Compare commits
175 Commits
25492caf15
...
feature_20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a446f71c04 | ||
|
|
e02cbcd9b7 | ||
|
|
9bb9eab922 | ||
|
|
3d7b0045b7 | ||
|
|
ada9f6e778 | ||
|
|
07aebbece5 | ||
|
|
7a11800cba | ||
|
|
3b352be1a8 | ||
|
|
c49dee72eb | ||
|
|
7159e510a6 | ||
|
|
385d452f5a | ||
|
|
bdc823e122 | ||
|
|
c83d239219 | ||
|
|
c4900bd280 | ||
|
|
7736212235 | ||
|
|
348d8a0ec3 | ||
|
|
5a0d6e1569 | ||
|
|
bc2b6ae41c | ||
|
|
ac7e627b2d | ||
|
|
21e83ac1bc | ||
|
|
e2dd9e2648 | ||
|
|
f2463922f3 | ||
|
|
9aaad00f87 | ||
|
|
024126025d | ||
|
|
e2f9f3278f | ||
|
|
2d03c88f43 | ||
|
|
515b538c84 | ||
|
|
b52b54347d | ||
|
|
4954373b5b | ||
|
|
66cd6c3a29 | ||
|
|
ba99f55b16 | ||
|
|
2f69f83d16 | ||
|
|
3bd48e1ddd | ||
|
|
84914b3cca | ||
|
|
da455946a3 | ||
|
|
e734319ec4 | ||
|
|
faf2446203 | ||
|
|
83b24b6d54 | ||
|
|
ab7164681a | ||
|
|
bc6d370f55 | ||
|
|
42215b2d59 | ||
|
|
c34aa37731 | ||
|
|
2eb2a22495 | ||
|
|
6a4c475d3a | ||
|
|
e08b9d2104 | ||
|
|
3f1f438440 | ||
|
|
24720dbba0 | ||
|
|
7877c41e9c | ||
|
|
b25d48e167 | ||
|
|
804de885e1 | ||
|
|
6738a09e3a | ||
|
|
67340e9b82 | ||
|
|
00f2937a34 | ||
|
|
91ed649220 | ||
|
|
391955f88c | ||
|
|
59f4b1cdb9 | ||
|
|
3d6d01964d | ||
|
|
3f3e13bddd | ||
|
|
d27cf5b7d8 | ||
|
|
03bc2d681b | ||
|
|
1022fa4077 | ||
|
|
406b951e53 | ||
|
|
7f392619e7 | ||
|
|
09ca7265d7 | ||
|
|
276b280cb9 | ||
|
|
adfc0bd478 | ||
|
|
85a857dc19 | ||
|
|
b89837d22e | ||
|
|
942dd16800 | ||
|
|
35e3b66684 | ||
|
|
b9ea08e601 | ||
|
|
d9106bf9f7 | ||
|
|
fb42ef566b | ||
|
|
a424b3338d | ||
|
|
9e6e3ae322 | ||
|
|
e92cc09e06 | ||
|
|
23112db115 | ||
|
|
7c7c70c4d9 | ||
|
|
e049429b09 | ||
|
|
b8cd520014 | ||
|
|
96fe919164 | ||
|
|
4672a24353 | ||
|
|
26bc5fece0 | ||
|
|
1c35ea24cd | ||
|
|
d76b0d32d6 | ||
|
|
eb093a5189 | ||
|
|
2c0b06e6a0 | ||
|
|
b3fb472c66 | ||
|
|
6797f54b6c | ||
|
|
a47e0feed8 | ||
|
|
13fa91a998 | ||
|
|
fba7a7ee96 | ||
|
|
32a73efb55 | ||
|
|
7819b4f8a2 | ||
|
|
6f74c1c1de | ||
|
|
3fed9d2d65 | ||
|
|
514917c0eb | ||
|
|
6ce913d79b | ||
|
|
6d5594556b | ||
|
|
c32091e83e | ||
|
|
2994de98c2 | ||
|
|
c237a4dc0c | ||
|
|
395dc27fe2 | ||
|
|
3abee6b907 | ||
|
|
d86cef9f79 | ||
|
|
9aaf4400c1 | ||
|
|
1cd8a2d7e9 | ||
|
|
af3cdc24b1 | ||
|
|
bfb6ef63d0 | ||
|
|
722d038b56 | ||
|
|
a01532ce65 | ||
|
|
fbeb66fb39 | ||
| 4fd1a24db4 | |||
| 3cb9b4237b | |||
|
|
7c00763999 | ||
| d6d4bb8a12 | |||
|
|
5f6e4387e5 | ||
| 1adbeda168 | |||
| 92458a8705 | |||
| 45339902aa | |||
| 2482b01b00 | |||
|
|
38076534b1 | ||
| d29ebfd501 | |||
|
|
a7ab87f7c4 | ||
|
|
9a77bb6f0b | ||
| da44dcd522 | |||
|
|
bf8847698b | ||
| e501ac3819 | |||
|
|
7c83ffe008 | ||
|
|
8786fa7b06 | ||
|
|
0997cd9992 | ||
|
|
c8d704363d | ||
|
|
0de4a1f7af | ||
|
|
3382dd1036 | ||
|
|
9423094af2 | ||
|
|
4f38505a80 | ||
|
|
4274341ed5 | ||
|
|
40f6eaced6 | ||
|
|
2dd7dd755a | ||
|
|
04ce16df56 | ||
|
|
d7759b1da3 | ||
|
|
701f96855e | ||
| d9daaeed19 | |||
| 205fd880f8 | |||
|
|
cd1a5b743f | ||
|
|
18c83237e2 | ||
| a6276ec435 | |||
| 87118209fe | |||
| a2d8ff7422 | |||
|
|
c1e10e6205 | ||
|
|
cf7376cc5a | ||
|
|
023684b8b7 | ||
| b40ca0e23c | |||
|
|
726d808f5c | ||
|
|
0e862d82a0 | ||
|
|
27fff4e60b | ||
|
|
4954c58525 | ||
|
|
91bd581a5e | ||
|
|
258708fca0 | ||
|
|
90391729bb | ||
|
|
2148d319ad | ||
|
|
c61d58b0e3 | ||
|
|
ed1c7b9fa9 | ||
|
|
e8763331cc | ||
|
|
15f5c445c5 | ||
|
|
c704b12bce | ||
|
|
c37a25d264 | ||
|
|
da2007386e | ||
|
|
76f13d6098 | ||
|
|
641514bbfd | ||
|
|
a8c8fe4211 | ||
|
|
65f71603e1 | ||
|
|
915ac2ebd3 | ||
|
|
4a5cd891bd | ||
|
|
429e96475f |
110
.husky/pre-commit
Executable file
110
.husky/pre-commit
Executable file
@@ -0,0 +1,110 @@
|
||||
#!/bin/sh
|
||||
|
||||
# ============================================
|
||||
# Git Pre-commit Hook
|
||||
# ============================================
|
||||
# 规则:
|
||||
# 1. src 目录下新增的代码文件必须使用 TypeScript (.ts/.tsx)
|
||||
# 2. 修改的代码不能使用 fetch,应使用 axios
|
||||
# ============================================
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
has_error=0
|
||||
|
||||
echo ""
|
||||
echo "🔍 正在检查代码规范..."
|
||||
echo ""
|
||||
|
||||
# ============================================
|
||||
# 规则 1: 新文件必须使用 TypeScript
|
||||
# ============================================
|
||||
|
||||
# 获取新增的文件(只检查 src 目录下的代码文件)
|
||||
new_js_files=$(git diff --cached --name-only --diff-filter=A | grep -E '^src/.*\.(js|jsx)$' || true)
|
||||
|
||||
if [ -n "$new_js_files" ]; then
|
||||
echo "${RED}❌ 错误: 发现新增的 JavaScript 文件${NC}"
|
||||
echo "${YELLOW} 新文件必须使用 TypeScript (.ts/.tsx)${NC}"
|
||||
echo ""
|
||||
echo " 以下文件需要改为 TypeScript:"
|
||||
echo "$new_js_files" | while read file; do
|
||||
echo " - $file"
|
||||
done
|
||||
echo ""
|
||||
echo " 💡 提示: 请将文件扩展名改为 .ts 或 .tsx"
|
||||
echo ""
|
||||
has_error=1
|
||||
fi
|
||||
|
||||
# ============================================
|
||||
# 规则 2: 禁止使用 fetch,应使用 axios
|
||||
# ============================================
|
||||
|
||||
# 获取所有暂存的文件(新增 + 修改)
|
||||
staged_files=$(git diff --cached --name-only --diff-filter=AM | grep -E '^src/.*\.(js|jsx|ts|tsx)$' || true)
|
||||
|
||||
if [ -n "$staged_files" ]; then
|
||||
# 检查暂存内容中是否包含 fetch 调用
|
||||
# 使用 git diff --cached 检查实际修改的内容
|
||||
fetch_found=""
|
||||
|
||||
for file in $staged_files; do
|
||||
# 检查该文件暂存的更改中是否有 fetch 调用
|
||||
# 排除注释和字符串中的 fetch
|
||||
# 匹配: fetch(, await fetch, .fetch(
|
||||
fetch_matches=$(git diff --cached -U0 "$file" 2>/dev/null | grep -E '^\+.*[^a-zA-Z_]fetch\s*\(' | grep -v '^\+\s*//' || true)
|
||||
|
||||
if [ -n "$fetch_matches" ]; then
|
||||
fetch_found="$fetch_found
|
||||
$file"
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -n "$fetch_found" ]; then
|
||||
echo "${RED}❌ 错误: 检测到使用了 fetch API${NC}"
|
||||
echo "${YELLOW} 请使用 axios 进行 HTTP 请求${NC}"
|
||||
echo ""
|
||||
echo " 以下文件包含 fetch 调用:"
|
||||
echo "$fetch_found" | while read file; do
|
||||
if [ -n "$file" ]; then
|
||||
echo " - $file"
|
||||
fi
|
||||
done
|
||||
echo ""
|
||||
echo " 💡 修改建议:"
|
||||
echo " ${GREEN}// 替换前${NC}"
|
||||
echo " fetch('/api/data').then(res => res.json())"
|
||||
echo ""
|
||||
echo " ${GREEN}// 替换后${NC}"
|
||||
echo " import axios from 'axios';"
|
||||
echo " axios.get('/api/data').then(res => res.data)"
|
||||
echo ""
|
||||
has_error=1
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================
|
||||
# 检查结果
|
||||
# ============================================
|
||||
|
||||
if [ $has_error -eq 1 ]; then
|
||||
echo "${RED}========================================${NC}"
|
||||
echo "${RED}提交被阻止,请修复以上问题后重试${NC}"
|
||||
echo "${RED}========================================${NC}"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${GREEN}✅ 代码规范检查通过${NC}"
|
||||
echo ""
|
||||
|
||||
# 运行 lint-staged(如果配置了)
|
||||
# 可选:在 package.json 中添加 "lint-staged" 配置来启用代码格式化
|
||||
# if [ -f "package.json" ] && grep -q '"lint-staged"' package.json; then
|
||||
# npx lint-staged
|
||||
# fi
|
||||
569
app.py
569
app.py
@@ -165,7 +165,7 @@ WECHAT_OPEN_APPID = 'wxa8d74c47041b5f87'
|
||||
WECHAT_OPEN_APPSECRET = 'eedef95b11787fd7ca7f1acc6c9061bc'
|
||||
|
||||
# 微信公众号配置(H5 网页授权用)
|
||||
WECHAT_MP_APPID = 'wx4e4b759f8fa9e43a'
|
||||
WECHAT_MP_APPID = 'wx8afd36f7c7b21ba0'
|
||||
WECHAT_MP_APPSECRET = 'ef1ca9064af271bb0405330efbc495aa'
|
||||
|
||||
# 微信回调地址
|
||||
@@ -6412,6 +6412,10 @@ def get_stock_kline(stock_code):
|
||||
except ValueError:
|
||||
return jsonify({'error': 'Invalid event_time format'}), 400
|
||||
|
||||
# 确保股票代码包含后缀(ClickHouse 中数据带后缀)
|
||||
if '.' not in stock_code:
|
||||
stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ"
|
||||
|
||||
# 获取股票名称
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(
|
||||
@@ -7819,7 +7823,7 @@ def get_index_realtime(index_code):
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}")
|
||||
app.logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
@@ -7837,8 +7841,13 @@ def get_index_kline(index_code):
|
||||
except ValueError:
|
||||
return jsonify({'error': 'Invalid event_time format'}), 400
|
||||
|
||||
# 确保指数代码包含后缀(ClickHouse 中数据带后缀)
|
||||
# 399xxx -> 深交所, 其他(000xxx等)-> 上交所
|
||||
if '.' not in index_code:
|
||||
index_code = f"{index_code}.SZ" if index_code.startswith('39') else f"{index_code}.SH"
|
||||
|
||||
# 指数名称(暂无索引表,先返回代码本身)
|
||||
index_name = index_code
|
||||
index_name = index_code.split('.')[0]
|
||||
|
||||
if chart_type == 'minute':
|
||||
return get_index_minute_kline(index_code, event_datetime, index_name)
|
||||
@@ -12044,10 +12053,11 @@ def get_market_summary(seccode):
|
||||
|
||||
@app.route('/api/stocks/search', methods=['GET'])
|
||||
def search_stocks():
|
||||
"""搜索股票(支持股票代码、股票简称、拼音首字母)"""
|
||||
"""搜索股票和指数(支持代码、名称搜索)"""
|
||||
try:
|
||||
query = request.args.get('q', '').strip()
|
||||
limit = request.args.get('limit', 20, type=int)
|
||||
search_type = request.args.get('type', 'all') # all, stock, index
|
||||
|
||||
if not query:
|
||||
return jsonify({
|
||||
@@ -12055,73 +12065,132 @@ def search_stocks():
|
||||
'error': '请输入搜索关键词'
|
||||
}), 400
|
||||
|
||||
results = []
|
||||
|
||||
with engine.connect() as conn:
|
||||
test_sql = text("""
|
||||
SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V
|
||||
FROM ea_stocklist
|
||||
WHERE SECCODE = '300750'
|
||||
OR F001V LIKE '%ndsd%' LIMIT 5
|
||||
""")
|
||||
test_result = conn.execute(test_sql).fetchall()
|
||||
# 搜索指数(优先显示指数,因为通常用户搜索代码时指数更常用)
|
||||
if search_type in ('all', 'index'):
|
||||
index_sql = text("""
|
||||
SELECT DISTINCT
|
||||
INDEXCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
INDEXNAME as full_name,
|
||||
F018V as exchange
|
||||
FROM ea_exchangeindex
|
||||
WHERE (
|
||||
UPPER(INDEXCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(INDEXNAME) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(INDEXCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(INDEXCODE) LIKE UPPER(:prefix_pattern) THEN 3
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
ELSE 5
|
||||
END,
|
||||
INDEXCODE
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
# 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索
|
||||
search_sql = text("""
|
||||
SELECT DISTINCT SECCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
F001V as pinyin_abbr,
|
||||
F003V as security_type,
|
||||
F005V as exchange,
|
||||
F011V as listing_status
|
||||
FROM ea_stocklist
|
||||
WHERE (
|
||||
UPPER(SECCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(F001V) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
-- 基本过滤条件:只搜索正常的A股和B股
|
||||
AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态
|
||||
AND F003V IN ('A股', 'B股') -- 只搜索A股和B股
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3
|
||||
WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5
|
||||
WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6
|
||||
ELSE 7
|
||||
END,
|
||||
SECCODE LIMIT :limit
|
||||
""")
|
||||
index_result = conn.execute(index_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
|
||||
result = conn.execute(search_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
for row in index_result:
|
||||
results.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'full_name': row.full_name,
|
||||
'exchange': row.exchange,
|
||||
'isIndex': True,
|
||||
'security_type': '指数'
|
||||
})
|
||||
|
||||
stocks = []
|
||||
for row in result:
|
||||
# 获取当前价格
|
||||
current_price, _ = get_latest_price_from_clickhouse(row.stock_code)
|
||||
# 搜索股票
|
||||
if search_type in ('all', 'stock'):
|
||||
stock_sql = text("""
|
||||
SELECT DISTINCT SECCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
F001V as pinyin_abbr,
|
||||
F003V as security_type,
|
||||
F005V as exchange,
|
||||
F011V as listing_status
|
||||
FROM ea_stocklist
|
||||
WHERE (
|
||||
UPPER(SECCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(F001V) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
AND (F011V = '正常上市' OR F010V = '013001')
|
||||
AND F003V IN ('A股', 'B股')
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3
|
||||
WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5
|
||||
WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6
|
||||
ELSE 7
|
||||
END,
|
||||
SECCODE
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
stocks.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'current_price': current_price or 0, # 添加当前价格
|
||||
'pinyin_abbr': row.pinyin_abbr,
|
||||
'security_type': row.security_type,
|
||||
'exchange': row.exchange,
|
||||
'listing_status': row.listing_status
|
||||
})
|
||||
stock_result = conn.execute(stock_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
|
||||
for row in stock_result:
|
||||
results.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'pinyin_abbr': row.pinyin_abbr,
|
||||
'security_type': row.security_type,
|
||||
'exchange': row.exchange,
|
||||
'listing_status': row.listing_status,
|
||||
'isIndex': False
|
||||
})
|
||||
|
||||
# 如果搜索全部,按相关性重新排序(精确匹配优先)
|
||||
if search_type == 'all':
|
||||
def sort_key(item):
|
||||
code = item['stock_code'].upper()
|
||||
name = item['stock_name'].upper()
|
||||
q = query.upper()
|
||||
# 精确匹配代码优先
|
||||
if code == q:
|
||||
return (0, not item['isIndex'], code) # 指数优先
|
||||
# 精确匹配名称
|
||||
if name == q:
|
||||
return (1, not item['isIndex'], code)
|
||||
# 前缀匹配代码
|
||||
if code.startswith(q):
|
||||
return (2, not item['isIndex'], code)
|
||||
# 前缀匹配名称
|
||||
if name.startswith(q):
|
||||
return (3, not item['isIndex'], code)
|
||||
return (4, not item['isIndex'], code)
|
||||
|
||||
results.sort(key=sort_key)
|
||||
|
||||
# 限制总数
|
||||
results = results[:limit]
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': stocks,
|
||||
'count': len(stocks)
|
||||
'data': results,
|
||||
'count': len(results)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
app.logger.error(f"搜索股票/指数错误: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
@@ -12465,6 +12534,10 @@ def get_hotspot_overview():
|
||||
"""
|
||||
获取热点概览数据(用于个股中心的热点概览图表)
|
||||
返回:指数分时数据 + 概念异动标注
|
||||
|
||||
数据来源:
|
||||
- 指数分时:ClickHouse index_minute 表
|
||||
- 概念异动:MySQL concept_anomaly_hybrid 表(来自 realtime_detector.py)
|
||||
"""
|
||||
try:
|
||||
trade_date = request.args.get('date')
|
||||
@@ -12532,60 +12605,135 @@ def get_hotspot_overview():
|
||||
'change_pct': change_pct
|
||||
})
|
||||
|
||||
# 2. 获取概念异动数据
|
||||
# 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表)
|
||||
alerts = []
|
||||
use_v2 = False
|
||||
|
||||
with engine.connect() as conn:
|
||||
alert_result = conn.execute(text("""
|
||||
SELECT
|
||||
concept_id, concept_name, alert_time, alert_type,
|
||||
change_pct, change_delta, limit_up_count, limit_up_delta,
|
||||
rank_position, index_price, index_change_pct,
|
||||
stock_count, concept_type, extra_info,
|
||||
prev_change_pct, zscore, importance_score
|
||||
FROM concept_minute_alert
|
||||
WHERE trade_date = :trade_date
|
||||
ORDER BY alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
# 尝试查询 V2 表(时间片对齐 + 持续确认版本)
|
||||
try:
|
||||
v2_result = conn.execute(text("""
|
||||
SELECT
|
||||
concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules
|
||||
FROM concept_anomaly_v2
|
||||
WHERE trade_date = :trade_date
|
||||
ORDER BY alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
v2_rows = v2_result.fetchall()
|
||||
if v2_rows:
|
||||
use_v2 = True
|
||||
for row in v2_rows:
|
||||
triggered_rules = None
|
||||
if row[16]:
|
||||
try:
|
||||
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||
except:
|
||||
pass
|
||||
|
||||
for row in alert_result:
|
||||
alert_time = row[2]
|
||||
extra_info = None
|
||||
if row[13]:
|
||||
try:
|
||||
extra_info = json.loads(row[13]) if isinstance(row[13], str) else row[13]
|
||||
except:
|
||||
pass
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[0], # 后面会填充
|
||||
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||
'timestamp': row[1].isoformat() if row[1] else None,
|
||||
'alert_type': row[3],
|
||||
'final_score': float(row[4]) if row[4] else None,
|
||||
'rule_score': float(row[5]) if row[5] else None,
|
||||
'ml_score': float(row[6]) if row[6] else None,
|
||||
'trigger_reason': row[7],
|
||||
# V2 新增字段
|
||||
'confirm_ratio': float(row[8]) if row[8] else None,
|
||||
'alpha': float(row[9]) if row[9] else None,
|
||||
'alpha_zscore': float(row[10]) if row[10] else None,
|
||||
'amt_zscore': float(row[11]) if row[11] else None,
|
||||
'rank_zscore': float(row[12]) if row[12] else None,
|
||||
'momentum_3m': float(row[13]) if row[13] else None,
|
||||
'momentum_5m': float(row[14]) if row[14] else None,
|
||||
'limit_up_ratio': float(row[15]) if row[15] else 0,
|
||||
'triggered_rules': triggered_rules,
|
||||
# 兼容字段
|
||||
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||
'is_v2': True,
|
||||
})
|
||||
except Exception as v2_err:
|
||||
app.logger.debug(f"V2 表查询失败,使用旧表: {v2_err}")
|
||||
|
||||
# 从 extra_info 提取 zscore 和 importance_score(兼容旧数据)
|
||||
zscore = None
|
||||
importance_score = None
|
||||
if len(row) > 15:
|
||||
zscore = float(row[15]) if row[15] else None
|
||||
importance_score = float(row[16]) if row[16] else None
|
||||
if extra_info:
|
||||
zscore = zscore or extra_info.get('zscore')
|
||||
importance_score = importance_score or extra_info.get('importance_score')
|
||||
# Fallback: 查询旧表
|
||||
if not use_v2:
|
||||
try:
|
||||
alert_result = conn.execute(text("""
|
||||
SELECT
|
||||
a.concept_id, a.alert_time, a.trade_date, a.alert_type,
|
||||
a.final_score, a.rule_score, a.ml_score, a.trigger_reason,
|
||||
a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta,
|
||||
a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt,
|
||||
a.triggered_rules
|
||||
FROM concept_anomaly_hybrid a
|
||||
WHERE a.trade_date = :trade_date
|
||||
ORDER BY a.alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[1],
|
||||
'time': alert_time.strftime('%H:%M') if alert_time else None,
|
||||
'timestamp': alert_time.isoformat() if alert_time else None,
|
||||
'alert_type': row[3],
|
||||
'change_pct': float(row[4]) if row[4] else None,
|
||||
'change_delta': float(row[5]) if row[5] else None,
|
||||
'limit_up_count': row[6],
|
||||
'limit_up_delta': row[7],
|
||||
'rank_position': row[8],
|
||||
'index_price': float(row[9]) if row[9] else None,
|
||||
'index_change_pct': float(row[10]) if row[10] else None,
|
||||
'stock_count': row[11],
|
||||
'concept_type': row[12],
|
||||
'extra_info': extra_info,
|
||||
'prev_change_pct': float(row[14]) if len(row) > 14 and row[14] else None,
|
||||
'zscore': zscore,
|
||||
'importance_score': importance_score
|
||||
})
|
||||
for row in alert_result:
|
||||
triggered_rules = None
|
||||
if row[16]:
|
||||
try:
|
||||
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||
except:
|
||||
pass
|
||||
|
||||
limit_up_ratio = float(row[13]) if row[13] else 0
|
||||
stock_count = int(row[14]) if row[14] else 0
|
||||
limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[0],
|
||||
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||
'timestamp': row[1].isoformat() if row[1] else None,
|
||||
'alert_type': row[3],
|
||||
'final_score': float(row[4]) if row[4] else None,
|
||||
'rule_score': float(row[5]) if row[5] else None,
|
||||
'ml_score': float(row[6]) if row[6] else None,
|
||||
'trigger_reason': row[7],
|
||||
'alpha': float(row[8]) if row[8] else None,
|
||||
'alpha_delta': float(row[9]) if row[9] else None,
|
||||
'amt_ratio': float(row[10]) if row[10] else None,
|
||||
'amt_delta': float(row[11]) if row[11] else None,
|
||||
'rank_pct': float(row[12]) if row[12] else None,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'limit_up_count': limit_up_count,
|
||||
'stock_count': stock_count,
|
||||
'total_amt': float(row[15]) if row[15] else None,
|
||||
'triggered_rules': triggered_rules,
|
||||
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||
'is_v2': False,
|
||||
})
|
||||
except Exception as old_err:
|
||||
app.logger.debug(f"旧表查询也失败: {old_err}")
|
||||
|
||||
# 尝试批量获取概念名称
|
||||
if alerts:
|
||||
concept_ids = list(set(a['concept_id'] for a in alerts))
|
||||
concept_names = {} # 初始化 concept_names 字典
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
es_client = Elasticsearch(["http://222.128.1.157:19200"])
|
||||
es_result = es_client.mget(
|
||||
index='concept_library_v3',
|
||||
body={'ids': concept_ids},
|
||||
_source=['concept']
|
||||
)
|
||||
for doc in es_result.get('docs', []):
|
||||
if doc.get('found') and doc.get('_source'):
|
||||
concept_names[doc['_id']] = doc['_source'].get('concept', doc['_id'])
|
||||
# 更新 alerts 中的概念名称
|
||||
for alert in alerts:
|
||||
if alert['concept_id'] in concept_names:
|
||||
alert['concept_name'] = concept_names[alert['concept_id']]
|
||||
except Exception as e:
|
||||
app.logger.warning(f"获取概念名称失败: {e}")
|
||||
|
||||
# 计算统计信息
|
||||
day_high = max([d['price'] for d in index_timeline if d['price']], default=None)
|
||||
@@ -12614,6 +12762,7 @@ def get_hotspot_overview():
|
||||
'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']),
|
||||
'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']),
|
||||
'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']),
|
||||
'volume_spike': len([a for a in alerts if a['alert_type'] == 'volume_spike']),
|
||||
'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump'])
|
||||
}
|
||||
}
|
||||
@@ -12621,7 +12770,205 @@ def get_hotspot_overview():
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"获取热点概览数据失败: {traceback.format_exc()}")
|
||||
error_trace = traceback.format_exc()
|
||||
app.logger.error(f"获取热点概览数据失败: {error_trace}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'traceback': error_trace # 临时返回完整错误信息用于调试
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/concept/<concept_id>/stocks', methods=['GET'])
|
||||
def get_concept_stocks(concept_id):
|
||||
"""
|
||||
获取概念的相关股票列表(带实时涨跌幅)
|
||||
|
||||
Args:
|
||||
concept_id: 概念 ID(来自 ES concept_library_v3)
|
||||
|
||||
Returns:
|
||||
- stocks: 股票列表 [{code, name, reason, change_pct}, ...]
|
||||
"""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
# 1. 从 ES 获取概念的股票列表
|
||||
es_client = Elasticsearch(["http://222.128.1.157:19200"])
|
||||
es_result = es_client.get(index='concept_library_v3', id=concept_id)
|
||||
|
||||
if not es_result.get('found'):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': f'概念 {concept_id} 不存在'
|
||||
}), 404
|
||||
|
||||
source = es_result.get('_source', {})
|
||||
concept_name = source.get('concept', concept_id)
|
||||
raw_stocks = source.get('stocks', [])
|
||||
|
||||
if not raw_stocks:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stocks': []
|
||||
}
|
||||
})
|
||||
|
||||
# 提取股票代码和原因
|
||||
stocks_info = []
|
||||
stock_codes = []
|
||||
for s in raw_stocks:
|
||||
if isinstance(s, dict):
|
||||
code = s.get('code', '')
|
||||
if code and len(code) == 6:
|
||||
stocks_info.append({
|
||||
'code': code,
|
||||
'name': s.get('name', ''),
|
||||
'reason': s.get('reason', '')
|
||||
})
|
||||
stock_codes.append(code)
|
||||
|
||||
if not stock_codes:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stocks': stocks_info
|
||||
}
|
||||
})
|
||||
|
||||
# 2. 获取最新交易日和前一交易日
|
||||
today = datetime.now().date()
|
||||
trading_day = None
|
||||
prev_trading_day = None
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取最新交易日
|
||||
result = conn.execute(text("""
|
||||
SELECT EXCHANGE_DATE FROM trading_days
|
||||
WHERE EXCHANGE_DATE <= :today
|
||||
ORDER BY EXCHANGE_DATE DESC LIMIT 1
|
||||
"""), {"today": today}).fetchone()
|
||||
if result:
|
||||
trading_day = result[0].date() if hasattr(result[0], 'date') else result[0]
|
||||
|
||||
# 获取前一交易日
|
||||
if trading_day:
|
||||
result = conn.execute(text("""
|
||||
SELECT EXCHANGE_DATE FROM trading_days
|
||||
WHERE EXCHANGE_DATE < :date
|
||||
ORDER BY EXCHANGE_DATE DESC LIMIT 1
|
||||
"""), {"date": trading_day}).fetchone()
|
||||
if result:
|
||||
prev_trading_day = result[0].date() if hasattr(result[0], 'date') else result[0]
|
||||
|
||||
# 3. 从 MySQL ea_trade 获取前一交易日收盘价(F007N)
|
||||
prev_close_map = {}
|
||||
if prev_trading_day and stock_codes:
|
||||
with engine.connect() as conn:
|
||||
placeholders = ','.join([f':code{i}' for i in range(len(stock_codes))])
|
||||
params = {f'code{i}': code for i, code in enumerate(stock_codes)}
|
||||
params['trade_date'] = prev_trading_day
|
||||
|
||||
result = conn.execute(text(f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ({placeholders})
|
||||
AND TRADEDATE = :trade_date
|
||||
AND F007N > 0
|
||||
"""), params).fetchall()
|
||||
|
||||
prev_close_map = {row[0]: float(row[1]) for row in result if row[1]}
|
||||
|
||||
# 4. 从 ClickHouse 获取最新价格
|
||||
current_price_map = {}
|
||||
if stock_codes:
|
||||
try:
|
||||
ch_client = Client(
|
||||
host='127.0.0.1',
|
||||
port=9000,
|
||||
user='default',
|
||||
password='Zzl33818!',
|
||||
database='stock'
|
||||
)
|
||||
|
||||
# 转换为 ClickHouse 格式
|
||||
ch_codes = []
|
||||
code_mapping = {}
|
||||
for code in stock_codes:
|
||||
if code.startswith('6'):
|
||||
ch_code = f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
ch_code = f"{code}.SZ"
|
||||
else:
|
||||
ch_code = f"{code}.BJ"
|
||||
ch_codes.append(ch_code)
|
||||
code_mapping[ch_code] = code
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
# 查询当天最新价格
|
||||
query = f"""
|
||||
SELECT code, close
|
||||
FROM stock_minute
|
||||
WHERE code IN ('{ch_codes_str}')
|
||||
AND toDate(timestamp) = today()
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1 BY code
|
||||
"""
|
||||
result = ch_client.execute(query)
|
||||
|
||||
for row in result:
|
||||
ch_code, close_price = row
|
||||
if ch_code in code_mapping and close_price:
|
||||
original_code = code_mapping[ch_code]
|
||||
current_price_map[original_code] = float(close_price)
|
||||
|
||||
except Exception as ch_err:
|
||||
app.logger.warning(f"ClickHouse 获取价格失败: {ch_err}")
|
||||
|
||||
# 5. 计算涨跌幅并合并数据
|
||||
result_stocks = []
|
||||
for stock in stocks_info:
|
||||
code = stock['code']
|
||||
prev_close = prev_close_map.get(code)
|
||||
current_price = current_price_map.get(code)
|
||||
|
||||
change_pct = None
|
||||
if prev_close and current_price and prev_close > 0:
|
||||
change_pct = round((current_price - prev_close) / prev_close * 100, 2)
|
||||
|
||||
result_stocks.append({
|
||||
'code': code,
|
||||
'name': stock['name'],
|
||||
'reason': stock['reason'],
|
||||
'change_pct': change_pct,
|
||||
'price': current_price,
|
||||
'prev_close': prev_close
|
||||
})
|
||||
|
||||
# 按涨跌幅排序(涨停优先)
|
||||
result_stocks.sort(key=lambda x: x.get('change_pct') if x.get('change_pct') is not None else -999, reverse=True)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stock_count': len(result_stocks),
|
||||
'trading_day': str(trading_day) if trading_day else None,
|
||||
'stocks': result_stocks
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
app.logger.error(f"获取概念股票失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
@@ -12724,7 +13071,7 @@ def get_concept_alerts():
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"获取概念异动列表失败: {traceback.format_exc()}")
|
||||
app.logger.error(f"获取概念异动列表失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
2025-12-08 16:40:41,567 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,567 - INFO - 🔄 回测: 2025-12-08 (Alpha Z-Score 方法)
|
||||
2025-12-08 16:40:41,569 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,679 - INFO - 已清除 2025-12-08 的数据
|
||||
2025-12-08 16:40:41,903 - INFO - POST http://222.128.1.157:19200/concept_library_v3/_search?scroll=2m [status:200 duration:0.224s]
|
||||
2025-12-08 16:40:42,105 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.197s]
|
||||
2025-12-08 16:40:42,330 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.178s]
|
||||
2025-12-08 16:40:42,518 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.183s]
|
||||
2025-12-08 16:40:42,704 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.182s]
|
||||
2025-12-08 16:40:42,894 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.186s]
|
||||
2025-12-08 16:40:43,060 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.162s]
|
||||
2025-12-08 16:40:43,234 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.171s]
|
||||
2025-12-08 16:40:43,383 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.145s]
|
||||
2025-12-08 16:40:43,394 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.008s]
|
||||
2025-12-08 16:40:43,399 - INFO - DELETE http://222.128.1.157:19200/_search/scroll [status:200 duration:0.005s]
|
||||
2025-12-08 16:40:43,409 - INFO - 概念: 968, 股票: 5938
|
||||
2025-12-08 16:40:43,505 - INFO - 时间点: 241
|
||||
2025-12-08 16:41:02,028 - INFO - 进度: 30/241 (12%), 异动: 0
|
||||
2025-12-08 16:41:20,851 - INFO - 进度: 60/241 (24%), 异动: 0
|
||||
2025-12-08 16:41:39,396 - INFO - 进度: 90/241 (37%), 异动: 0
|
||||
2025-12-08 16:41:58,687 - INFO - 进度: 120/241 (49%), 异动: 0
|
||||
2025-12-08 16:43:08,124 - INFO - 进度: 150/241 (62%), 异动: 0
|
||||
2025-12-08 16:43:26,973 - INFO - 进度: 180/241 (74%), 异动: 0
|
||||
2025-12-08 16:43:45,746 - INFO - 进度: 210/241 (87%), 异动: 0
|
||||
2025-12-08 16:44:04,479 - INFO - 进度: 240/241 (99%), 异动: 0
|
||||
2025-12-08 16:44:05,123 - INFO - ============================================================
|
||||
2025-12-08 16:44:05,123 - INFO - ✅ 回测完成! 检测到 0 条异动
|
||||
2025-12-08 16:44:05,125 - INFO - ============================================================
|
||||
1625
concept_alert_ml.py
1625
concept_alert_ml.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,681 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念涨跌幅实时更新服务
|
||||
- 在交易时间段每分钟从ClickHouse获取最新分钟数据
|
||||
- 计算涨跌幅后更新MySQL的concept_daily_stats表
|
||||
- 支持叶子概念和母概念(lv1/lv2/lv3)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import argparse
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
# MySQL配置
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# Elasticsearch配置
|
||||
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
|
||||
INDEX_NAME = 'concept_library_v3'
|
||||
|
||||
# ClickHouse配置
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '222.128.1.157',
|
||||
'port': 18000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
# 层级结构文件
|
||||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||||
|
||||
# 交易时间配置
|
||||
TRADING_HOURS = {
|
||||
'morning_start': (9, 30),
|
||||
'morning_end': (11, 30),
|
||||
'afternoon_start': (13, 0),
|
||||
'afternoon_end': (15, 0),
|
||||
}
|
||||
|
||||
# ==================== 日志配置 ====================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'concept_realtime_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ClickHouse客户端
|
||||
ch_client = None
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取ClickHouse客户端"""
|
||||
global ch_client
|
||||
if ch_client is None:
|
||||
ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return ch_client
|
||||
|
||||
|
||||
def generate_id(name: str) -> str:
|
||||
"""生成概念ID"""
|
||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
"""将6位股票代码转换为ClickHouse格式(带后缀)
|
||||
|
||||
规则:
|
||||
- 6开头 -> .SH(上海)
|
||||
- 0或3开头 -> .SZ(深圳)
|
||||
- 其他 -> .BJ(北京)
|
||||
- 非6位数字的忽略(可能是港股)
|
||||
"""
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def ch_code_to_pure(ch_code: str) -> str:
|
||||
"""将ClickHouse格式的股票代码转回纯6位代码"""
|
||||
if not ch_code:
|
||||
return None
|
||||
return ch_code.split('.')[0]
|
||||
|
||||
|
||||
# ==================== 概念数据获取 ====================
|
||||
|
||||
def get_all_concepts():
|
||||
"""从ES获取所有叶子概念及其股票列表"""
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
concept_info = {
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': [],
|
||||
'concept_type': 'leaf'
|
||||
}
|
||||
|
||||
# v3索引的stocks字段是 [{name, code}, ...]
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
concept_info['stocks'].append(stock['code'])
|
||||
|
||||
if concept_info['stocks']:
|
||||
concepts.append(concept_info)
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
|
||||
def load_hierarchy_concepts(leaf_concepts: list) -> list:
|
||||
"""加载层级结构,生成母概念(lv1/lv2/lv3)"""
|
||||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||||
if not os.path.exists(hierarchy_path):
|
||||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||||
return []
|
||||
|
||||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||||
hierarchy_data = json.load(f)
|
||||
|
||||
# 建立概念名称到股票的映射
|
||||
concept_to_stocks = {}
|
||||
for c in leaf_concepts:
|
||||
concept_to_stocks[c['concept_name']] = set(c['stocks'])
|
||||
|
||||
parent_concepts = []
|
||||
|
||||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||||
lv1_name = lv1.get('lv1', '')
|
||||
lv1_stocks = set()
|
||||
|
||||
for child in lv1.get('children', []):
|
||||
lv2_name = child.get('lv2', '')
|
||||
lv2_stocks = set()
|
||||
|
||||
if 'children' in child:
|
||||
for lv3_child in child.get('children', []):
|
||||
lv3_name = lv3_child.get('lv3', '')
|
||||
lv3_stocks = set()
|
||||
|
||||
for concept_name in lv3_child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv3_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv3_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv3_{lv3_name}"),
|
||||
'concept_name': f"[三级] {lv3_name}",
|
||||
'stocks': list(lv3_stocks),
|
||||
'concept_type': 'lv3'
|
||||
})
|
||||
|
||||
lv2_stocks.update(lv3_stocks)
|
||||
else:
|
||||
for concept_name in child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv2_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv2_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv2_{lv2_name}"),
|
||||
'concept_name': f"[二级] {lv2_name}",
|
||||
'stocks': list(lv2_stocks),
|
||||
'concept_type': 'lv2'
|
||||
})
|
||||
|
||||
lv1_stocks.update(lv2_stocks)
|
||||
|
||||
if lv1_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv1_{lv1_name}"),
|
||||
'concept_name': f"[一级] {lv1_name}",
|
||||
'stocks': list(lv1_stocks),
|
||||
'concept_type': 'lv1'
|
||||
})
|
||||
|
||||
return parent_concepts
|
||||
|
||||
|
||||
# ==================== 基准价格获取 ====================
|
||||
|
||||
def get_base_prices(stock_codes: list, current_date: str) -> dict:
|
||||
"""获取当日的昨收价作为基准(从ea_trade的F002N字段)
|
||||
|
||||
ea_trade表字段说明:
|
||||
- F002N: 昨日收盘价
|
||||
- F007N: 最近成交价(收盘价)
|
||||
- F010N: 涨跌幅
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
# 过滤出有效的6位股票代码
|
||||
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
stock_codes_str = "','".join(valid_codes)
|
||||
|
||||
# 获取当日数据中的昨收价(F002N)
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{stock_codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE)
|
||||
FROM ea_trade
|
||||
WHERE TRADEDATE <= '{current_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
|
||||
logger.info(f"获取到 {len(base_prices)} 个基准价格")
|
||||
return base_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取基准价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 实时价格获取 ====================
|
||||
|
||||
def get_latest_prices(stock_codes: list) -> dict:
|
||||
"""从ClickHouse获取最新分钟数据的收盘价
|
||||
|
||||
Args:
|
||||
stock_codes: 纯6位股票代码列表(如 ['000001', '600000'])
|
||||
|
||||
Returns:
|
||||
dict: {纯6位代码: {'close': 价格, 'timestamp': 时间}}
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
client = get_ch_client()
|
||||
|
||||
# 转换为ClickHouse格式的代码(带后缀)
|
||||
ch_codes = []
|
||||
code_mapping = {} # ch_code -> pure_code
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_mapping[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
logger.warning("没有有效的股票代码可查询")
|
||||
return {}
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
# 获取今日最新的分钟数据
|
||||
query = f"""
|
||||
SELECT code, close, timestamp
|
||||
FROM (
|
||||
SELECT code, close, timestamp,
|
||||
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
|
||||
FROM stock_minute
|
||||
WHERE code IN ('{ch_codes_str}')
|
||||
AND toDate(timestamp) = today()
|
||||
)
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
try:
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
latest_prices = {}
|
||||
for row in result:
|
||||
ch_code, close, ts = row
|
||||
if close and close > 0:
|
||||
# 转回纯6位代码
|
||||
pure_code = code_mapping.get(ch_code)
|
||||
if pure_code:
|
||||
latest_prices[pure_code] = {
|
||||
'close': float(close),
|
||||
'timestamp': ts
|
||||
}
|
||||
|
||||
return latest_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取最新价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 涨跌幅计算 ====================
|
||||
|
||||
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
|
||||
"""计算涨跌幅"""
|
||||
changes = {}
|
||||
for code, latest in latest_prices.items():
|
||||
if code in base_prices and base_prices[code] > 0:
|
||||
base = base_prices[code]
|
||||
close = latest['close']
|
||||
change_pct = (close - base) / base * 100
|
||||
changes[code] = round(change_pct, 4)
|
||||
return changes
|
||||
|
||||
|
||||
def calculate_concept_stats(concepts: list, stock_changes: dict, trade_date: str) -> list:
|
||||
"""计算所有概念的涨跌幅统计"""
|
||||
stats = []
|
||||
|
||||
for concept in concepts:
|
||||
concept_id = concept['concept_id']
|
||||
concept_name = concept['concept_name']
|
||||
stock_codes = concept['stocks']
|
||||
concept_type = concept.get('concept_type', 'leaf')
|
||||
|
||||
# 获取该概念股票的涨跌幅
|
||||
changes = [stock_changes[code] for code in stock_codes if code in stock_changes]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
avg_change_pct = round(np.mean(changes), 4)
|
||||
stock_count = len(changes)
|
||||
|
||||
stats.append({
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'trade_date': trade_date,
|
||||
'avg_change_pct': avg_change_pct,
|
||||
'stock_count': stock_count,
|
||||
'concept_type': concept_type
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ==================== MySQL更新 ====================
|
||||
|
||||
def update_mysql_stats(stats: list):
|
||||
"""更新MySQL的concept_daily_stats表"""
|
||||
if not stats:
|
||||
return 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
updated = 0
|
||||
for item in stats:
|
||||
upsert_sql = text("""
|
||||
REPLACE INTO concept_daily_stats
|
||||
(concept_id, concept_name, trade_date, avg_change_pct, stock_count, concept_type)
|
||||
VALUES (:concept_id, :concept_name, :trade_date, :avg_change_pct, :stock_count, :concept_type)
|
||||
""")
|
||||
conn.execute(upsert_sql, item)
|
||||
updated += 1
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
# ==================== 交易时间判断 ====================
|
||||
|
||||
def is_trading_time() -> bool:
|
||||
"""判断当前是否为交易时间"""
|
||||
now = datetime.now()
|
||||
weekday = now.weekday()
|
||||
|
||||
# 周末不交易
|
||||
if weekday >= 5:
|
||||
return False
|
||||
|
||||
hour, minute = now.hour, now.minute
|
||||
current_time = hour * 60 + minute
|
||||
|
||||
# 上午 9:30 - 11:30
|
||||
morning_start = 9 * 60 + 30
|
||||
morning_end = 11 * 60 + 30
|
||||
|
||||
# 下午 13:00 - 15:00
|
||||
afternoon_start = 13 * 60
|
||||
afternoon_end = 15 * 60
|
||||
|
||||
return (morning_start <= current_time <= morning_end) or \
|
||||
(afternoon_start <= current_time <= afternoon_end)
|
||||
|
||||
|
||||
def get_next_update_time() -> int:
|
||||
"""获取距离下次更新的秒数"""
|
||||
now = datetime.now()
|
||||
|
||||
if is_trading_time():
|
||||
# 交易时间内,等到下一分钟
|
||||
return 60 - now.second
|
||||
else:
|
||||
# 非交易时间
|
||||
hour, minute = now.hour, now.minute
|
||||
|
||||
# 计算距离下次交易开始的时间
|
||||
if hour < 9 or (hour == 9 and minute < 30):
|
||||
# 等到9:30
|
||||
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
elif (hour == 11 and minute >= 30) or hour == 12:
|
||||
# 等到13:00
|
||||
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
|
||||
elif hour >= 15:
|
||||
# 等到明天9:30
|
||||
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
else:
|
||||
target = now + timedelta(minutes=1)
|
||||
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
return max(60, int(wait_seconds))
|
||||
|
||||
|
||||
# ==================== 主运行逻辑 ====================
|
||||
|
||||
def run_once(concepts: list, all_stocks: list) -> int:
|
||||
"""执行一次更新"""
|
||||
now = datetime.now()
|
||||
trade_date = now.strftime('%Y-%m-%d')
|
||||
|
||||
# 获取基准价格(昨日收盘价)
|
||||
base_prices = get_base_prices(all_stocks, trade_date)
|
||||
if not base_prices:
|
||||
logger.warning("无法获取基准价格")
|
||||
return 0
|
||||
|
||||
# 获取最新价格
|
||||
latest_prices = get_latest_prices(all_stocks)
|
||||
if not latest_prices:
|
||||
logger.warning("无法获取最新价格")
|
||||
return 0
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_changes = calculate_change_pct(base_prices, latest_prices)
|
||||
if not stock_changes:
|
||||
logger.warning("无涨跌幅数据")
|
||||
return 0
|
||||
|
||||
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
|
||||
|
||||
# 计算概念统计
|
||||
stats = calculate_concept_stats(concepts, stock_changes, trade_date)
|
||||
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
|
||||
|
||||
# 更新MySQL
|
||||
updated = update_mysql_stats(stats)
|
||||
logger.info(f"更新了 {updated} 条记录到MySQL")
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def run_realtime():
|
||||
"""实时更新主循环"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("启动概念涨跌幅实时更新服务")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 加载概念数据
|
||||
logger.info("加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
|
||||
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
|
||||
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
logger.info(f"总计 {len(all_concepts)} 个概念")
|
||||
|
||||
# 收集所有股票代码
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
logger.info(f"监控 {len(all_stocks)} 只股票")
|
||||
|
||||
last_concept_update = datetime.now()
|
||||
|
||||
while True:
|
||||
try:
|
||||
now = datetime.now()
|
||||
|
||||
# 每小时重新加载概念数据
|
||||
if (now - last_concept_update).total_seconds() > 3600:
|
||||
logger.info("重新加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
last_concept_update = now
|
||||
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
|
||||
|
||||
# 检查是否交易时间
|
||||
if not is_trading_time():
|
||||
wait_sec = get_next_update_time()
|
||||
wait_min = wait_sec // 60
|
||||
logger.info(f"非交易时间,等待 {wait_min} 分钟后重试...")
|
||||
time.sleep(min(wait_sec, 300)) # 最多等5分钟再检查
|
||||
continue
|
||||
|
||||
# 执行更新
|
||||
logger.info(f"\n{'=' * 40}")
|
||||
logger.info(f"更新时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
|
||||
# 等待下一分钟
|
||||
sleep_sec = 60 - datetime.now().second
|
||||
logger.info(f"完成,等待 {sleep_sec} 秒后继续...")
|
||||
time.sleep(sleep_sec)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n收到退出信号,停止服务...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def run_single():
|
||||
"""单次运行(不循环)"""
|
||||
logger.info("单次更新模式")
|
||||
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
|
||||
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
|
||||
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
logger.info(f"更新完成: {updated} 条记录")
|
||||
|
||||
|
||||
def show_status():
|
||||
"""显示当前状态"""
|
||||
print("\n" + "=" * 60)
|
||||
print("概念涨跌幅实时更新服务 - 状态")
|
||||
print("=" * 60)
|
||||
|
||||
# 当前时间
|
||||
now = datetime.now()
|
||||
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"是否交易时间: {'是' if is_trading_time() else '否'}")
|
||||
|
||||
# MySQL数据状态
|
||||
print("\nMySQL数据状态:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
# 今日数据量
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_type, COUNT(*) as cnt
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE()
|
||||
GROUP BY concept_type
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(" 今日数据:")
|
||||
for row in rows:
|
||||
print(f" {row[0]}: {row[1]} 条")
|
||||
else:
|
||||
print(" 今日暂无数据")
|
||||
|
||||
# 最新更新时间
|
||||
result = conn.execute(text("""
|
||||
SELECT MAX(updated_at) FROM concept_daily_stats WHERE trade_date = CURDATE()
|
||||
"""))
|
||||
row = result.fetchone()
|
||||
if row and row[0]:
|
||||
print(f" 最后更新: {row[0]}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# ClickHouse数据状态
|
||||
print("\nClickHouse数据状态:")
|
||||
try:
|
||||
client = get_ch_client()
|
||||
result = client.execute("""
|
||||
SELECT COUNT(*), MAX(timestamp)
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = today()
|
||||
""")
|
||||
if result:
|
||||
count, max_ts = result[0]
|
||||
print(f" 今日分钟数据: {count:,} 条")
|
||||
print(f" 最新时间戳: {max_ts}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# 今日涨跌幅TOP10
|
||||
print("\n今日涨跌幅 TOP10:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_name, avg_change_pct, stock_count, concept_type
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE() AND concept_type = 'leaf'
|
||||
ORDER BY avg_change_pct DESC
|
||||
LIMIT 10
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(f" {'概念':<25} | {'涨跌幅':>8} | {'股票数':>6}")
|
||||
print(" " + "-" * 50)
|
||||
for row in rows:
|
||||
name = row[0][:25] if len(row[0]) > 25 else row[0]
|
||||
print(f" {name:<25} | {row[1]:>7.2f}% | {row[2]:>6}")
|
||||
else:
|
||||
print(" 暂无数据")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='概念涨跌幅实时更新服务')
|
||||
parser.add_argument('command', nargs='?', default='realtime',
|
||||
choices=['realtime', 'once', 'status'],
|
||||
help='命令: realtime(实时运行), once(单次运行), status(状态查看)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'realtime':
|
||||
run_realtime()
|
||||
elif args.command == 'once':
|
||||
run_single()
|
||||
elif args.command == 'status':
|
||||
show_status()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""创建异动检测所需的数据库表"""
|
||||
|
||||
import sys
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
engine = create_engine('mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock', echo=False)
|
||||
|
||||
# 删除旧表
|
||||
drop_sql1 = 'DROP TABLE IF EXISTS concept_minute_alert'
|
||||
drop_sql2 = 'DROP TABLE IF EXISTS index_minute_snapshot'
|
||||
|
||||
# 创建 concept_minute_alert 表
|
||||
# 支持 Z-Score + SVM 智能检测
|
||||
sql1 = '''
|
||||
CREATE TABLE concept_minute_alert (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(32) NOT NULL,
|
||||
concept_name VARCHAR(100) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL COMMENT 'surge_up=暴涨, surge_down=暴跌, limit_up=涨停增加, rank_jump=排名跃升',
|
||||
trade_date DATE NOT NULL,
|
||||
change_pct DECIMAL(10,4) COMMENT '当前涨跌幅',
|
||||
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅',
|
||||
change_delta DECIMAL(10,4) COMMENT '涨跌幅变化',
|
||||
limit_up_count INT DEFAULT 0 COMMENT '涨停数',
|
||||
prev_limit_up_count INT DEFAULT 0,
|
||||
limit_up_delta INT DEFAULT 0,
|
||||
limit_down_count INT DEFAULT 0 COMMENT '跌停数',
|
||||
rank_position INT COMMENT '当前排名',
|
||||
prev_rank_position INT COMMENT '之前排名',
|
||||
rank_delta INT COMMENT '排名变化(负数表示上升)',
|
||||
index_code VARCHAR(20) DEFAULT '000001.SH',
|
||||
index_price DECIMAL(12,4),
|
||||
index_change_pct DECIMAL(10,4),
|
||||
stock_count INT,
|
||||
concept_type VARCHAR(20) DEFAULT 'leaf',
|
||||
zscore DECIMAL(8,4) COMMENT 'Z-Score值',
|
||||
importance_score DECIMAL(6,4) COMMENT '重要性评分(0-1)',
|
||||
extra_info JSON COMMENT '扩展信息(包含zscore,svm_score等)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_time (alert_time),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_alert_type (alert_type),
|
||||
INDEX idx_trade_date_time (trade_date, alert_time),
|
||||
INDEX idx_importance (importance_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动记录表(智能版)'
|
||||
'''
|
||||
|
||||
# 创建 index_minute_snapshot 表
|
||||
sql2 = '''
|
||||
CREATE TABLE index_minute_snapshot (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
index_code VARCHAR(20) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
snapshot_time DATETIME NOT NULL,
|
||||
price DECIMAL(12,4),
|
||||
open_price DECIMAL(12,4),
|
||||
high_price DECIMAL(12,4),
|
||||
low_price DECIMAL(12,4),
|
||||
prev_close DECIMAL(12,4),
|
||||
change_pct DECIMAL(10,4),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_index_time (index_code, snapshot_time),
|
||||
INDEX idx_trade_date (trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('正在重建数据库表...\n')
|
||||
|
||||
with engine.begin() as conn:
|
||||
# 先删除旧表
|
||||
print('删除旧表...')
|
||||
conn.execute(text(drop_sql1))
|
||||
print(' - concept_minute_alert 已删除')
|
||||
conn.execute(text(drop_sql2))
|
||||
print(' - index_minute_snapshot 已删除')
|
||||
|
||||
# 创建新表
|
||||
print('\n创建新表...')
|
||||
conn.execute(text(sql1))
|
||||
print(' ✅ concept_minute_alert 表创建成功')
|
||||
conn.execute(text(sql2))
|
||||
print(' ✅ index_minute_snapshot 表创建成功')
|
||||
|
||||
print('\n✅ 所有表创建完成!')
|
||||
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
Binary file not shown.
859
ml/backtest_fast.py
Normal file
859
ml/backtest_fast.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
快速融合异动回测脚本
|
||||
|
||||
优化策略:
|
||||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
||||
2. 批量 ML 推理(一次推理所有候选)
|
||||
3. 使用 NumPy 向量化操作替代 Python 循环
|
||||
|
||||
性能对比:
|
||||
- 原版:5分钟/天
|
||||
- 优化版:预计 10-30秒/天
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||||
|
||||
CONFIG = {
|
||||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
||||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
# === 融合权重:均衡 ===
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
# === 触发阈值 ===
|
||||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
||||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
||||
'fusion_trigger': 45,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 规则评分(向量化版)====================
|
||||
|
||||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
||||
"""
|
||||
根据概念股票数量计算动态阈值
|
||||
|
||||
设计思路:
|
||||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
||||
- 中概念(10-50 只):标准阈值
|
||||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
||||
|
||||
返回各指标的调整系数(乘以基准阈值)
|
||||
"""
|
||||
n = len(stock_count)
|
||||
|
||||
# 基于股票数量的调整系数
|
||||
# 小概念:系数 > 1(提高阈值,更难触发)
|
||||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
||||
size_factor = np.ones(n)
|
||||
|
||||
# 微型概念(<5 只):阈值 × 1.8
|
||||
tiny = stock_count < 5
|
||||
size_factor[tiny] = 1.8
|
||||
|
||||
# 小概念(5-10 只):阈值 × 1.4
|
||||
small = (stock_count >= 5) & (stock_count < 10)
|
||||
size_factor[small] = 1.4
|
||||
|
||||
# 中小概念(10-20 只):阈值 × 1.2
|
||||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
||||
size_factor[medium_small] = 1.2
|
||||
|
||||
# 中概念(20-50 只):标准阈值 × 1.0
|
||||
medium = (stock_count >= 20) & (stock_count < 50)
|
||||
size_factor[medium] = 1.0
|
||||
|
||||
# 大概念(50-100 只):阈值 × 0.85
|
||||
large = (stock_count >= 50) & (stock_count < 100)
|
||||
size_factor[large] = 0.85
|
||||
|
||||
# 超大概念(>100 只):阈值 × 0.7
|
||||
xlarge = stock_count >= 100
|
||||
size_factor[xlarge] = 0.7
|
||||
|
||||
return size_factor
|
||||
|
||||
|
||||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||||
"""
|
||||
批量计算规则得分(向量化)- 考虑概念规模版
|
||||
|
||||
设计原则:
|
||||
- 规则作为辅助信号,不应单独主导决策
|
||||
- 根据概念股票数量动态调整阈值
|
||||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
||||
|
||||
Args:
|
||||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
||||
Returns:
|
||||
scores: (n,) 规则得分数组
|
||||
triggered_rules: 每行触发的规则列表
|
||||
"""
|
||||
n = len(df)
|
||||
scores = np.zeros(n)
|
||||
triggered = [[] for _ in range(n)]
|
||||
|
||||
alpha = df['alpha'].values
|
||||
alpha_delta = df['alpha_delta'].values
|
||||
amt_ratio = df['amt_ratio'].values
|
||||
amt_delta = df['amt_delta'].values
|
||||
rank_pct = df['rank_pct'].values
|
||||
limit_up_ratio = df['limit_up_ratio'].values
|
||||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||||
|
||||
alpha_abs = np.abs(alpha)
|
||||
alpha_delta_abs = np.abs(alpha_delta)
|
||||
|
||||
# 获取基于规模的调整系数
|
||||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||||
|
||||
# ========== Alpha 规则(动态阈值)==========
|
||||
# 基准阈值:极强 5%,强 4%,中等 3%
|
||||
# 实际阈值 = 基准 × size_factor
|
||||
|
||||
# 极强信号
|
||||
alpha_extreme_thresh = 5.0 * size_factor
|
||||
mask = alpha_abs >= alpha_extreme_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||||
|
||||
# 强信号
|
||||
alpha_strong_thresh = 4.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||||
|
||||
# 中等信号
|
||||
alpha_medium_thresh = 3.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||||
|
||||
# ========== Alpha 加速度规则(动态阈值)==========
|
||||
delta_strong_thresh = 2.0 * size_factor
|
||||
mask = alpha_delta_abs >= delta_strong_thresh
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||||
|
||||
delta_medium_thresh = 1.5 * size_factor
|
||||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||||
|
||||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
||||
mask = amt_ratio >= 10.0
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||||
|
||||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||||
|
||||
# ========== 排名规则 ==========
|
||||
mask = rank_pct >= 0.98
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||||
|
||||
mask = rank_pct <= 0.02
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||||
|
||||
# ========== 涨停规则(动态阈值)==========
|
||||
# 大概念有涨停更有意义
|
||||
limit_high_thresh = 0.30 * size_factor
|
||||
mask = limit_up_ratio >= limit_high_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||||
|
||||
limit_medium_thresh = 0.20 * size_factor
|
||||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||||
|
||||
# ========== 概念规模加分(大概念异动更有价值)==========
|
||||
# 大概念(50+)额外加分
|
||||
large_concept = stock_count >= 50
|
||||
has_signal = scores > 0 # 至少触发了某个规则
|
||||
mask = large_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||||
|
||||
# 超大概念(100+)再加分
|
||||
xlarge_concept = stock_count >= 100
|
||||
mask = xlarge_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||||
|
||||
# ========== 组合规则(动态阈值)==========
|
||||
combo_alpha_thresh = 3.0 * size_factor
|
||||
|
||||
# Alpha + 放量 + 排名(三重验证)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||||
|
||||
# Alpha + 涨停(强组合)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||||
|
||||
# ========== 小概念惩罚(过滤噪音)==========
|
||||
# 微型概念(<5 只)如果只有单一信号,减分
|
||||
tiny_concept = stock_count < 5
|
||||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||||
mask = tiny_concept & single_rule & (scores > 0)
|
||||
scores[mask] *= 0.5 # 减半
|
||||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||||
|
||||
scores = np.clip(scores, 0, 100)
|
||||
return scores, triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class FastMLScorer:
|
||||
"""快速 ML 评分器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
from model import LSTMAutoencoder
|
||||
|
||||
config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config = json.load(f).get('model', {})
|
||||
|
||||
# 处理旧配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||||
config.pop(key, None)
|
||||
if 'num_layers' not in config:
|
||||
config['num_layers'] = 1
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.model = LSTMAutoencoder(**config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path) as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"ML模型加载成功 (设备: {self.device})")
|
||||
except Exception as e:
|
||||
print(f"ML模型加载失败: {e}")
|
||||
self.model = None
|
||||
|
||||
def is_ready(self):
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (batch,) 0-100 分数
|
||||
"""
|
||||
if not self.is_ready() or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1)
|
||||
errors = mse[:, -1].cpu().numpy()
|
||||
|
||||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 快速回测 ====================
|
||||
|
||||
def build_sequences_fast(
|
||||
df: pd.DataFrame,
|
||||
seq_len: int = 30,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
快速构建所有有效序列
|
||||
|
||||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
||||
|
||||
Args:
|
||||
df: 当天数据
|
||||
seq_len: 序列长度
|
||||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
||||
|
||||
返回:
|
||||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
||||
info_df: 对应的元信息 DataFrame
|
||||
"""
|
||||
# 确保按概念和时间排序
|
||||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||||
|
||||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
||||
prev_cache = {}
|
||||
if prev_df is not None and len(prev_df) > 0:
|
||||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
||||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
||||
tail_data = gdf.tail(seq_len - 1)
|
||||
if len(tail_data) > 0:
|
||||
feat_matrix = tail_data[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
prev_cache[concept_id] = feat_matrix
|
||||
|
||||
# 按概念分组
|
||||
groups = df.groupby('concept_id')
|
||||
|
||||
sequences = []
|
||||
infos = []
|
||||
|
||||
for concept_id, gdf in groups:
|
||||
gdf = gdf.reset_index(drop=True)
|
||||
|
||||
# 获取特征矩阵
|
||||
feat_matrix = gdf[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
|
||||
# 如果有前一天缓存,拼接到当天数据前面
|
||||
if concept_id in prev_cache:
|
||||
prev_data = prev_cache[concept_id]
|
||||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
||||
# 计算偏移量:前一天数据的长度
|
||||
offset = len(prev_data)
|
||||
else:
|
||||
combined_matrix = feat_matrix
|
||||
offset = 0
|
||||
|
||||
# 滑动窗口构建序列
|
||||
n_total = len(combined_matrix)
|
||||
if n_total < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(n_total - seq_len + 1):
|
||||
seq = combined_matrix[i:i + seq_len]
|
||||
|
||||
# 计算对应当天数据的索引
|
||||
# 序列最后一个点的位置 = i + seq_len - 1
|
||||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
||||
today_idx = i + seq_len - 1 - offset
|
||||
|
||||
# 只要序列的最后一个点是当天的数据,就记录
|
||||
if today_idx < 0 or today_idx >= len(gdf):
|
||||
continue
|
||||
|
||||
sequences.append(seq)
|
||||
|
||||
# 记录最后一个时间步的信息(当天的)
|
||||
row = gdf.iloc[today_idx]
|
||||
infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'alpha': row['alpha'],
|
||||
'alpha_delta': row.get('alpha_delta', 0),
|
||||
'amt_ratio': row.get('amt_ratio', 1),
|
||||
'amt_delta': row.get('amt_delta', 0),
|
||||
'rank_pct': row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||||
'stock_count': row.get('stock_count', 0),
|
||||
'total_amt': row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not sequences:
|
||||
return np.array([]), pd.DataFrame()
|
||||
|
||||
return np.array(sequences), pd.DataFrame(infos)
|
||||
|
||||
|
||||
def backtest_single_day_fast(
|
||||
ml_scorer: FastMLScorer,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
config: Dict,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
快速回测单天(向量化版本)
|
||||
|
||||
Args:
|
||||
ml_scorer: ML 评分器
|
||||
df: 当天数据
|
||||
date: 日期
|
||||
config: 配置
|
||||
prev_df: 前一天数据(用于 9:30 开始检测)
|
||||
"""
|
||||
seq_len = config.get('seq_len', 30)
|
||||
|
||||
# 1. 构建所有序列(支持跨日)
|
||||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 2. 过滤小波动
|
||||
alpha_abs = np.abs(info_df['alpha'].values)
|
||||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
||||
|
||||
sequences = sequences[valid_mask]
|
||||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 3. 批量规则评分
|
||||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||||
|
||||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
||||
batch_size = 2048
|
||||
ml_scores = []
|
||||
for i in range(0, len(sequences), batch_size):
|
||||
batch_seq = sequences[i:i+batch_size]
|
||||
batch_scores = ml_scorer.score_batch(batch_seq)
|
||||
ml_scores.append(batch_scores)
|
||||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||||
|
||||
# 5. 融合得分
|
||||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||||
|
||||
# 6. 判断异动
|
||||
is_anomaly = (
|
||||
(rule_scores >= config['rule_trigger']) |
|
||||
(ml_scores >= config['ml_trigger']) |
|
||||
(final_scores >= config['fusion_trigger'])
|
||||
)
|
||||
|
||||
# 7. 应用冷却期(按概念+时间排序后处理)
|
||||
info_df['rule_score'] = rule_scores
|
||||
info_df['ml_score'] = ml_scores
|
||||
info_df['final_score'] = final_scores
|
||||
info_df['is_anomaly'] = is_anomaly
|
||||
info_df['triggered_rules'] = triggered_rules
|
||||
|
||||
# 只保留异动
|
||||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||||
|
||||
if len(anomaly_df) == 0:
|
||||
return []
|
||||
|
||||
# 应用冷却期
|
||||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||||
cooldown = {}
|
||||
keep_mask = []
|
||||
|
||||
for _, row in anomaly_df.iterrows():
|
||||
cid = row['concept_id']
|
||||
ts = row['timestamp']
|
||||
|
||||
if cid in cooldown:
|
||||
try:
|
||||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
||||
except:
|
||||
diff = config['cooldown_minutes'] + 1
|
||||
|
||||
if diff < config['cooldown_minutes']:
|
||||
keep_mask.append(False)
|
||||
continue
|
||||
|
||||
cooldown[cid] = ts
|
||||
keep_mask.append(True)
|
||||
|
||||
anomaly_df = anomaly_df[keep_mask]
|
||||
|
||||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||||
alerts = []
|
||||
for ts, group in anomaly_df.groupby('timestamp'):
|
||||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
||||
|
||||
for _, row in group.iterrows():
|
||||
alpha = row['alpha']
|
||||
if alpha >= 1.5:
|
||||
atype = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
atype = 'surge_down'
|
||||
elif row['amt_ratio'] >= 3.0:
|
||||
atype = 'volume_spike'
|
||||
else:
|
||||
atype = 'unknown'
|
||||
|
||||
rule_score = row['rule_score']
|
||||
ml_score = row['ml_score']
|
||||
final_score = row['final_score']
|
||||
|
||||
if rule_score >= config['rule_trigger']:
|
||||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= config['ml_trigger']:
|
||||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||||
else:
|
||||
trigger = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row['concept_id'],
|
||||
'alert_time': row['timestamp'],
|
||||
'trade_date': date,
|
||||
'alert_type': atype,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger,
|
||||
'triggered_rules': row['triggered_rules'],
|
||||
'alpha': alpha,
|
||||
'alpha_delta': row['alpha_delta'],
|
||||
'amt_ratio': row['amt_ratio'],
|
||||
'amt_delta': row['amt_delta'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'total_amt': row['total_amt'],
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
if not file_path.exists():
|
||||
return None
|
||||
return pd.read_parquet(file_path)
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
||||
data_path = Path(data_dir)
|
||||
dates = []
|
||||
for f in sorted(data_path.glob("features_*.parquet")):
|
||||
d = f.stem.replace('features_', '')
|
||||
if start <= d <= end:
|
||||
dates.append(d)
|
||||
return dates
|
||||
|
||||
|
||||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
||||
"""获取给定日期之前最近的有数据的交易日"""
|
||||
data_path = Path(data_dir)
|
||||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
||||
|
||||
for i, d in enumerate(all_dates):
|
||||
if d == date and i > 0:
|
||||
return all_dates[i - 1]
|
||||
return None
|
||||
|
||||
|
||||
def export_to_csv(alerts: List[Dict], path: str):
|
||||
if alerts:
|
||||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出: {path}")
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def init_db_table():
|
||||
"""
|
||||
初始化数据库表(如果不存在则创建)
|
||||
|
||||
表结构说明:
|
||||
- concept_id: 概念ID
|
||||
- alert_time: 异动时间(精确到分钟)
|
||||
- trade_date: 交易日期
|
||||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
||||
- final_score: 最终得分(0-100)
|
||||
- rule_score: 规则得分(0-100)
|
||||
- ml_score: ML得分(0-100)
|
||||
- trigger_reason: 触发原因
|
||||
- alpha: 超额收益率
|
||||
- alpha_delta: alpha变化速度
|
||||
- amt_ratio: 成交额放大倍数
|
||||
- rank_pct: 排名百分位
|
||||
- stock_count: 概念股票数量
|
||||
- triggered_rules: 触发的规则列表(JSON)
|
||||
"""
|
||||
create_sql = text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(64),
|
||||
alpha FLOAT,
|
||||
alpha_delta FLOAT,
|
||||
amt_ratio FLOAT,
|
||||
amt_delta FLOAT,
|
||||
rank_pct FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
stock_count INT,
|
||||
total_amt FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_final_score (final_score),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
||||
""")
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(create_sql)
|
||||
print("数据库表已就绪: concept_anomaly_hybrid")
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""
|
||||
保存异动到 MySQL
|
||||
|
||||
Args:
|
||||
alerts: 异动列表
|
||||
dry_run: 是否只模拟,不实际写入
|
||||
|
||||
Returns:
|
||||
实际保存的记录数
|
||||
"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
skipped = 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason,
|
||||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||||
triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||||
:triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
'amt_delta': alert.get('amt_delta', 0),
|
||||
'rank_pct': alert.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'total_amt': alert.get('total_amt', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
||||
|
||||
if skipped > 0:
|
||||
print(f" 跳过 {skipped} 条重复记录")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def clear_alerts_by_date(trade_date: str) -> int:
|
||||
"""清除指定日期的异动记录(用于重新回测)"""
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
result = conn.execute(
|
||||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
||||
{'trade_date': trade_date}
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
if not alerts:
|
||||
print("无异动")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
print(f"\n总异动: {len(alerts)}")
|
||||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
||||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
||||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
||||
|
||||
trigger_type = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
)
|
||||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
||||
parser.add_argument('--data_dir', default='ml/data')
|
||||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
||||
parser.add_argument('--start', required=True)
|
||||
parser.add_argument('--end', default=None)
|
||||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
||||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
||||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
||||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
||||
parser.add_argument('--device', default='auto')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("快速融合异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期: {args.start} ~ {args.end}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"保存数据库: {args.save_db}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化数据库表(如果需要保存)
|
||||
if args.save_db and not args.dry_run:
|
||||
init_db_table()
|
||||
|
||||
# 初始化 ML 评分器
|
||||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
if not dates:
|
||||
print("无数据")
|
||||
return
|
||||
|
||||
print(f"找到 {len(dates)} 天数据\n")
|
||||
|
||||
# 回测(支持跨日序列)
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
prev_df = None # 缓存前一天数据
|
||||
|
||||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
if df is None or df.empty:
|
||||
prev_df = None # 当天无数据,清空缓存
|
||||
continue
|
||||
|
||||
# 第一天需要加载前一天数据(如果存在)
|
||||
if i == 0 and prev_df is None:
|
||||
prev_date = get_prev_trading_day(args.data_dir, date)
|
||||
if prev_date:
|
||||
prev_df = load_daily_features(args.data_dir, prev_date)
|
||||
if prev_df is not None:
|
||||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
||||
|
||||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 保存到数据库
|
||||
if args.save_db and alerts:
|
||||
if args.clear_first and not args.dry_run:
|
||||
cleared = clear_alerts_by_date(date)
|
||||
if cleared > 0:
|
||||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
||||
elif alerts:
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
||||
|
||||
# 当天数据成为下一天的 prev_df
|
||||
prev_df = df
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv:
|
||||
export_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
print(f"\n总计: {len(all_alerts)} 个异动")
|
||||
if args.save_db:
|
||||
print(f"已保存到数据库: {total_saved} 条")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -93,12 +93,12 @@ def backtest_single_day_hybrid(
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据
|
||||
使用融合检测器回测单天数据(批量优化版)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
# 按概念分组,预先构建字典
|
||||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
@@ -114,27 +114,46 @@ def backtest_single_day_hybrid(
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
# 批量收集该时刻所有候选概念
|
||||
batch_sequences = []
|
||||
batch_features = []
|
||||
batch_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped_dict.items():
|
||||
# 检查冷却(提前过滤)
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
window_df = concept_df.loc[mask]
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.tail(seq_len)
|
||||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
||||
|
||||
# 提取特征序列(给 ML 模型)
|
||||
# 当前时刻特征
|
||||
current_row = window_df.iloc[-1]
|
||||
alpha = current_row.get('alpha', 0)
|
||||
|
||||
# 过滤微小波动(提前过滤)
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 提取特征序列
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 当前时刻特征(给规则系统)
|
||||
current_row = window_df.iloc[-1]
|
||||
current_features = {
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha': alpha,
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
@@ -142,41 +161,79 @@ def backtest_single_day_hybrid(
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
# 过滤微小波动
|
||||
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
batch_sequences.append(sequence)
|
||||
batch_features.append(current_features)
|
||||
batch_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not batch_sequences:
|
||||
continue
|
||||
|
||||
# 批量 ML 推理
|
||||
sequences_array = np.array(batch_sequences)
|
||||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
||||
if isinstance(ml_scores, float):
|
||||
ml_scores = [ml_scores]
|
||||
|
||||
# 批量规则评分 + 融合
|
||||
minute_alerts = []
|
||||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
||||
concept_id = info['concept_id']
|
||||
|
||||
# 规则评分
|
||||
rule_score, rule_details = detector.rule_scorer.score(features)
|
||||
|
||||
# ML 评分
|
||||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
||||
|
||||
# 融合
|
||||
w1 = detector.config['rule_weight']
|
||||
w2 = detector.config['ml_weight']
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= detector.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= detector.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= detector.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
if not is_anomaly:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
# 异动类型
|
||||
alpha = features.get('alpha', 0)
|
||||
if alpha >= 1.5:
|
||||
anomaly_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
anomaly_type = 'surge_down'
|
||||
elif features.get('amt_ratio', 1) >= 3.0:
|
||||
anomaly_type = 'volume_spike'
|
||||
else:
|
||||
anomaly_type = 'unknown'
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 融合检测
|
||||
result = detector.detect(current_features, sequence)
|
||||
|
||||
if not result.is_anomaly:
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': result.anomaly_type,
|
||||
'final_score': result.final_score,
|
||||
'rule_score': result.rule_score,
|
||||
'ml_score': result.ml_score,
|
||||
'trigger_reason': result.trigger_reason,
|
||||
'triggered_rules': list(result.rule_details.keys()),
|
||||
**current_features,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
'alert_type': anomaly_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'triggered_rules': list(rule_details.keys()),
|
||||
**features,
|
||||
**info,
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
@@ -341,6 +398,8 @@ def main():
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
help='设备 (cuda/cpu),默认 cuda')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -355,15 +414,19 @@ def main():
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器
|
||||
# 初始化融合检测器(使用 GPU)
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
detector = create_detector(args.checkpoint_dir, config)
|
||||
|
||||
# 修改 detector.py 中 MLScorer 的设备
|
||||
from detector import HybridAnomalyDetector
|
||||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
294
ml/backtest_v2.py
Normal file
294
ml/backtest_v2.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
||||
|
||||
回测指标:
|
||||
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
||||
2. 虚警率:多少异动是噪音
|
||||
3. 持续性:平均异动持续时长
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
|
||||
# ==================== 回测评估 ====================
|
||||
|
||||
def evaluate_alerts(
|
||||
alerts: List[Dict],
|
||||
raw_data: pd.DataFrame,
|
||||
lookahead_minutes: int = 10
|
||||
) -> Dict:
|
||||
"""
|
||||
评估异动质量
|
||||
|
||||
指标:
|
||||
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
||||
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
||||
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
||||
"""
|
||||
if not alerts:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
results = []
|
||||
|
||||
for alert in alerts:
|
||||
concept_id = alert['concept_id']
|
||||
alert_time = alert['alert_time']
|
||||
alert_alpha = alert['alpha']
|
||||
is_up = alert_alpha > 0
|
||||
|
||||
# 获取该概念在异动后的数据
|
||||
concept_data = raw_data[
|
||||
(raw_data['concept_id'] == concept_id) &
|
||||
(raw_data['timestamp'] > alert_time)
|
||||
].head(lookahead_minutes)
|
||||
|
||||
if len(concept_data) < 3:
|
||||
continue
|
||||
|
||||
future_alphas = concept_data['alpha'].values
|
||||
|
||||
# 方向正确:未来 alpha 平均值与当前同向
|
||||
avg_future_alpha = np.mean(future_alphas)
|
||||
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
||||
|
||||
# 持续率:有多少时刻保持同向
|
||||
if is_up:
|
||||
sustained_count = sum(1 for a in future_alphas if a > 0)
|
||||
else:
|
||||
sustained_count = sum(1 for a in future_alphas if a < 0)
|
||||
sustained_rate = sustained_count / len(future_alphas)
|
||||
|
||||
# 峰值收益
|
||||
if is_up:
|
||||
peak = max(future_alphas)
|
||||
else:
|
||||
peak = min(future_alphas)
|
||||
|
||||
results.append({
|
||||
'direction_correct': direction_correct,
|
||||
'sustained_rate': sustained_rate,
|
||||
'peak': peak,
|
||||
'alert_alpha': alert_alpha,
|
||||
})
|
||||
|
||||
if not results:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
return {
|
||||
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
||||
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
||||
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
||||
'total_alerts': len(alerts),
|
||||
'evaluated_alerts': len(results),
|
||||
}
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts or dry_run:
|
||||
return 0
|
||||
|
||||
# 确保表存在
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(128),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_final_score (final_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
||||
"""))
|
||||
|
||||
# 插入数据
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
conn.execute(text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
"""), {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert.get('confirm_ratio', 0),
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert.get('alpha_zscore', 0),
|
||||
'amt_zscore': alert.get('amt_zscore', 0),
|
||||
'rank_zscore': alert.get('rank_zscore', 0),
|
||||
'momentum_3m': alert.get('momentum_3m', 0),
|
||||
'momentum_5m': alert.get('momentum_5m', 0),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
||||
})
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='V2 回测')
|
||||
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
||||
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("V2 回测 - 时间片对齐 + 持续性确认")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"模型目录: {args.model_dir}")
|
||||
print(f"评估前瞻: {args.lookahead} 分钟")
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetectorV2(
|
||||
model_dir=args.model_dir,
|
||||
baseline_dir=args.baseline_dir
|
||||
)
|
||||
|
||||
# 获取交易日
|
||||
from prepare_data_v2 import get_trading_days
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日")
|
||||
return
|
||||
|
||||
print(f"交易日数: {len(trading_days)}")
|
||||
|
||||
# 回测统计
|
||||
total_stats = {
|
||||
'total_alerts': 0,
|
||||
'accuracy_sum': 0,
|
||||
'sustained_sum': 0,
|
||||
'peak_sum': 0,
|
||||
'day_count': 0,
|
||||
}
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for trade_date in tqdm(trading_days, desc="回测进度"):
|
||||
# 检测异动
|
||||
alerts = detector.detect(trade_date)
|
||||
|
||||
if not alerts:
|
||||
continue
|
||||
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 评估
|
||||
raw_data = detector._compute_raw_features(trade_date)
|
||||
if raw_data.empty:
|
||||
continue
|
||||
|
||||
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
||||
|
||||
if stats['evaluated_alerts'] > 0:
|
||||
total_stats['total_alerts'] += stats['total_alerts']
|
||||
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
||||
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
||||
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
||||
total_stats['day_count'] += 1
|
||||
|
||||
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
||||
f"准确率: {stats['accuracy']:.1%}, "
|
||||
f"持续率: {stats['sustained_rate']:.1%}, "
|
||||
f"峰值: {stats['avg_peak']:.2f}%")
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测汇总")
|
||||
print("=" * 60)
|
||||
|
||||
if total_stats['total_alerts'] > 0:
|
||||
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
||||
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
||||
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
||||
|
||||
print(f"总异动数: {total_stats['total_alerts']}")
|
||||
print(f"回测天数: {total_stats['day_count']}")
|
||||
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
||||
print(f"方向准确率: {avg_accuracy:.1%}")
|
||||
print(f"持续率: {avg_sustained:.1%}")
|
||||
print(f"平均峰值: {avg_peak:.2f}%")
|
||||
else:
|
||||
print("无异动检测结果")
|
||||
|
||||
# 保存
|
||||
if args.save and all_alerts:
|
||||
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
||||
saved = save_alerts_to_mysql(all_alerts)
|
||||
print(f"保存完成: {saved} 条")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
ml/checkpoints_v2/config.json
Normal file
31
ml/checkpoints_v2/config.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"seq_len": 10,
|
||||
"stride": 2,
|
||||
"train_end_date": "2025-06-30",
|
||||
"val_end_date": "2025-09-30",
|
||||
"features": [
|
||||
"alpha_zscore",
|
||||
"amt_zscore",
|
||||
"rank_zscore",
|
||||
"momentum_3m",
|
||||
"momentum_5m",
|
||||
"limit_up_ratio"
|
||||
],
|
||||
"batch_size": 32768,
|
||||
"epochs": 150,
|
||||
"learning_rate": 0.0006,
|
||||
"weight_decay": 1e-05,
|
||||
"gradient_clip": 1.0,
|
||||
"patience": 15,
|
||||
"min_delta": 1e-06,
|
||||
"model": {
|
||||
"n_features": 6,
|
||||
"hidden_dim": 32,
|
||||
"latent_dim": 4,
|
||||
"num_layers": 1,
|
||||
"dropout": 0.2,
|
||||
"bidirectional": true
|
||||
},
|
||||
"clip_value": 5.0,
|
||||
"threshold_percentiles": [90, 95, 99]
|
||||
}
|
||||
8
ml/checkpoints_v2/thresholds.json
Normal file
8
ml/checkpoints_v2/thresholds.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"p90": 0.15,
|
||||
"p95": 0.25,
|
||||
"p99": 0.50,
|
||||
"mean": 0.08,
|
||||
"std": 0.12,
|
||||
"median": 0.06
|
||||
}
|
||||
@@ -243,9 +243,12 @@ class MLScorer:
|
||||
):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
# 设备检测
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
@@ -276,8 +279,8 @@ class MLScorer:
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
# 先用 CPU 加载模型(避免 CUDA 不可用问题),再移动到目标设备
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
|
||||
model_config = self.config.get('model', {}) if self.config else {}
|
||||
self.model = create_model(model_config)
|
||||
@@ -294,6 +297,8 @@ class MLScorer:
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 模型加载失败 - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self.model = None
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
@@ -551,7 +556,8 @@ if __name__ == "__main__":
|
||||
},
|
||||
]
|
||||
|
||||
print("\n测试结果:")
|
||||
print("\n" + "-" * 60)
|
||||
print("测试1: 只用规则(无序列数据)")
|
||||
print("-" * 60)
|
||||
|
||||
for case in test_cases:
|
||||
@@ -567,5 +573,63 @@ if __name__ == "__main__":
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
print(f" 触发规则: {list(result.rule_details.keys())}")
|
||||
|
||||
# 测试2: 带序列数据的融合检测
|
||||
print("\n" + "-" * 60)
|
||||
print("测试2: 融合检测(规则 + ML)")
|
||||
print("-" * 60)
|
||||
|
||||
# 生成模拟序列数据
|
||||
seq_len = 30
|
||||
n_features = 6
|
||||
|
||||
# 正常序列:小幅波动
|
||||
normal_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
normal_sequence[:, 0] = np.linspace(0, 0.5, seq_len) # alpha 缓慢上升
|
||||
normal_sequence[:, 2] = np.abs(normal_sequence[:, 2]) + 1 # amt_ratio > 0
|
||||
|
||||
# 异常序列:最后几个时间步突然变化
|
||||
anomaly_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
anomaly_sequence[-5:, 0] = np.linspace(1, 4, 5) # alpha 突然飙升
|
||||
anomaly_sequence[-5:, 1] = np.linspace(0.2, 1.5, 5) # alpha_delta 加速
|
||||
anomaly_sequence[-5:, 2] = np.linspace(2, 6, 5) # amt_ratio 放量
|
||||
anomaly_sequence[:, 2] = np.abs(anomaly_sequence[:, 2]) + 1
|
||||
|
||||
# 测试正常序列
|
||||
normal_features = {
|
||||
'alpha': float(normal_sequence[-1, 0]),
|
||||
'alpha_delta': float(normal_sequence[-1, 1]),
|
||||
'amt_ratio': float(normal_sequence[-1, 2]),
|
||||
'amt_delta': float(normal_sequence[-1, 3]),
|
||||
'rank_pct': 0.5,
|
||||
'limit_up_ratio': 0.02
|
||||
}
|
||||
|
||||
result = detector.detect(normal_features, normal_sequence)
|
||||
print(f"\n正常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
|
||||
# 测试异常序列
|
||||
anomaly_features = {
|
||||
'alpha': float(anomaly_sequence[-1, 0]),
|
||||
'alpha_delta': float(anomaly_sequence[-1, 1]),
|
||||
'amt_ratio': float(anomaly_sequence[-1, 2]),
|
||||
'amt_delta': float(anomaly_sequence[-1, 3]),
|
||||
'rank_pct': 0.95,
|
||||
'limit_up_ratio': 0.15
|
||||
}
|
||||
|
||||
result = detector.detect(anomaly_features, anomaly_sequence)
|
||||
print(f"\n异常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
if result.is_anomaly:
|
||||
print(f" 触发原因: {result.trigger_reason}")
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
|
||||
716
ml/detector_v2.py
Normal file
716
ml/detector_v2.py
Normal file
@@ -0,0 +1,716 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
异动检测器 V2 - 基于时间片对齐 + 持续性确认
|
||||
|
||||
核心改进:
|
||||
1. Z-Score 特征:相对于同时间片历史的偏离
|
||||
2. 短序列 LSTM:10分钟序列,开盘即可用
|
||||
3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动
|
||||
|
||||
检测流程:
|
||||
1. 计算当前时刻的 Z-Score(对比同时间片历史基线)
|
||||
2. 构建最近10分钟的 Z-Score 序列
|
||||
3. LSTM 计算重构误差(ML分数)
|
||||
4. 规则评分(基于 Z-Score 的规则)
|
||||
5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标
|
||||
6. 只有通过持续性确认的才输出为异动
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.model import TransformerAutoencoder
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# 检测配置
|
||||
CONFIG = {
|
||||
# 序列配置
|
||||
'seq_len': 10, # LSTM 序列长度(分钟)
|
||||
|
||||
# 持续性确认配置(核心!)
|
||||
'confirm_window': 5, # 确认窗口(分钟)
|
||||
'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标)
|
||||
|
||||
# Z-Score 阈值
|
||||
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
|
||||
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
|
||||
|
||||
# 融合权重
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
|
||||
# 触发阈值
|
||||
'rule_trigger': 60,
|
||||
'ml_trigger': 70,
|
||||
'fusion_trigger': 50,
|
||||
|
||||
# 冷却期
|
||||
'cooldown_minutes': 10,
|
||||
'max_alerts_per_minute': 15,
|
||||
|
||||
# Z-Score 截断
|
||||
'zscore_clip': 5.0,
|
||||
}
|
||||
|
||||
# V2 特征列表
|
||||
FEATURES_V2 = [
|
||||
'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
|
||||
]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def get_ch_client():
|
||||
return Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
"""时间戳转时间片(HH:MM)"""
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 基线加载 ====================
|
||||
|
||||
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
|
||||
"""加载时间片基线"""
|
||||
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
|
||||
if os.path.exists(baseline_file):
|
||||
with open(baseline_file, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 规则评分(基于 Z-Score)====================
|
||||
|
||||
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
基于 Z-Score 的规则评分
|
||||
|
||||
设计思路:Z-Score 已经标准化,直接用阈值判断
|
||||
"""
|
||||
score = 0.0
|
||||
triggered = []
|
||||
|
||||
alpha_zscore = row.get('alpha_zscore', 0)
|
||||
amt_zscore = row.get('amt_zscore', 0)
|
||||
rank_zscore = row.get('rank_zscore', 0)
|
||||
momentum_3m = row.get('momentum_3m', 0)
|
||||
momentum_5m = row.get('momentum_5m', 0)
|
||||
limit_up_ratio = row.get('limit_up_ratio', 0)
|
||||
|
||||
alpha_zscore_abs = abs(alpha_zscore)
|
||||
amt_zscore_abs = abs(amt_zscore)
|
||||
|
||||
# ========== Alpha Z-Score 规则 ==========
|
||||
if alpha_zscore_abs >= 4.0:
|
||||
score += 25
|
||||
triggered.append('alpha_zscore_extreme')
|
||||
elif alpha_zscore_abs >= 3.0:
|
||||
score += 18
|
||||
triggered.append('alpha_zscore_strong')
|
||||
elif alpha_zscore_abs >= 2.0:
|
||||
score += 10
|
||||
triggered.append('alpha_zscore_moderate')
|
||||
|
||||
# ========== 成交额 Z-Score 规则 ==========
|
||||
if amt_zscore >= 4.0:
|
||||
score += 20
|
||||
triggered.append('amt_zscore_extreme')
|
||||
elif amt_zscore >= 3.0:
|
||||
score += 12
|
||||
triggered.append('amt_zscore_strong')
|
||||
elif amt_zscore >= 2.0:
|
||||
score += 6
|
||||
triggered.append('amt_zscore_moderate')
|
||||
|
||||
# ========== 排名 Z-Score 规则 ==========
|
||||
if abs(rank_zscore) >= 3.0:
|
||||
score += 15
|
||||
triggered.append('rank_zscore_extreme')
|
||||
elif abs(rank_zscore) >= 2.0:
|
||||
score += 8
|
||||
triggered.append('rank_zscore_strong')
|
||||
|
||||
# ========== 动量规则 ==========
|
||||
if momentum_3m >= 1.0:
|
||||
score += 12
|
||||
triggered.append('momentum_3m_strong')
|
||||
elif momentum_3m >= 0.5:
|
||||
score += 6
|
||||
triggered.append('momentum_3m_moderate')
|
||||
|
||||
if momentum_5m >= 1.5:
|
||||
score += 10
|
||||
triggered.append('momentum_5m_strong')
|
||||
|
||||
# ========== 涨停比例规则 ==========
|
||||
if limit_up_ratio >= 0.3:
|
||||
score += 20
|
||||
triggered.append('limit_up_extreme')
|
||||
elif limit_up_ratio >= 0.15:
|
||||
score += 12
|
||||
triggered.append('limit_up_strong')
|
||||
elif limit_up_ratio >= 0.08:
|
||||
score += 5
|
||||
triggered.append('limit_up_moderate')
|
||||
|
||||
# ========== 组合规则 ==========
|
||||
# Alpha Z-Score + 成交额放大
|
||||
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
|
||||
score += 15
|
||||
triggered.append('combo_alpha_amt')
|
||||
|
||||
# Alpha Z-Score + 涨停
|
||||
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
|
||||
score += 12
|
||||
triggered.append('combo_alpha_limitup')
|
||||
|
||||
return min(score, 100), triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class MLScorerV2:
|
||||
"""V2 ML 评分器"""
|
||||
|
||||
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型和阈值"""
|
||||
model_path = os.path.join(self.model_dir, 'best_model.pt')
|
||||
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
|
||||
config_path = os.path.join(self.model_dir, 'config.json')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型文件不存在: {model_path}")
|
||||
return
|
||||
|
||||
# 加载配置
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 创建模型
|
||||
model_config = config.get('model', {})
|
||||
self.model = TransformerAutoencoder(**model_config)
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 加载阈值
|
||||
if os.path.exists(threshold_path):
|
||||
with open(threshold_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"V2 模型加载完成: {model_path}")
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 分数
|
||||
|
||||
返回 0-100 的分数,越高越异常
|
||||
"""
|
||||
if self.model is None:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
# 转换为 tensor
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
# 取最后一个时刻的误差
|
||||
last_errors = errors[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
if self.thresholds:
|
||||
p50 = self.thresholds.get('median', 0.1)
|
||||
p99 = self.thresholds.get('p99', 1.0)
|
||||
|
||||
# 线性映射:p50 -> 50分,p99 -> 99分
|
||||
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
|
||||
scores = np.clip(scores, 0, 100)
|
||||
else:
|
||||
# 没有阈值时,简单归一化
|
||||
scores = last_errors * 100
|
||||
scores = np.clip(scores, 0, 100)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 实时数据管理器 ====================
|
||||
|
||||
class RealtimeDataManagerV2:
|
||||
"""
|
||||
V2 实时数据管理器
|
||||
|
||||
维护:
|
||||
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
|
||||
2. 每个概念的异动候选队列(用于持续性确认)
|
||||
"""
|
||||
|
||||
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
|
||||
self.concepts = {c['concept_id']: c for c in concepts}
|
||||
self.baselines = baselines
|
||||
|
||||
# 概念到股票的映射
|
||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
|
||||
# 历史 Z-Score 序列(每个概念)
|
||||
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
|
||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||
|
||||
# 异动候选队列(用于持续性确认)
|
||||
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
|
||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||
|
||||
# 冷却期记录
|
||||
self.cooldown = {}
|
||||
|
||||
# 上一次更新的时间戳
|
||||
self.last_timestamp = None
|
||||
|
||||
def compute_zscore_features(
|
||||
self,
|
||||
concept_id: str,
|
||||
timestamp,
|
||||
alpha: float,
|
||||
total_amt: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float
|
||||
) -> Optional[Dict]:
|
||||
"""计算单个概念单个时刻的 Z-Score 特征"""
|
||||
if concept_id not in self.baselines:
|
||||
return None
|
||||
|
||||
baseline = self.baselines[concept_id]
|
||||
time_slot = time_to_slot(timestamp)
|
||||
|
||||
# 查找对应时间片的基线
|
||||
bl_row = baseline[baseline['time_slot'] == time_slot]
|
||||
if bl_row.empty:
|
||||
return None
|
||||
|
||||
bl = bl_row.iloc[0]
|
||||
|
||||
# 检查样本量
|
||||
if bl.get('sample_count', 0) < 10:
|
||||
return None
|
||||
|
||||
# 计算 Z-Score
|
||||
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
|
||||
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
|
||||
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
|
||||
|
||||
# 截断
|
||||
clip = CONFIG['zscore_clip']
|
||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||
|
||||
# 计算动量(需要历史)
|
||||
history = self.zscore_history[concept_id]
|
||||
momentum_3m = 0
|
||||
momentum_5m = 0
|
||||
|
||||
if len(history) >= 3:
|
||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
|
||||
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
|
||||
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||
|
||||
if len(history) >= 5:
|
||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
|
||||
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
|
||||
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||
|
||||
return {
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': alpha_zscore,
|
||||
'amt_zscore': amt_zscore,
|
||||
'rank_zscore': rank_zscore,
|
||||
'momentum_3m': momentum_3m,
|
||||
'momentum_5m': momentum_5m,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'total_amt': total_amt,
|
||||
'rank_pct': rank_pct,
|
||||
}
|
||||
|
||||
def update(self, concept_id: str, timestamp, features: Dict):
|
||||
"""更新概念的历史数据"""
|
||||
self.zscore_history[concept_id].append((timestamp, features))
|
||||
|
||||
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
|
||||
"""获取用于 LSTM 的序列"""
|
||||
history = self.zscore_history[concept_id]
|
||||
|
||||
if len(history) < CONFIG['seq_len']:
|
||||
return None
|
||||
|
||||
# 提取特征
|
||||
feature_list = []
|
||||
for _, features in history:
|
||||
feature_list.append([
|
||||
features['alpha_zscore'],
|
||||
features['amt_zscore'],
|
||||
features['rank_zscore'],
|
||||
features['momentum_3m'],
|
||||
features['momentum_5m'],
|
||||
features['limit_up_ratio'],
|
||||
])
|
||||
|
||||
return np.array(feature_list)
|
||||
|
||||
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
|
||||
"""添加异动候选"""
|
||||
self.anomaly_candidates[concept_id].append((timestamp, score))
|
||||
|
||||
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
|
||||
"""
|
||||
检查是否为持续性异动
|
||||
|
||||
返回:(是否确认, 确认比例)
|
||||
"""
|
||||
candidates = self.anomaly_candidates[concept_id]
|
||||
|
||||
if len(candidates) < CONFIG['confirm_window']:
|
||||
return False, 0.0
|
||||
|
||||
# 统计超过阈值的时刻数量
|
||||
exceed_count = sum(1 for _, score in candidates if score >= threshold)
|
||||
ratio = exceed_count / len(candidates)
|
||||
|
||||
return ratio >= CONFIG['confirm_ratio'], ratio
|
||||
|
||||
def check_cooldown(self, concept_id: str, timestamp) -> bool:
|
||||
"""检查是否在冷却期"""
|
||||
if concept_id not in self.cooldown:
|
||||
return False
|
||||
|
||||
last_alert = self.cooldown[concept_id]
|
||||
try:
|
||||
diff = (timestamp - last_alert).total_seconds() / 60
|
||||
return diff < CONFIG['cooldown_minutes']
|
||||
except:
|
||||
return False
|
||||
|
||||
def set_cooldown(self, concept_id: str, timestamp):
|
||||
"""设置冷却期"""
|
||||
self.cooldown[concept_id] = timestamp
|
||||
|
||||
|
||||
# ==================== 异动检测器 V2 ====================
|
||||
|
||||
class AnomalyDetectorV2:
|
||||
"""
|
||||
V2 异动检测器
|
||||
|
||||
核心流程:
|
||||
1. 获取实时数据
|
||||
2. 计算 Z-Score 特征
|
||||
3. 规则评分 + ML 评分
|
||||
4. 持续性确认
|
||||
5. 输出异动
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = 'ml/checkpoints_v2',
|
||||
baseline_dir: str = 'ml/data_v2/baselines'
|
||||
):
|
||||
# 加载概念
|
||||
self.concepts = self._load_concepts()
|
||||
|
||||
# 加载基线
|
||||
self.baselines = load_baselines(baseline_dir)
|
||||
print(f"加载了 {len(self.baselines)} 个概念的基线")
|
||||
|
||||
# 初始化 ML 评分器
|
||||
self.ml_scorer = MLScorerV2(model_dir)
|
||||
|
||||
# 初始化数据管理器
|
||||
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
|
||||
|
||||
# 收集所有股票
|
||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||
|
||||
def _load_concepts(self) -> List[dict]:
|
||||
"""从 ES 加载概念"""
|
||||
concepts = []
|
||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||
|
||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"加载了 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
def detect(self, trade_date: str) -> List[Dict]:
|
||||
"""
|
||||
检测指定日期的异动
|
||||
|
||||
返回异动列表
|
||||
"""
|
||||
print(f"\n检测 {trade_date} 的异动...")
|
||||
|
||||
# 获取原始数据
|
||||
raw_features = self._compute_raw_features(trade_date)
|
||||
if raw_features.empty:
|
||||
print("无数据")
|
||||
return []
|
||||
|
||||
# 按时间排序
|
||||
timestamps = sorted(raw_features['timestamp'].unique())
|
||||
print(f"时间点数: {len(timestamps)}")
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_data = raw_features[raw_features['timestamp'] == ts]
|
||||
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
|
||||
all_alerts.extend(ts_alerts)
|
||||
|
||||
print(f"共检测到 {len(all_alerts)} 个异动")
|
||||
return all_alerts
|
||||
|
||||
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
|
||||
"""计算原始特征(同 prepare_data_v2)"""
|
||||
# 这里简化处理,直接调用数据准备逻辑
|
||||
from prepare_data_v2 import compute_raw_concept_features
|
||||
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
|
||||
|
||||
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
|
||||
"""处理单个时间戳"""
|
||||
alerts = []
|
||||
candidates = [] # (concept_id, features, rule_score, triggered_rules)
|
||||
|
||||
for _, row in ts_data.iterrows():
|
||||
concept_id = row['concept_id']
|
||||
|
||||
# 计算 Z-Score 特征
|
||||
features = self.data_manager.compute_zscore_features(
|
||||
concept_id, timestamp,
|
||||
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
|
||||
)
|
||||
|
||||
if features is None:
|
||||
continue
|
||||
|
||||
# 更新历史
|
||||
self.data_manager.update(concept_id, timestamp, features)
|
||||
|
||||
# 规则评分
|
||||
rule_score, triggered_rules = score_rules_zscore(features)
|
||||
|
||||
# 收集候选
|
||||
candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# 批量 ML 评分
|
||||
sequences = []
|
||||
valid_candidates = []
|
||||
|
||||
for concept_id, features, rule_score, triggered_rules in candidates:
|
||||
seq = self.data_manager.get_sequence(concept_id)
|
||||
if seq is not None:
|
||||
sequences.append(seq)
|
||||
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||
|
||||
if not sequences:
|
||||
return []
|
||||
|
||||
sequences = np.array(sequences)
|
||||
ml_scores = self.ml_scorer.score_batch(sequences)
|
||||
|
||||
# 融合评分 + 持续性确认
|
||||
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
|
||||
ml_score = ml_scores[i]
|
||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||
|
||||
# 判断是否触发
|
||||
is_triggered = (
|
||||
rule_score >= CONFIG['rule_trigger'] or
|
||||
ml_score >= CONFIG['ml_trigger'] or
|
||||
final_score >= CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
# 添加到候选队列
|
||||
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
|
||||
|
||||
if not is_triggered:
|
||||
continue
|
||||
|
||||
# 检查冷却期
|
||||
if self.data_manager.check_cooldown(concept_id, timestamp):
|
||||
continue
|
||||
|
||||
# 持续性确认
|
||||
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
|
||||
concept_id, CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
if not is_sustained:
|
||||
continue
|
||||
|
||||
# 确认为异动!
|
||||
self.data_manager.set_cooldown(concept_id, timestamp)
|
||||
|
||||
# 确定异动类型
|
||||
alpha = features['alpha']
|
||||
if alpha >= 1.5:
|
||||
alert_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
alert_type = 'surge_down'
|
||||
elif features['amt_zscore'] >= 3.0:
|
||||
alert_type = 'volume_spike'
|
||||
else:
|
||||
alert_type = 'surge'
|
||||
|
||||
# 确定触发原因
|
||||
if rule_score >= CONFIG['rule_trigger']:
|
||||
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
elif ml_score >= CONFIG['ml_trigger']:
|
||||
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
else:
|
||||
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': concept_id,
|
||||
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
|
||||
'alert_time': timestamp,
|
||||
'trade_date': trade_date,
|
||||
'alert_type': alert_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'confirm_ratio': confirm_ratio,
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': features['alpha_zscore'],
|
||||
'amt_zscore': features['amt_zscore'],
|
||||
'rank_zscore': features['rank_zscore'],
|
||||
'momentum_3m': features['momentum_3m'],
|
||||
'momentum_5m': features['momentum_5m'],
|
||||
'limit_up_ratio': features['limit_up_ratio'],
|
||||
'triggered_rules': triggered_rules,
|
||||
})
|
||||
|
||||
# 每分钟最多 N 个
|
||||
if len(alerts) > CONFIG['max_alerts_per_minute']:
|
||||
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
|
||||
alerts = alerts[:CONFIG['max_alerts_per_minute']]
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='V2 异动检测器')
|
||||
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
|
||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
detector = AnomalyDetectorV2(
|
||||
model_dir=args.model_dir,
|
||||
baseline_dir=args.baseline_dir
|
||||
)
|
||||
|
||||
alerts = detector.detect(trade_date)
|
||||
|
||||
print(f"\n检测结果:")
|
||||
for alert in alerts[:20]:
|
||||
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
|
||||
f"{alert['concept_name']} ({alert['alert_type']}) "
|
||||
f"分数={alert['final_score']:.0f} "
|
||||
f"确认率={alert['confirm_ratio']:.0%}")
|
||||
|
||||
if len(alerts) > 20:
|
||||
print(f" ... 共 {len(alerts)} 个异动")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -85,9 +85,12 @@ class LSTMAutoencoder(nn.Module):
|
||||
nn.Tanh(), # 限制范围,增加约束
|
||||
)
|
||||
|
||||
# 使用 LeakyReLU 替代 ReLU
|
||||
# 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息
|
||||
# LeakyReLU 保留负值信号(乘以 0.1)
|
||||
self.bottleneck_up = nn.Sequential(
|
||||
nn.Linear(latent_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.LeakyReLU(negative_slope=0.1),
|
||||
)
|
||||
|
||||
# Decoder: 单向 LSTM
|
||||
|
||||
@@ -26,7 +26,9 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from multiprocessing import Manager
|
||||
import multiprocessing
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]:
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
logger.info(f"获取到 {len(concepts)} 个概念")
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
@@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
@@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
||||
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价"""
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
|
||||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
logger.error(f"获取昨收价失败: {e}")
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
@@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数昨收失败: {e}")
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -285,25 +289,19 @@ def compute_daily_features(
|
||||
"""
|
||||
|
||||
# 1. 获取数据
|
||||
logger.info(f" 获取股票数据...")
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
logger.warning(f" 无股票数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f" 获取指数数据...")
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
logger.warning(f" 无指数数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 2. 获取昨收价
|
||||
logger.info(f" 获取昨收价...")
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
logger.warning(f" 无昨收价数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 3. 计算股票涨跌幅和成交额
|
||||
@@ -317,7 +315,6 @@ def compute_daily_features(
|
||||
|
||||
# 5. 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
logger.info(f" 时间点数: {len(timestamps)}")
|
||||
|
||||
# 6. 按时间点计算概念特征
|
||||
results = []
|
||||
@@ -414,87 +411,126 @@ def compute_daily_features(
|
||||
if amt_delta_std > 0:
|
||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||||
|
||||
logger.info(f" 计算完成: {len(final_df)} 条记录")
|
||||
return final_df
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
|
||||
"""处理单个交易日"""
|
||||
def process_single_day(args) -> Tuple[str, bool]:
|
||||
"""
|
||||
处理单个交易日(多进程版本)
|
||||
|
||||
Args:
|
||||
args: (trade_date, concepts, all_stocks) 元组
|
||||
|
||||
Returns:
|
||||
(trade_date, success) 元组
|
||||
"""
|
||||
trade_date, concepts, all_stocks = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
|
||||
# 检查是否已处理
|
||||
if os.path.exists(output_file):
|
||||
logger.info(f"[{trade_date}] 已存在,跳过")
|
||||
return output_file
|
||||
print(f"[{trade_date}] 已存在,跳过")
|
||||
return (trade_date, True)
|
||||
|
||||
logger.info(f"[{trade_date}] 开始处理...")
|
||||
print(f"[{trade_date}] 开始处理...")
|
||||
|
||||
try:
|
||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"[{trade_date}] 无数据")
|
||||
return None
|
||||
print(f"[{trade_date}] 无数据")
|
||||
return (trade_date, False)
|
||||
|
||||
# 保存
|
||||
df.to_parquet(output_file, index=False)
|
||||
logger.info(f"[{trade_date}] 保存完成: {output_file}")
|
||||
return output_file
|
||||
print(f"[{trade_date}] 保存完成")
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{trade_date}] 处理失败: {e}")
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=1, help='并行数(建议1,避免数据库压力)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("数据准备 - Transformer Autoencoder 训练数据")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"日期范围: {args.start} ~ {end_date}")
|
||||
print("=" * 60)
|
||||
print("数据准备 - Transformer Autoencoder 训练数据")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
|
||||
# 收集所有股票
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
logger.info(f"股票总数: {len(all_stocks)}")
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
logger.error("无交易日数据")
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 处理每个交易日
|
||||
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
|
||||
# 如果强制模式,删除已有文件
|
||||
if args.force:
|
||||
for trade_date in trading_days:
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
print(f"删除已有文件: {output_file}")
|
||||
|
||||
# 3. 准备任务参数
|
||||
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
||||
|
||||
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
# 4. 多进程处理
|
||||
success_count = 0
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
|
||||
result = process_single_day(trade_date, concepts, all_stocks)
|
||||
if result:
|
||||
success_count += 1
|
||||
failed_dates = []
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
logger.info(f"数据保存在: {OUTPUT_DIR}")
|
||||
logger.info("=" * 60)
|
||||
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
||||
|
||||
# 使用 tqdm 显示进度
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
715
ml/prepare_data_v2.py
Normal file
715
ml/prepare_data_v2.py
Normal file
@@ -0,0 +1,715 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
|
||||
|
||||
核心改进:
|
||||
1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比
|
||||
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
|
||||
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
|
||||
4. 基于 Z-Score 的动量:消除一天内波动率异构性
|
||||
|
||||
修复:
|
||||
- 滚动窗口基线:避免未来数据泄露
|
||||
- Z-Score 动量:消除早盘/尾盘波动率差异
|
||||
- 进程级数据库单例:避免连接池爆炸
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
import warnings
|
||||
import pickle
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# 输出目录
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
|
||||
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
|
||||
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# 特征配置
|
||||
CONFIG = {
|
||||
'baseline_days': 20, # 滚动窗口大小
|
||||
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
|
||||
'limit_up_threshold': 9.8,
|
||||
'limit_down_threshold': -9.8,
|
||||
'zscore_clip': 5.0,
|
||||
}
|
||||
|
||||
# 特征列表
|
||||
FEATURES_V2 = [
|
||||
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
|
||||
]
|
||||
|
||||
# ==================== 进程级单例(避免连接池爆炸)====================
|
||||
|
||||
# 进程级全局变量
|
||||
_process_mysql_engine = None
|
||||
_process_es_client = None
|
||||
_process_ch_client = None
|
||||
|
||||
|
||||
def init_process_connections():
|
||||
"""进程初始化时调用,创建连接(单例)"""
|
||||
global _process_mysql_engine, _process_es_client, _process_ch_client
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
"""获取进程级 MySQL Engine(单例)"""
|
||||
global _process_mysql_engine
|
||||
if _process_mysql_engine is None:
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
return _process_mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
"""获取进程级 ES 客户端(单例)"""
|
||||
global _process_es_client
|
||||
if _process_es_client is None:
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
return _process_es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取进程级 ClickHouse 客户端(单例)"""
|
||||
global _process_ch_client
|
||||
if _process_ch_client is None:
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _process_ch_client
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
"""将时间戳转换为时间片(HH:MM格式)"""
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 获取概念列表 ====================
|
||||
|
||||
def get_all_concepts() -> List[dict]:
|
||||
"""从ES获取所有叶子概念"""
|
||||
es_client = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es_client.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
# ==================== 获取交易日列表 ====================
|
||||
|
||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
"""获取交易日列表"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) >= '{start_date}'
|
||||
AND toDate(timestamp) <= '{end_date}'
|
||||
ORDER BY trade_date
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
if days:
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
# ==================== 获取昨收价 ====================
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
query = f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||
"""获取指数昨收价"""
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 获取分钟数据 ====================
|
||||
|
||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||
"""获取单日所有股票的分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
ch_codes = []
|
||||
code_map = {}
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_map[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT code, timestamp, close, volume, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY code, timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||
df['code'] = df['ch_code'].map(code_map)
|
||||
df = df.dropna(subset=['code'])
|
||||
|
||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||
|
||||
|
||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||
"""获取单日指数分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT timestamp, close, volume, amt
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{index_code}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||
return df
|
||||
|
||||
|
||||
# ==================== 计算原始概念特征(单日)====================
|
||||
|
||||
def compute_raw_concept_features(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)"""
|
||||
# 检查缓存
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
return pd.read_parquet(cache_file)
|
||||
|
||||
# 获取数据
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取昨收价
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
|
||||
# 概念到股票的映射
|
||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
|
||||
results = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||
index_change = index_change_map.get(ts, 0)
|
||||
|
||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||
|
||||
concept_features = []
|
||||
|
||||
for concept_id, stocks in concept_stocks.items():
|
||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||
|
||||
if not concept_changes:
|
||||
continue
|
||||
|
||||
avg_change = np.mean(concept_changes)
|
||||
total_amt = sum(concept_amts)
|
||||
alpha = avg_change - index_change
|
||||
|
||||
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
|
||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||
|
||||
concept_features.append({
|
||||
'concept_id': concept_id,
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(concept_changes),
|
||||
})
|
||||
|
||||
if not concept_features:
|
||||
continue
|
||||
|
||||
concept_df = pd.DataFrame(concept_features)
|
||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||
concept_df['timestamp'] = ts
|
||||
concept_df['time_slot'] = time_to_slot(ts)
|
||||
concept_df['trade_date'] = trade_date
|
||||
|
||||
results.append(concept_df)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
result_df = pd.concat(results, ignore_index=True)
|
||||
|
||||
# 保存缓存
|
||||
result_df.to_parquet(cache_file, index=False)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
# ==================== 滚动窗口基线计算 ====================
|
||||
|
||||
def compute_rolling_baseline(
|
||||
historical_data: pd.DataFrame,
|
||||
concept_id: str
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
计算单个概念的滚动基线
|
||||
|
||||
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
|
||||
"""
|
||||
if historical_data.empty:
|
||||
return {}
|
||||
|
||||
concept_data = historical_data[historical_data['concept_id'] == concept_id]
|
||||
if concept_data.empty:
|
||||
return {}
|
||||
|
||||
baseline_dict = {}
|
||||
|
||||
for time_slot, group in concept_data.groupby('time_slot'):
|
||||
if len(group) < CONFIG['min_baseline_samples']:
|
||||
continue
|
||||
|
||||
alpha_std = group['alpha'].std()
|
||||
amt_std = group['total_amt'].std()
|
||||
rank_std = group['rank_pct'].std()
|
||||
|
||||
baseline_dict[time_slot] = {
|
||||
'alpha_mean': group['alpha'].mean(),
|
||||
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
|
||||
'amt_mean': group['total_amt'].mean(),
|
||||
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
|
||||
'rank_mean': group['rank_pct'].mean(),
|
||||
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
|
||||
'sample_count': len(group),
|
||||
}
|
||||
|
||||
return baseline_dict
|
||||
|
||||
|
||||
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
|
||||
|
||||
def compute_zscore_features_rolling(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str],
|
||||
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
计算单日的 Z-Score 特征(使用滚动窗口基线)
|
||||
|
||||
关键改进:
|
||||
1. 基线只使用 trade_date 之前的数据(无未来泄露)
|
||||
2. 动量基于 Z-Score 计算(消除波动率异构性)
|
||||
"""
|
||||
# 计算当日原始特征
|
||||
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if raw_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
zscore_records = []
|
||||
|
||||
for concept_id, group in raw_df.groupby('concept_id'):
|
||||
# 计算该概念的滚动基线(只用历史数据)
|
||||
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
|
||||
|
||||
if not baseline_dict:
|
||||
continue
|
||||
|
||||
# 按时间排序
|
||||
group = group.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
# Z-Score 历史(用于计算基于 Z-Score 的动量)
|
||||
zscore_history = []
|
||||
|
||||
for idx, row in group.iterrows():
|
||||
time_slot = row['time_slot']
|
||||
|
||||
if time_slot not in baseline_dict:
|
||||
continue
|
||||
|
||||
bl = baseline_dict[time_slot]
|
||||
|
||||
# 计算 Z-Score
|
||||
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
|
||||
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
|
||||
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
|
||||
|
||||
# 截断极端值
|
||||
clip = CONFIG['zscore_clip']
|
||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||
|
||||
# 记录 Z-Score 历史
|
||||
zscore_history.append(alpha_zscore)
|
||||
|
||||
# 基于 Z-Score 计算动量(消除波动率异构性)
|
||||
momentum_3m = 0.0
|
||||
momentum_5m = 0.0
|
||||
|
||||
if len(zscore_history) >= 3:
|
||||
recent_3 = zscore_history[-3:]
|
||||
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
|
||||
momentum_3m = np.mean(recent_3) - np.mean(older_3)
|
||||
|
||||
if len(zscore_history) >= 5:
|
||||
recent_5 = zscore_history[-5:]
|
||||
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
|
||||
momentum_5m = np.mean(recent_5) - np.mean(older_5)
|
||||
|
||||
zscore_records.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'time_slot': time_slot,
|
||||
'trade_date': trade_date,
|
||||
# 原始特征
|
||||
'alpha': row['alpha'],
|
||||
'total_amt': row['total_amt'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
# Z-Score 特征
|
||||
'alpha_zscore': alpha_zscore,
|
||||
'amt_zscore': amt_zscore,
|
||||
'rank_zscore': rank_zscore,
|
||||
# 基于 Z-Score 的动量
|
||||
'momentum_3m': momentum_3m,
|
||||
'momentum_5m': momentum_5m,
|
||||
})
|
||||
|
||||
if not zscore_records:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.DataFrame(zscore_records)
|
||||
|
||||
|
||||
# ==================== 多进程处理 ====================
|
||||
|
||||
def process_single_day_v2(args) -> Tuple[str, bool]:
|
||||
"""处理单个交易日(多进程版本)"""
|
||||
trade_date, day_index, concepts, all_stocks, all_trading_days = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
|
||||
|
||||
if os.path.exists(output_file):
|
||||
return (trade_date, True)
|
||||
|
||||
try:
|
||||
# 计算滚动窗口范围(该日期之前的 N 天)
|
||||
baseline_days = CONFIG['baseline_days']
|
||||
|
||||
# 找出 trade_date 之前的交易日
|
||||
start_idx = max(0, day_index - baseline_days)
|
||||
end_idx = day_index # 不包含当天
|
||||
|
||||
if end_idx <= start_idx:
|
||||
# 没有足够的历史数据
|
||||
return (trade_date, False)
|
||||
|
||||
historical_days = all_trading_days[start_idx:end_idx]
|
||||
|
||||
# 加载历史原始数据
|
||||
historical_dfs = []
|
||||
for hist_date in historical_days:
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
historical_dfs.append(pd.read_parquet(cache_file))
|
||||
else:
|
||||
# 需要计算
|
||||
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
|
||||
if not hist_df.empty:
|
||||
historical_dfs.append(hist_df)
|
||||
|
||||
if not historical_dfs:
|
||||
return (trade_date, False)
|
||||
|
||||
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
|
||||
|
||||
# 计算当日 Z-Score 特征(使用滚动基线)
|
||||
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
|
||||
|
||||
if df.empty:
|
||||
return (trade_date, False)
|
||||
|
||||
df.to_parquet(output_file, index=False)
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
|
||||
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
CONFIG['baseline_days'] = args.baseline_days
|
||||
|
||||
print("=" * 60)
|
||||
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
print(f"滚动基线窗口: {args.baseline_days} 天")
|
||||
|
||||
# 初始化主进程连接
|
||||
init_process_connections()
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 第一阶段:预计算所有原始特征(用于缓存)
|
||||
print(f"\n{'='*60}")
|
||||
print("第一阶段:预计算原始特征(用于滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 如果强制重新计算,删除缓存
|
||||
if args.force:
|
||||
import shutil
|
||||
if os.path.exists(RAW_CACHE_DIR):
|
||||
shutil.rmtree(RAW_CACHE_DIR)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
if os.path.exists(OUTPUT_DIR):
|
||||
for f in os.listdir(OUTPUT_DIR):
|
||||
if f.startswith('features_v2_'):
|
||||
os.remove(os.path.join(OUTPUT_DIR, f))
|
||||
|
||||
# 单线程预计算原始特征(因为需要顺序缓存)
|
||||
print(f"预计算 {len(trading_days)} 天的原始特征...")
|
||||
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if not os.path.exists(cache_file):
|
||||
compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
# 4. 第二阶段:计算 Z-Score 特征(多进程)
|
||||
print(f"\n{'='*60}")
|
||||
print("第二阶段:计算 Z-Score 特征(滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 从第 baseline_days 天开始(前面的没有足够历史)
|
||||
start_idx = args.baseline_days
|
||||
processable_days = trading_days[start_idx:]
|
||||
|
||||
if not processable_days:
|
||||
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
|
||||
return
|
||||
|
||||
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
|
||||
print(f"跳过前 {start_idx} 天(基线预热期)")
|
||||
|
||||
# 构建任务
|
||||
tasks = []
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
if i >= start_idx:
|
||||
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
|
||||
|
||||
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
success_count = 0
|
||||
failed_dates = []
|
||||
|
||||
# 使用进程池初始化器
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
|
||||
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
|
||||
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1520
ml/realtime_detector.py
Normal file
1520
ml/realtime_detector.py
Normal file
File diff suppressed because it is too large
Load Diff
729
ml/realtime_detector_v2.py
Normal file
729
ml/realtime_detector_v2.py
Normal file
@@ -0,0 +1,729 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 实时异动检测器
|
||||
|
||||
使用方法:
|
||||
# 作为模块导入
|
||||
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect_realtime() # 检测当前时刻
|
||||
|
||||
# 或命令行测试
|
||||
python ml/realtime_detector_v2.py --date 2025-12-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.model import TransformerAutoencoder
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
||||
MODEL_DIR = 'ml/checkpoints_v2'
|
||||
|
||||
# 检测配置
|
||||
CONFIG = {
|
||||
'seq_len': 10, # LSTM 序列长度
|
||||
'confirm_window': 5, # 持续确认窗口
|
||||
'confirm_ratio': 0.6, # 确认比例
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
'rule_trigger': 60,
|
||||
'ml_trigger': 70,
|
||||
'fusion_trigger': 50,
|
||||
'cooldown_minutes': 10,
|
||||
'max_alerts_per_minute': 15,
|
||||
'zscore_clip': 5.0,
|
||||
'limit_up_threshold': 9.8,
|
||||
}
|
||||
|
||||
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
||||
|
||||
|
||||
# ==================== 数据库连接 ====================
|
||||
|
||||
_mysql_engine = None
|
||||
_es_client = None
|
||||
_ch_client = None
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
global _mysql_engine
|
||||
if _mysql_engine is None:
|
||||
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
||||
return _mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
global _es_client
|
||||
if _es_client is None:
|
||||
_es_client = Elasticsearch([ES_HOST])
|
||||
return _es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
global _ch_client
|
||||
if _ch_client is None:
|
||||
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _ch_client
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 规则评分 ====================
|
||||
|
||||
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
||||
"""基于 Z-Score 的规则评分"""
|
||||
score = 0.0
|
||||
triggered = []
|
||||
|
||||
alpha_z = abs(features.get('alpha_zscore', 0))
|
||||
amt_z = features.get('amt_zscore', 0)
|
||||
rank_z = abs(features.get('rank_zscore', 0))
|
||||
mom_3m = features.get('momentum_3m', 0)
|
||||
mom_5m = features.get('momentum_5m', 0)
|
||||
limit_up = features.get('limit_up_ratio', 0)
|
||||
|
||||
# Alpha Z-Score
|
||||
if alpha_z >= 4.0:
|
||||
score += 25
|
||||
triggered.append('alpha_extreme')
|
||||
elif alpha_z >= 3.0:
|
||||
score += 18
|
||||
triggered.append('alpha_strong')
|
||||
elif alpha_z >= 2.0:
|
||||
score += 10
|
||||
triggered.append('alpha_moderate')
|
||||
|
||||
# 成交额 Z-Score
|
||||
if amt_z >= 4.0:
|
||||
score += 20
|
||||
triggered.append('amt_extreme')
|
||||
elif amt_z >= 3.0:
|
||||
score += 12
|
||||
triggered.append('amt_strong')
|
||||
elif amt_z >= 2.0:
|
||||
score += 6
|
||||
triggered.append('amt_moderate')
|
||||
|
||||
# 排名 Z-Score
|
||||
if rank_z >= 3.0:
|
||||
score += 15
|
||||
triggered.append('rank_extreme')
|
||||
elif rank_z >= 2.0:
|
||||
score += 8
|
||||
triggered.append('rank_strong')
|
||||
|
||||
# 动量(基于 Z-Score 的)
|
||||
if mom_3m >= 1.0:
|
||||
score += 12
|
||||
triggered.append('momentum_3m_strong')
|
||||
elif mom_3m >= 0.5:
|
||||
score += 6
|
||||
triggered.append('momentum_3m_moderate')
|
||||
|
||||
if mom_5m >= 1.5:
|
||||
score += 10
|
||||
triggered.append('momentum_5m_strong')
|
||||
|
||||
# 涨停比例
|
||||
if limit_up >= 0.3:
|
||||
score += 20
|
||||
triggered.append('limit_up_extreme')
|
||||
elif limit_up >= 0.15:
|
||||
score += 12
|
||||
triggered.append('limit_up_strong')
|
||||
elif limit_up >= 0.08:
|
||||
score += 5
|
||||
triggered.append('limit_up_moderate')
|
||||
|
||||
# 组合规则
|
||||
if alpha_z >= 2.0 and amt_z >= 2.0:
|
||||
score += 15
|
||||
triggered.append('combo_alpha_amt')
|
||||
|
||||
if alpha_z >= 2.0 and limit_up >= 0.1:
|
||||
score += 12
|
||||
triggered.append('combo_alpha_limitup')
|
||||
|
||||
return min(score, 100), triggered
|
||||
|
||||
|
||||
# ==================== 实时检测器 ====================
|
||||
|
||||
class RealtimeDetectorV2:
|
||||
"""V2 实时异动检测器"""
|
||||
|
||||
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
||||
print("初始化 V2 实时检测器...")
|
||||
|
||||
# 加载概念
|
||||
self.concepts = self._load_concepts()
|
||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||
|
||||
# 加载基线
|
||||
self.baselines = self._load_baselines(baseline_file)
|
||||
|
||||
# 加载模型
|
||||
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
||||
|
||||
# 状态管理
|
||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||
self.cooldown = {}
|
||||
|
||||
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
||||
|
||||
def _load_concepts(self) -> List[dict]:
|
||||
"""从 ES 加载概念"""
|
||||
es = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while hits:
|
||||
for hit in hits:
|
||||
src = hit['_source']
|
||||
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': src.get('concept_id'),
|
||||
'concept_name': src.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
def _load_baselines(self, baseline_file: str) -> Dict:
|
||||
"""加载基线"""
|
||||
if not os.path.exists(baseline_file):
|
||||
print(f"警告: 基线文件不存在: {baseline_file}")
|
||||
print("请先运行: python ml/update_baseline.py")
|
||||
return {}
|
||||
|
||||
with open(baseline_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
||||
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
||||
|
||||
return data.get('baselines', {})
|
||||
|
||||
def _load_model(self, model_dir: str):
|
||||
"""加载模型"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
config_path = os.path.join(model_dir, 'config.json')
|
||||
model_path = os.path.join(model_dir, 'best_model.pt')
|
||||
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型不存在: {model_path}")
|
||||
return None, {}, device
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
thresholds = {}
|
||||
if os.path.exists(threshold_path):
|
||||
with open(threshold_path) as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
return model, thresholds, device
|
||||
|
||||
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取实时数据并计算原始特征"""
|
||||
ch = get_ch_client()
|
||||
|
||||
# 获取股票数据
|
||||
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
stock_query = f"""
|
||||
SELECT code, timestamp, close, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
stock_result = ch.execute(stock_query)
|
||||
if not stock_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
||||
|
||||
# 映射回原始代码
|
||||
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
||||
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
||||
stock_df = stock_df.dropna(subset=['code'])
|
||||
|
||||
# 获取指数数据
|
||||
index_query = f"""
|
||||
SELECT timestamp, close
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{REFERENCE_INDEX}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
index_result = ch.execute(index_query)
|
||||
if not index_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
||||
|
||||
# 获取昨收价
|
||||
engine = get_mysql_engine()
|
||||
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
||||
|
||||
with engine.connect() as conn:
|
||||
prev_result = conn.execute(text(f"""
|
||||
SELECT SECCODE, F007N FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
||||
AND F007N > 0
|
||||
"""))
|
||||
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
||||
|
||||
idx_result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'today': trade_date}).fetchone()
|
||||
index_prev_close = float(idx_result[0]) if idx_result else None
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 按时间聚合概念特征
|
||||
results = []
|
||||
for ts in sorted(stock_df['timestamp'].unique()):
|
||||
ts_data = stock_df[stock_df['timestamp'] == ts]
|
||||
idx_chg = index_map.get(ts, 0)
|
||||
|
||||
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
||||
|
||||
for cid, stocks in self.concept_stocks.items():
|
||||
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
||||
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
alpha = np.mean(changes) - idx_chg
|
||||
total_amt = sum(amts)
|
||||
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
||||
|
||||
results.append({
|
||||
'concept_id': cid,
|
||||
'timestamp': ts,
|
||||
'time_slot': time_to_slot(ts),
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(changes),
|
||||
})
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
|
||||
# 计算排名
|
||||
for ts in df['timestamp'].unique():
|
||||
mask = df['timestamp'] == ts
|
||||
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
||||
|
||||
return df
|
||||
|
||||
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
||||
"""计算 Z-Score"""
|
||||
if concept_id not in self.baselines:
|
||||
return None
|
||||
|
||||
baseline = self.baselines[concept_id]
|
||||
if time_slot not in baseline:
|
||||
return None
|
||||
|
||||
bl = baseline[time_slot]
|
||||
|
||||
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
||||
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
||||
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
||||
|
||||
# 动量(基于 Z-Score 历史)
|
||||
history = list(self.zscore_history[concept_id])
|
||||
mom_3m = 0.0
|
||||
mom_5m = 0.0
|
||||
|
||||
if len(history) >= 3:
|
||||
recent = [h['alpha_zscore'] for h in history[-3:]]
|
||||
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
||||
mom_3m = np.mean(recent) - np.mean(older)
|
||||
|
||||
if len(history) >= 5:
|
||||
recent = [h['alpha_zscore'] for h in history[-5:]]
|
||||
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
||||
mom_5m = np.mean(recent) - np.mean(older)
|
||||
|
||||
return {
|
||||
'alpha_zscore': float(alpha_z),
|
||||
'amt_zscore': float(amt_z),
|
||||
'rank_zscore': float(rank_z),
|
||||
'momentum_3m': float(mom_3m),
|
||||
'momentum_5m': float(mom_5m),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""批量 ML 评分"""
|
||||
if self.model is None or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
last_errors = errors[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
if self.thresholds:
|
||||
p50 = self.thresholds.get('median', 0.001)
|
||||
p99 = self.thresholds.get('p99', 0.05)
|
||||
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
||||
else:
|
||||
scores = last_errors * 1000
|
||||
|
||||
return np.clip(scores, 0, 100)
|
||||
|
||||
def detect(self, trade_date: str = None) -> List[Dict]:
|
||||
"""检测指定日期的异动"""
|
||||
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
||||
print(f"\n检测 {trade_date} 的异动...")
|
||||
|
||||
# 重置状态
|
||||
self.zscore_history.clear()
|
||||
self.anomaly_candidates.clear()
|
||||
self.cooldown.clear()
|
||||
|
||||
# 获取数据
|
||||
raw_df = self._get_realtime_data(trade_date)
|
||||
if raw_df.empty:
|
||||
print("无数据")
|
||||
return []
|
||||
|
||||
timestamps = sorted(raw_df['timestamp'].unique())
|
||||
print(f"时间点数: {len(timestamps)}")
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_data = raw_df[raw_df['timestamp'] == ts]
|
||||
time_slot = time_to_slot(ts)
|
||||
|
||||
candidates = []
|
||||
|
||||
# 计算每个概念的 Z-Score
|
||||
for _, row in ts_data.iterrows():
|
||||
cid = row['concept_id']
|
||||
|
||||
zscore = self._compute_zscore(
|
||||
cid, time_slot,
|
||||
row['alpha'], row['total_amt'], row['rank_pct']
|
||||
)
|
||||
|
||||
if zscore is None:
|
||||
continue
|
||||
|
||||
# 完整特征
|
||||
features = {
|
||||
**zscore,
|
||||
'alpha': row['alpha'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'total_amt': row['total_amt'],
|
||||
}
|
||||
|
||||
# 更新历史
|
||||
self.zscore_history[cid].append(zscore)
|
||||
|
||||
# 规则评分
|
||||
rule_score, triggered = score_rules_zscore(features)
|
||||
|
||||
candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not candidates:
|
||||
continue
|
||||
|
||||
# 批量 ML 评分
|
||||
sequences = []
|
||||
valid_candidates = []
|
||||
|
||||
for cid, features, rule_score, triggered in candidates:
|
||||
history = list(self.zscore_history[cid])
|
||||
if len(history) >= CONFIG['seq_len']:
|
||||
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
||||
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
||||
for h in history])
|
||||
sequences.append(seq)
|
||||
valid_candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not sequences:
|
||||
continue
|
||||
|
||||
ml_scores = self._ml_score(np.array(sequences))
|
||||
|
||||
# 融合 + 确认
|
||||
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
||||
ml_score = ml_scores[i]
|
||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||
|
||||
# 判断触发
|
||||
is_triggered = (
|
||||
rule_score >= CONFIG['rule_trigger'] or
|
||||
ml_score >= CONFIG['ml_trigger'] or
|
||||
final_score >= CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
self.anomaly_candidates[cid].append((ts, final_score))
|
||||
|
||||
if not is_triggered:
|
||||
continue
|
||||
|
||||
# 冷却期
|
||||
if cid in self.cooldown:
|
||||
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
||||
continue
|
||||
|
||||
# 持续确认
|
||||
recent = list(self.anomaly_candidates[cid])
|
||||
if len(recent) < CONFIG['confirm_window']:
|
||||
continue
|
||||
|
||||
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
||||
ratio = exceed / len(recent)
|
||||
|
||||
if ratio < CONFIG['confirm_ratio']:
|
||||
continue
|
||||
|
||||
# 确认异动!
|
||||
self.cooldown[cid] = ts
|
||||
|
||||
alpha = features['alpha']
|
||||
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
||||
|
||||
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
||||
|
||||
all_alerts.append({
|
||||
'concept_id': cid,
|
||||
'concept_name': concept_name,
|
||||
'alert_time': ts,
|
||||
'trade_date': trade_date,
|
||||
'alert_type': alert_type,
|
||||
'final_score': float(final_score),
|
||||
'rule_score': float(rule_score),
|
||||
'ml_score': float(ml_score),
|
||||
'confirm_ratio': float(ratio),
|
||||
'alpha': float(alpha),
|
||||
'alpha_zscore': float(features['alpha_zscore']),
|
||||
'amt_zscore': float(features['amt_zscore']),
|
||||
'rank_zscore': float(features['rank_zscore']),
|
||||
'momentum_3m': float(features['momentum_3m']),
|
||||
'momentum_5m': float(features['momentum_5m']),
|
||||
'limit_up_ratio': float(features['limit_up_ratio']),
|
||||
'triggered_rules': triggered,
|
||||
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
||||
})
|
||||
|
||||
print(f"检测到 {len(all_alerts)} 个异动")
|
||||
return all_alerts
|
||||
|
||||
|
||||
# ==================== 数据库存储 ====================
|
||||
|
||||
def create_v2_table():
|
||||
"""创建 V2 异动表(如果不存在)"""
|
||||
engine = get_mysql_engine()
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(50) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL,
|
||||
final_score FLOAT,
|
||||
rule_score FLOAT,
|
||||
ml_score FLOAT,
|
||||
trigger_reason VARCHAR(200),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""))
|
||||
print("concept_anomaly_v2 表已就绪")
|
||||
|
||||
|
||||
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
||||
"""保存异动到数据库"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
engine = get_mysql_engine()
|
||||
saved = 0
|
||||
|
||||
with engine.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert['confirm_ratio'],
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert['alpha_zscore'],
|
||||
'amt_zscore': alert['amt_zscore'],
|
||||
'rank_zscore': alert['rank_zscore'],
|
||||
'momentum_3m': alert['momentum_3m'],
|
||||
'momentum_5m': alert['momentum_5m'],
|
||||
'limit_up_ratio': alert['limit_up_ratio'],
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--date', type=str, default=None)
|
||||
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保表存在
|
||||
if not args.no_save:
|
||||
create_v2_table()
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect(args.date)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"检测结果 ({len(alerts)} 个异动)")
|
||||
print('='*60)
|
||||
|
||||
for a in alerts[:20]:
|
||||
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
||||
f"{a['concept_name']} | {a['alert_type']} | "
|
||||
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
||||
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
||||
|
||||
if len(alerts) > 20:
|
||||
print(f"... 共 {len(alerts)} 个")
|
||||
|
||||
# 保存到数据库
|
||||
if not args.no_save and alerts:
|
||||
saved = save_alerts_to_db(alerts)
|
||||
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
||||
elif args.no_save:
|
||||
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
622
ml/train_v2.py
Normal file
622
ml/train_v2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
||||
|
||||
改进点:
|
||||
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
||||
2. 短序列:10分钟(不需要30分钟预热)
|
||||
3. 开盘即可检测:9:30 直接有特征
|
||||
|
||||
模型输入:
|
||||
- 过去10分钟的 Z-Score 特征序列
|
||||
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
||||
|
||||
模型学习:
|
||||
- 学习 Z-Score 序列的"正常演化模式"
|
||||
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置(改进!)
|
||||
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
||||
'stride': 2, # 步长2分钟
|
||||
|
||||
# 时间切分
|
||||
'train_end_date': '2024-06-30',
|
||||
'val_end_date': '2024-09-30',
|
||||
|
||||
# V2 特征(Z-Score 为主)
|
||||
'features': [
|
||||
'alpha_zscore', # Alpha 的 Z-Score
|
||||
'amt_zscore', # 成交额的 Z-Score
|
||||
'rank_zscore', # 排名的 Z-Score
|
||||
'momentum_3m', # 3分钟动量
|
||||
'momentum_5m', # 5分钟动量
|
||||
'limit_up_ratio', # 涨停占比
|
||||
],
|
||||
|
||||
# 训练配置
|
||||
'batch_size': 4096,
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4,
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 15,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(小型 LSTM)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32,
|
||||
'latent_dim': 4,
|
||||
'num_layers': 1,
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True,
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""按日期加载 V2 数据"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
date = pf.stem.replace('features_v2_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""按日期划分数据集"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分:")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""按概念分组构建序列"""
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
feature_data = concept_df[features].values
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints_v2'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||||
)
|
||||
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
train_loss = self.train_epoch()
|
||||
val_loss = self.validate()
|
||||
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
if HAS_MATPLOTLIB:
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.set_title('Training & Validation Loss (V2)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Learning Rate')
|
||||
ax2.set_title('Learning Rate Schedule')
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存")
|
||||
|
||||
|
||||
# ==================== 阈值计算 ====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""在验证集上计算阈值"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
seq_errors = errors[:, -1] # 最后一个时刻
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
||||
parser.add_argument('--epochs', type=int, default=100)
|
||||
parser.add_argument('--batch_size', type=int, default=4096)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--device', type=str, default='auto')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
||||
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
config['seq_len'] = args.seq_len
|
||||
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
||||
print("=" * 60)
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"设备: {device}")
|
||||
print(f"序列长度: {config['seq_len']} 分钟")
|
||||
print(f"批次大小: {config['batch_size']}")
|
||||
print(f"特征: {config['features']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载数据
|
||||
print("\n[1/6] 加载 V2 数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 划分数据集
|
||||
print("\n[2/6] 划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data, config['train_end_date'], config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 构建序列
|
||||
print("\n[3/6] 构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!")
|
||||
return
|
||||
|
||||
# 4. 预处理
|
||||
print("\n[4/6] 数据预处理...")
|
||||
clip_value = config['clip_value']
|
||||
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
||||
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# 5. 创建数据加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
) if val_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
trainer.train(config['epochs'])
|
||||
|
||||
# 计算阈值
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
||||
|
||||
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
||||
threshold_model = TransformerAutoencoder(**config['model'])
|
||||
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
threshold_model.to(device)
|
||||
threshold_model.eval()
|
||||
|
||||
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
||||
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
ml/update_baseline.py
Normal file
132
ml/update_baseline.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
每日盘后运行:更新滚动基线
|
||||
|
||||
使用方法:
|
||||
python ml/update_baseline.py
|
||||
|
||||
建议加入 crontab,每天 15:30 后运行:
|
||||
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.prepare_data_v2 import (
|
||||
get_all_concepts, get_trading_days, compute_raw_concept_features,
|
||||
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
|
||||
)
|
||||
|
||||
|
||||
def update_rolling_baseline(baseline_days: int = 20):
|
||||
"""
|
||||
更新滚动基线(用于实盘检测)
|
||||
|
||||
基线 = 最近 N 个交易日每个时间片的统计量
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("更新滚动基线(用于实盘)")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化连接
|
||||
init_process_connections()
|
||||
|
||||
# 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
|
||||
# 获取最近的交易日
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
|
||||
|
||||
trading_days = get_trading_days(start_date, today)
|
||||
|
||||
if len(trading_days) < baseline_days:
|
||||
print(f"错误:交易日不足 {baseline_days} 天")
|
||||
return
|
||||
|
||||
# 只取最近 N 天
|
||||
recent_days = trading_days[-baseline_days:]
|
||||
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
|
||||
|
||||
# 加载原始数据
|
||||
all_data = []
|
||||
for trade_date in tqdm(recent_days, desc="加载数据"):
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
df = pd.read_parquet(cache_file)
|
||||
else:
|
||||
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if not df.empty:
|
||||
all_data.append(df)
|
||||
|
||||
if not all_data:
|
||||
print("错误:无数据")
|
||||
return
|
||||
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
print(f"总数据量: {len(combined):,} 条")
|
||||
|
||||
# 按概念计算基线
|
||||
baselines = {}
|
||||
|
||||
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
|
||||
baseline_dict = {}
|
||||
|
||||
for time_slot, slot_group in group.groupby('time_slot'):
|
||||
if len(slot_group) < CONFIG['min_baseline_samples']:
|
||||
continue
|
||||
|
||||
alpha_std = slot_group['alpha'].std()
|
||||
amt_std = slot_group['total_amt'].std()
|
||||
rank_std = slot_group['rank_pct'].std()
|
||||
|
||||
baseline_dict[time_slot] = {
|
||||
'alpha_mean': float(slot_group['alpha'].mean()),
|
||||
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
|
||||
'amt_mean': float(slot_group['total_amt'].mean()),
|
||||
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
|
||||
'rank_mean': float(slot_group['rank_pct'].mean()),
|
||||
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
|
||||
'sample_count': len(slot_group),
|
||||
}
|
||||
|
||||
if baseline_dict:
|
||||
baselines[concept_id] = baseline_dict
|
||||
|
||||
print(f"计算了 {len(baselines)} 个概念的基线")
|
||||
|
||||
# 保存
|
||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
|
||||
|
||||
with open(baseline_file, 'wb') as f:
|
||||
pickle.dump({
|
||||
'baselines': baselines,
|
||||
'update_time': datetime.now().isoformat(),
|
||||
'date_range': [recent_days[0], recent_days[-1]],
|
||||
'baseline_days': baseline_days,
|
||||
}, f)
|
||||
|
||||
print(f"基线已保存: {baseline_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--days', type=int, default=20, help='基线天数')
|
||||
args = parser.parse_args()
|
||||
|
||||
update_rolling_baseline(args.days)
|
||||
@@ -131,12 +131,14 @@
|
||||
"eslint-plugin-prettier": "3.4.0",
|
||||
"gulp": "4.0.2",
|
||||
"gulp-append-prepend": "1.0.9",
|
||||
"husky": "^9.1.7",
|
||||
"imagemin": "^9.0.1",
|
||||
"imagemin-mozjpeg": "^10.0.0",
|
||||
"imagemin-pngquant": "^10.0.0",
|
||||
"kill-port": "^2.0.1",
|
||||
"less": "^4.4.2",
|
||||
"less-loader": "^12.3.0",
|
||||
"lint-staged": "^16.2.7",
|
||||
"msw": "^2.11.5",
|
||||
"prettier": "2.2.1",
|
||||
"react-error-overlay": "6.0.9",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/StockDetailPanel/components/MiniTimelineChart.js
|
||||
// src/components/Charts/Stock/MiniTimelineChart.js
|
||||
import React, { useState, useEffect, useMemo, useRef, useCallback } from 'react';
|
||||
import ReactECharts from 'echarts-for-react';
|
||||
import * as echarts from 'echarts';
|
||||
4
src/components/Charts/Stock/hooks/index.js
Normal file
4
src/components/Charts/Stock/hooks/index.js
Normal file
@@ -0,0 +1,4 @@
|
||||
// src/components/Charts/Stock/hooks/index.js
|
||||
// 股票图表 Hooks 统一导出
|
||||
|
||||
export { useEventStocks } from './useEventStocks';
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/StockDetailPanel/hooks/useEventStocks.js
|
||||
// src/components/Charts/Stock/hooks/useEventStocks.js
|
||||
import { useSelector, useDispatch, shallowEqual } from 'react-redux';
|
||||
import { useEffect, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
@@ -8,8 +8,8 @@ import {
|
||||
fetchHistoricalEvents,
|
||||
fetchChainAnalysis,
|
||||
fetchExpectationScore
|
||||
} from '../../../../../store/slices/stockSlice';
|
||||
import { logger } from '../../../../../utils/logger';
|
||||
} from '@store/slices/stockSlice';
|
||||
import { logger } from '@utils/logger';
|
||||
|
||||
/**
|
||||
* 事件股票数据 Hook
|
||||
5
src/components/Charts/Stock/index.js
Normal file
5
src/components/Charts/Stock/index.js
Normal file
@@ -0,0 +1,5 @@
|
||||
// src/components/Charts/Stock/index.js
|
||||
// 股票图表组件统一导出
|
||||
|
||||
export { default as MiniTimelineChart } from './MiniTimelineChart';
|
||||
export { useEventStocks } from './hooks/useEventStocks';
|
||||
@@ -121,7 +121,7 @@ const CitationMark = ({ citationId, citation }) => {
|
||||
title={`引用来源 [${citationId}]`}
|
||||
trigger={triggerType}
|
||||
placement="top"
|
||||
overlayInnerStyle={{ maxWidth: 340, padding: '8px' }}
|
||||
styles={{ body: { maxWidth: 340, padding: '8px' } }}
|
||||
open={popoverVisible}
|
||||
onOpenChange={setPopoverVisible}
|
||||
zIndex={2000}
|
||||
|
||||
@@ -146,7 +146,7 @@ const CitedContent = ({
|
||||
</div>
|
||||
|
||||
{/* 响应式样式 */}
|
||||
<style jsx>{`
|
||||
<style>{`
|
||||
@media (max-width: 768px) {
|
||||
.ai-badge-responsive {
|
||||
font-size: 10px !important;
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { ViewIcon } from '@chakra-ui/icons';
|
||||
import EventFollowButton from '@views/Community/components/EventCard/EventFollowButton';
|
||||
import { EventFollowButton } from '@views/Community/components/EventCard/atoms';
|
||||
|
||||
/**
|
||||
* 精简信息栏组件
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
import { eventService } from '@services/eventService';
|
||||
import { useEventStocks } from '@views/Community/components/StockDetailPanel/hooks/useEventStocks';
|
||||
import { useEventStocks } from '@components/Charts/Stock';
|
||||
import { toggleEventFollow, selectEventFollowStatus } from '@store/slices/communityDataSlice';
|
||||
import { useAuth } from '@contexts/AuthContext';
|
||||
import EventHeaderInfo from './EventHeaderInfo';
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
import { ViewIcon } from '@chakra-ui/icons';
|
||||
import dayjs from 'dayjs';
|
||||
import StockChangeIndicators from '../StockChangeIndicators';
|
||||
import EventFollowButton from '@views/Community/components/EventCard/EventFollowButton';
|
||||
import { EventFollowButton } from '@views/Community/components/EventCard/atoms';
|
||||
|
||||
/**
|
||||
* 事件头部信息区组件
|
||||
|
||||
@@ -20,7 +20,7 @@ import { StarIcon } from '@chakra-ui/icons';
|
||||
import { Tag } from 'antd';
|
||||
import { RobotOutlined } from '@ant-design/icons';
|
||||
import { selectIsMobile } from '@store/slices/deviceSlice';
|
||||
import MiniTimelineChart from '@views/Community/components/StockDetailPanel/components/MiniTimelineChart';
|
||||
import { MiniTimelineChart } from '@components/Charts/Stock';
|
||||
import MiniKLineChart from './MiniKLineChart';
|
||||
import TimelineChartModal from '@components/StockChart/TimelineChartModal';
|
||||
import KLineChartModal from '@components/StockChart/KLineChartModal';
|
||||
|
||||
84
src/components/FavoriteButton/index.tsx
Normal file
84
src/components/FavoriteButton/index.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* FavoriteButton - 通用关注/收藏按钮组件(图标按钮)
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { IconButton, Tooltip, Spinner } from '@chakra-ui/react';
|
||||
import { Star } from 'lucide-react';
|
||||
|
||||
export interface FavoriteButtonProps {
|
||||
/** 是否已关注 */
|
||||
isFavorite: boolean;
|
||||
/** 加载状态 */
|
||||
isLoading?: boolean;
|
||||
/** 点击回调 */
|
||||
onClick: () => void;
|
||||
/** 按钮大小 */
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
/** 颜色主题 */
|
||||
colorScheme?: 'gold' | 'default';
|
||||
/** 是否显示 tooltip */
|
||||
showTooltip?: boolean;
|
||||
}
|
||||
|
||||
// 颜色配置
|
||||
const COLORS = {
|
||||
gold: {
|
||||
active: '#F4D03F', // 已关注 - 亮金色
|
||||
inactive: '#C9A961', // 未关注 - 暗金色
|
||||
hoverBg: 'whiteAlpha.100',
|
||||
},
|
||||
default: {
|
||||
active: 'yellow.400',
|
||||
inactive: 'gray.400',
|
||||
hoverBg: 'gray.100',
|
||||
},
|
||||
};
|
||||
|
||||
const FavoriteButton: React.FC<FavoriteButtonProps> = ({
|
||||
isFavorite,
|
||||
isLoading = false,
|
||||
onClick,
|
||||
size = 'sm',
|
||||
colorScheme = 'gold',
|
||||
showTooltip = true,
|
||||
}) => {
|
||||
const colors = COLORS[colorScheme];
|
||||
const currentColor = isFavorite ? colors.active : colors.inactive;
|
||||
const label = isFavorite ? '取消关注' : '加入自选';
|
||||
|
||||
const iconButton = (
|
||||
<IconButton
|
||||
aria-label={label}
|
||||
icon={
|
||||
isLoading ? (
|
||||
<Spinner size="sm" color={currentColor} />
|
||||
) : (
|
||||
<Star
|
||||
size={size === 'sm' ? 18 : size === 'md' ? 20 : 24}
|
||||
fill={isFavorite ? currentColor : 'none'}
|
||||
stroke={currentColor}
|
||||
/>
|
||||
)
|
||||
}
|
||||
variant="ghost"
|
||||
color={currentColor}
|
||||
size={size}
|
||||
onClick={onClick}
|
||||
isDisabled={isLoading}
|
||||
_hover={{ bg: colors.hoverBg }}
|
||||
/>
|
||||
);
|
||||
|
||||
if (showTooltip) {
|
||||
return (
|
||||
<Tooltip label={label} placement="top">
|
||||
{iconButton}
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
return iconButton;
|
||||
};
|
||||
|
||||
export default FavoriteButton;
|
||||
@@ -1,5 +1,7 @@
|
||||
// src/components/InvestmentCalendar/index.js
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { useSelector, useDispatch } from 'react-redux';
|
||||
import { loadWatchlist, toggleWatchlist } from '@store/slices/stockSlice';
|
||||
import {
|
||||
Card, Calendar, Badge, Modal, Table, Tabs, Tag, Button, List, Spin, Empty,
|
||||
Drawer, Typography, Divider, Space, Tooltip, message, Alert
|
||||
@@ -24,6 +26,10 @@ const { TabPane } = Tabs;
|
||||
const { Text, Title, Paragraph } = Typography;
|
||||
|
||||
const InvestmentCalendar = () => {
|
||||
// Redux 状态
|
||||
const dispatch = useDispatch();
|
||||
const reduxWatchlist = useSelector(state => state.stock.watchlist);
|
||||
|
||||
// 权限控制
|
||||
const { hasFeatureAccess, getUpgradeRecommendation } = useSubscription();
|
||||
const [upgradeModalOpen, setUpgradeModalOpen] = useState(false);
|
||||
@@ -45,7 +51,6 @@ const InvestmentCalendar = () => {
|
||||
const [selectedStock, setSelectedStock] = useState(null);
|
||||
const [selectedEventTime, setSelectedEventTime] = useState(null); // 记录事件时间
|
||||
const [followingIds, setFollowingIds] = useState([]); // 正在处理关注的事件ID列表
|
||||
const [addingToWatchlist, setAddingToWatchlist] = useState({}); // 正在添加到自选的股票代码
|
||||
const [expandedReasons, setExpandedReasons] = useState({}); // 跟踪每个股票关联理由的展开状态
|
||||
|
||||
// 加载月度事件统计
|
||||
@@ -174,10 +179,29 @@ const InvestmentCalendar = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 使用 ref 确保只加载一次自选股
|
||||
const watchlistLoadedRef = useRef(false);
|
||||
|
||||
// 组件挂载时加载自选股列表(仅加载一次)
|
||||
useEffect(() => {
|
||||
if (!watchlistLoadedRef.current) {
|
||||
watchlistLoadedRef.current = true;
|
||||
dispatch(loadWatchlist());
|
||||
}
|
||||
}, [dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
loadEventCounts(currentMonth);
|
||||
}, [currentMonth, loadEventCounts]);
|
||||
|
||||
// 检查股票是否已在自选中
|
||||
const isStockInWatchlist = useCallback((stockCode) => {
|
||||
const sixDigitCode = getSixDigitCode(stockCode);
|
||||
return reduxWatchlist.some(item =>
|
||||
getSixDigitCode(item.stock_code) === sixDigitCode
|
||||
);
|
||||
}, [reduxWatchlist]);
|
||||
|
||||
// 自定义日期单元格渲染(Ant Design 5.x API)
|
||||
const cellRender = (current, info) => {
|
||||
// 只处理日期单元格,月份单元格返回默认
|
||||
@@ -220,7 +244,12 @@ const InvestmentCalendar = () => {
|
||||
};
|
||||
|
||||
// 处理日期选择
|
||||
const handleDateSelect = (value) => {
|
||||
// info.source 区分选择来源:'date' = 点击日期,'month'/'year' = 切换月份/年份
|
||||
const handleDateSelect = (value, info) => {
|
||||
// 只有点击日期单元格时才打开弹窗,切换月份/年份时不打开
|
||||
if (info?.source !== 'date') {
|
||||
return;
|
||||
}
|
||||
setSelectedDate(value);
|
||||
loadDateEvents(value);
|
||||
setModalVisible(true);
|
||||
@@ -379,42 +408,35 @@ const InvestmentCalendar = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 添加单只股票到自选(支持新旧格式)
|
||||
// 添加单只股票到自选(乐观更新,无需 loading 状态)
|
||||
const addSingleToWatchlist = async (stock) => {
|
||||
// 兼容新旧格式
|
||||
const code = stock.code || stock[0];
|
||||
const name = stock.name || stock[1];
|
||||
const stockCode = getSixDigitCode(code);
|
||||
|
||||
setAddingToWatchlist(prev => ({ ...prev, [stockCode]: true }));
|
||||
// 检查是否已在自选中
|
||||
if (isStockInWatchlist(code)) {
|
||||
message.info(`${name} 已在自选中`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/account/watchlist', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({
|
||||
stock_code: stockCode, // 使用六位代码
|
||||
stock_name: name // 股票名称
|
||||
})
|
||||
});
|
||||
// 乐观更新:dispatch 后 Redux 立即更新状态,UI 立即响应
|
||||
await dispatch(toggleWatchlist({
|
||||
stockCode,
|
||||
stockName: name,
|
||||
isInWatchlist: false // false 表示添加
|
||||
})).unwrap();
|
||||
|
||||
const data = await response.json();
|
||||
if (data.success) {
|
||||
message.success(`已将 ${name}(${stockCode}) 添加到自选`);
|
||||
} else {
|
||||
message.error(data.error || '添加失败');
|
||||
}
|
||||
message.success(`已将 ${name}(${stockCode}) 添加到自选`);
|
||||
} catch (error) {
|
||||
// 失败时 Redux 会自动回滚状态
|
||||
logger.error('InvestmentCalendar', 'addSingleToWatchlist', error, {
|
||||
stockCode,
|
||||
stockName: name
|
||||
});
|
||||
message.error('添加失败,请重试');
|
||||
} finally {
|
||||
setAddingToWatchlist(prev => ({ ...prev, [stockCode]: false }));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -524,7 +546,9 @@ const InvestmentCalendar = () => {
|
||||
{concepts && concepts.length > 0 ? (
|
||||
concepts.slice(0, 3).map((concept, index) => (
|
||||
<Tag key={index} icon={<TagsOutlined />}>
|
||||
{Array.isArray(concept) ? concept[0] : concept}
|
||||
{typeof concept === 'string'
|
||||
? concept
|
||||
: (concept?.concept || concept?.name || '未知')}
|
||||
</Tag>
|
||||
))
|
||||
) : (
|
||||
@@ -791,17 +815,16 @@ const InvestmentCalendar = () => {
|
||||
key: 'action',
|
||||
width: 100,
|
||||
render: (_, record) => {
|
||||
const stockCode = getSixDigitCode(record.code);
|
||||
const isAdding = addingToWatchlist[stockCode] || false;
|
||||
const inWatchlist = isStockInWatchlist(record.code);
|
||||
|
||||
return (
|
||||
<Button
|
||||
type="default"
|
||||
type={inWatchlist ? "primary" : "default"}
|
||||
size="small"
|
||||
loading={isAdding}
|
||||
onClick={() => addSingleToWatchlist(record)}
|
||||
disabled={inWatchlist}
|
||||
>
|
||||
加自选
|
||||
{inWatchlist ? '已关注' : '加自选'}
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -919,7 +942,7 @@ const InvestmentCalendar = () => {
|
||||
<Table
|
||||
dataSource={selectedStocks}
|
||||
columns={stockColumns}
|
||||
rowKey={(record) => record[0]}
|
||||
rowKey={(record) => record.code}
|
||||
size="middle"
|
||||
pagination={false}
|
||||
/>
|
||||
|
||||
@@ -313,12 +313,29 @@ const StockChartAntdModal = ({
|
||||
axisPointer: { type: 'cross' },
|
||||
formatter: function(params) {
|
||||
const d = params[0]?.dataIndex ?? 0;
|
||||
const priceChangePercent = ((prices[d] - prevClose) / prevClose * 100);
|
||||
const avgChangePercent = ((avgPrices[d] - prevClose) / prevClose * 100);
|
||||
const price = prices[d];
|
||||
const avgPrice = avgPrices[d];
|
||||
const volume = volumes[d];
|
||||
|
||||
// 安全计算涨跌幅,处理 undefined/null/0 的情况
|
||||
const safeCalcPercent = (val, base) => {
|
||||
if (val == null || base == null || base === 0) return 0;
|
||||
return ((val - base) / base * 100);
|
||||
};
|
||||
|
||||
const priceChangePercent = safeCalcPercent(price, prevClose);
|
||||
const avgChangePercent = safeCalcPercent(avgPrice, prevClose);
|
||||
const priceColor = priceChangePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const avgColor = avgChangePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
|
||||
return `时间:${times[d]}<br/>现价:<span style="color: ${priceColor}">¥${prices[d]?.toFixed(2)} (${priceChangePercent >= 0 ? '+' : ''}${priceChangePercent.toFixed(2)}%)</span><br/>均价:<span style="color: ${avgColor}">¥${avgPrices[d]?.toFixed(2)} (${avgChangePercent >= 0 ? '+' : ''}${avgChangePercent.toFixed(2)}%)</span><br/>昨收:¥${prevClose?.toFixed(2)}<br/>成交量:${Math.round(volumes[d]/100)}手`;
|
||||
// 安全格式化数字
|
||||
const safeFixed = (val, digits = 2) => (val != null && !isNaN(val)) ? val.toFixed(digits) : '-';
|
||||
const formatPercent = (val) => {
|
||||
if (val == null || isNaN(val)) return '-';
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%';
|
||||
};
|
||||
|
||||
return `时间:${times[d] || '-'}<br/>现价:<span style="color: ${priceColor}">¥${safeFixed(price)} (${formatPercent(priceChangePercent)})</span><br/>均价:<span style="color: ${avgColor}">¥${safeFixed(avgPrice)} (${formatPercent(avgChangePercent)})</span><br/>昨收:¥${safeFixed(prevClose)}<br/>成交量:${volume != null ? Math.round(volume/100) + '手' : '-'}`;
|
||||
}
|
||||
},
|
||||
grid: [
|
||||
@@ -337,6 +354,7 @@ const StockChartAntdModal = ({
|
||||
position: 'left',
|
||||
axisLabel: {
|
||||
formatter: function(value) {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
return (value >= 0 ? '+' : '') + value.toFixed(2) + '%';
|
||||
}
|
||||
},
|
||||
@@ -354,11 +372,12 @@ const StockChartAntdModal = ({
|
||||
position: 'right',
|
||||
axisLabel: {
|
||||
formatter: function(value) {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
return (value >= 0 ? '+' : '') + value.toFixed(2) + '%';
|
||||
}
|
||||
}
|
||||
},
|
||||
{ type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => Math.round(v/100) + '手' } }
|
||||
{ type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => (v != null && !isNaN(v)) ? Math.round(v/100) + '手' : '-' } }
|
||||
],
|
||||
dataZoom: [
|
||||
{ type: 'inside', xAxisIndex: [0, 1], start: 0, end: 100 },
|
||||
|
||||
@@ -217,27 +217,34 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
if (dataIndex === undefined) return '';
|
||||
|
||||
const item = data[dataIndex];
|
||||
const changeColor = item.change_percent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const changeSign = item.change_percent >= 0 ? '+' : '';
|
||||
if (!item) return '';
|
||||
|
||||
// 安全格式化数字
|
||||
const safeFixed = (val: any, digits = 2) =>
|
||||
val != null && !isNaN(val) ? Number(val).toFixed(digits) : '-';
|
||||
|
||||
const changePercent = item.change_percent ?? 0;
|
||||
const changeColor = changePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const changeSign = changePercent >= 0 ? '+' : '';
|
||||
|
||||
return `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 8px;">${item.time}</div>
|
||||
<div style="font-weight: bold; margin-bottom: 8px;">${item.time || '-'}</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>价格:</span>
|
||||
<span style="color: ${changeColor}; font-weight: bold; margin-left: 20px;">${item.price.toFixed(2)}</span>
|
||||
<span style="color: ${changeColor}; font-weight: bold; margin-left: 20px;">${safeFixed(item.price)}</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>均价:</span>
|
||||
<span style="color: #ffa726; margin-left: 20px;">${item.avg_price.toFixed(2)}</span>
|
||||
<span style="color: #ffa726; margin-left: 20px;">${safeFixed(item.avg_price)}</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>涨跌幅:</span>
|
||||
<span style="color: ${changeColor}; margin-left: 20px;">${changeSign}${item.change_percent.toFixed(2)}%</span>
|
||||
<span style="color: ${changeColor}; margin-left: 20px;">${changeSign}${safeFixed(changePercent)}%</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between;">
|
||||
<span>成交量:</span>
|
||||
<span style="margin-left: 20px;">${(item.volume / 100).toFixed(0)}手</span>
|
||||
<span style="margin-left: 20px;">${item.volume != null ? (item.volume / 100).toFixed(0) : '-'}手</span>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
@@ -314,7 +321,7 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
axisLabel: {
|
||||
color: '#999',
|
||||
fontSize: isMobile ? 10 : 12,
|
||||
formatter: (value: number) => value.toFixed(2),
|
||||
formatter: (value: number) => (value != null && !isNaN(value)) ? value.toFixed(2) : '-',
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -333,6 +340,7 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
color: '#999',
|
||||
fontSize: isMobile ? 10 : 12,
|
||||
formatter: (value: number) => {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
if (value >= 10000) {
|
||||
return (value / 10000).toFixed(1) + '万';
|
||||
}
|
||||
|
||||
232
src/components/SubTabContainer/index.tsx
Normal file
232
src/components/SubTabContainer/index.tsx
Normal file
@@ -0,0 +1,232 @@
|
||||
/**
|
||||
* SubTabContainer - 二级导航容器组件
|
||||
*
|
||||
* 用于模块内的子功能切换(如公司档案下的股权结构、管理团队等)
|
||||
* 与 TabContainer(一级导航)区分:无 Card 包裹,直接融入父容器
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <SubTabContainer
|
||||
* tabs={[
|
||||
* { key: 'tab1', name: 'Tab 1', icon: FaHome, component: Tab1 },
|
||||
* { key: 'tab2', name: 'Tab 2', icon: FaUser, component: Tab2 },
|
||||
* ]}
|
||||
* componentProps={{ stockCode: '000001' }}
|
||||
* onTabChange={(index, key) => console.log('切换到', key)}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback, memo } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Tabs,
|
||||
TabList,
|
||||
TabPanels,
|
||||
Tab,
|
||||
TabPanel,
|
||||
Icon,
|
||||
HStack,
|
||||
Text,
|
||||
Spacer,
|
||||
} from '@chakra-ui/react';
|
||||
import type { ComponentType } from 'react';
|
||||
import type { IconType } from 'react-icons';
|
||||
|
||||
/**
|
||||
* Tab 配置项
|
||||
*/
|
||||
export interface SubTabConfig {
|
||||
key: string;
|
||||
name: string;
|
||||
icon?: IconType | ComponentType;
|
||||
component?: ComponentType<any>;
|
||||
}
|
||||
|
||||
/**
|
||||
* 主题配置
|
||||
*/
|
||||
export interface SubTabTheme {
|
||||
bg: string;
|
||||
borderColor: string;
|
||||
tabSelectedBg: string;
|
||||
tabSelectedColor: string;
|
||||
tabUnselectedColor: string;
|
||||
tabHoverBg: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* 预设主题
|
||||
*/
|
||||
const THEME_PRESETS: Record<string, SubTabTheme> = {
|
||||
blackGold: {
|
||||
bg: 'gray.900',
|
||||
borderColor: 'rgba(212, 175, 55, 0.3)',
|
||||
tabSelectedBg: '#D4AF37',
|
||||
tabSelectedColor: 'gray.900',
|
||||
tabUnselectedColor: '#D4AF37',
|
||||
tabHoverBg: 'gray.600',
|
||||
},
|
||||
default: {
|
||||
bg: 'white',
|
||||
borderColor: 'gray.200',
|
||||
tabSelectedBg: 'blue.500',
|
||||
tabSelectedColor: 'white',
|
||||
tabUnselectedColor: 'gray.600',
|
||||
tabHoverBg: 'gray.100',
|
||||
},
|
||||
};
|
||||
|
||||
export interface SubTabContainerProps {
|
||||
/** Tab 配置数组 */
|
||||
tabs: SubTabConfig[];
|
||||
/** 传递给 Tab 内容组件的 props */
|
||||
componentProps?: Record<string, any>;
|
||||
/** 默认选中的 Tab 索引 */
|
||||
defaultIndex?: number;
|
||||
/** 受控模式下的当前索引 */
|
||||
index?: number;
|
||||
/** Tab 变更回调 */
|
||||
onTabChange?: (index: number, tabKey: string) => void;
|
||||
/** 主题预设 */
|
||||
themePreset?: 'blackGold' | 'default';
|
||||
/** 自定义主题(优先级高于预设) */
|
||||
theme?: Partial<SubTabTheme>;
|
||||
/** 内容区内边距 */
|
||||
contentPadding?: number;
|
||||
/** 是否懒加载 */
|
||||
isLazy?: boolean;
|
||||
/** TabList 右侧自定义内容 */
|
||||
rightElement?: React.ReactNode;
|
||||
}
|
||||
|
||||
const SubTabContainer: React.FC<SubTabContainerProps> = memo(({
|
||||
tabs,
|
||||
componentProps = {},
|
||||
defaultIndex = 0,
|
||||
index: controlledIndex,
|
||||
onTabChange,
|
||||
themePreset = 'blackGold',
|
||||
theme: customTheme,
|
||||
contentPadding = 4,
|
||||
isLazy = true,
|
||||
rightElement,
|
||||
}) => {
|
||||
// 内部状态(非受控模式)
|
||||
const [internalIndex, setInternalIndex] = useState(defaultIndex);
|
||||
|
||||
// 当前索引
|
||||
const currentIndex = controlledIndex ?? internalIndex;
|
||||
|
||||
// 记录已访问的 Tab 索引(用于真正的懒加载)
|
||||
const [visitedTabs, setVisitedTabs] = useState<Set<number>>(
|
||||
() => new Set([controlledIndex ?? defaultIndex])
|
||||
);
|
||||
|
||||
// 合并主题
|
||||
const theme: SubTabTheme = {
|
||||
...THEME_PRESETS[themePreset],
|
||||
...customTheme,
|
||||
};
|
||||
|
||||
/**
|
||||
* 处理 Tab 切换
|
||||
*/
|
||||
const handleTabChange = useCallback(
|
||||
(newIndex: number) => {
|
||||
const tabKey = tabs[newIndex]?.key || '';
|
||||
onTabChange?.(newIndex, tabKey);
|
||||
|
||||
// 记录已访问的 Tab(用于懒加载)
|
||||
setVisitedTabs(prev => {
|
||||
if (prev.has(newIndex)) return prev;
|
||||
return new Set(prev).add(newIndex);
|
||||
});
|
||||
|
||||
if (controlledIndex === undefined) {
|
||||
setInternalIndex(newIndex);
|
||||
}
|
||||
},
|
||||
[tabs, onTabChange, controlledIndex]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Tabs
|
||||
isLazy={isLazy}
|
||||
variant="unstyled"
|
||||
index={currentIndex}
|
||||
onChange={handleTabChange}
|
||||
>
|
||||
<TabList
|
||||
bg={theme.bg}
|
||||
borderBottom="1px solid"
|
||||
borderColor={theme.borderColor}
|
||||
pl={0}
|
||||
pr={2}
|
||||
py={1.5}
|
||||
flexWrap="nowrap"
|
||||
gap={1}
|
||||
alignItems="center"
|
||||
overflowX="auto"
|
||||
css={{
|
||||
'&::-webkit-scrollbar': { display: 'none' },
|
||||
scrollbarWidth: 'none',
|
||||
}}
|
||||
>
|
||||
{tabs.map((tab) => (
|
||||
<Tab
|
||||
key={tab.key}
|
||||
color={theme.tabUnselectedColor}
|
||||
borderRadius="full"
|
||||
px={2.5}
|
||||
py={1.5}
|
||||
fontSize="xs"
|
||||
whiteSpace="nowrap"
|
||||
flexShrink={0}
|
||||
_selected={{
|
||||
bg: theme.tabSelectedBg,
|
||||
color: theme.tabSelectedColor,
|
||||
fontWeight: 'bold',
|
||||
}}
|
||||
_hover={{
|
||||
bg: theme.tabHoverBg,
|
||||
}}
|
||||
>
|
||||
<HStack spacing={1}>
|
||||
{tab.icon && <Icon as={tab.icon} boxSize={3} />}
|
||||
<Text>{tab.name}</Text>
|
||||
</HStack>
|
||||
</Tab>
|
||||
))}
|
||||
{rightElement && (
|
||||
<>
|
||||
<Spacer />
|
||||
<Box flexShrink={0}>{rightElement}</Box>
|
||||
</>
|
||||
)}
|
||||
</TabList>
|
||||
|
||||
<TabPanels p={contentPadding}>
|
||||
{tabs.map((tab, idx) => {
|
||||
const Component = tab.component;
|
||||
// 懒加载:只渲染已访问过的 Tab
|
||||
const shouldRender = !isLazy || visitedTabs.has(idx);
|
||||
|
||||
return (
|
||||
<TabPanel key={tab.key} p={0}>
|
||||
{shouldRender && Component ? (
|
||||
<Component {...componentProps} />
|
||||
) : null}
|
||||
</TabPanel>
|
||||
);
|
||||
})}
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
SubTabContainer.displayName = 'SubTabContainer';
|
||||
|
||||
export default SubTabContainer;
|
||||
56
src/components/TabContainer/TabNavigation.tsx
Normal file
56
src/components/TabContainer/TabNavigation.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* TabNavigation 通用导航组件
|
||||
*
|
||||
* 渲染 Tab 按钮列表,支持图标 + 文字
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { TabList, Tab, HStack, Icon, Text } from '@chakra-ui/react';
|
||||
import type { TabNavigationProps } from './types';
|
||||
|
||||
const TabNavigation: React.FC<TabNavigationProps> = ({
|
||||
tabs,
|
||||
themeColors,
|
||||
borderRadius = 'lg',
|
||||
}) => {
|
||||
return (
|
||||
<TabList
|
||||
bg={themeColors.bg}
|
||||
borderBottom="1px solid"
|
||||
borderColor={themeColors.dividerColor}
|
||||
borderTopLeftRadius={borderRadius}
|
||||
borderTopRightRadius={borderRadius}
|
||||
pl={0}
|
||||
pr={4}
|
||||
py={2}
|
||||
flexWrap="wrap"
|
||||
gap={2}
|
||||
>
|
||||
{tabs.map((tab) => (
|
||||
<Tab
|
||||
key={tab.key}
|
||||
color={themeColors.unselectedText}
|
||||
borderRadius="full"
|
||||
px={4}
|
||||
py={2}
|
||||
fontSize="sm"
|
||||
_selected={{
|
||||
bg: themeColors.selectedBg,
|
||||
color: themeColors.selectedText,
|
||||
fontWeight: 'bold',
|
||||
}}
|
||||
_hover={{
|
||||
bg: 'whiteAlpha.100',
|
||||
}}
|
||||
>
|
||||
<HStack spacing={2}>
|
||||
{tab.icon && <Icon as={tab.icon} boxSize={4} />}
|
||||
<Text>{tab.name}</Text>
|
||||
</HStack>
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
);
|
||||
};
|
||||
|
||||
export default TabNavigation;
|
||||
55
src/components/TabContainer/constants.ts
Normal file
55
src/components/TabContainer/constants.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* TabContainer 常量和主题预设
|
||||
*/
|
||||
|
||||
import type { ThemeColors, ThemePreset } from './types';
|
||||
|
||||
/**
|
||||
* 主题预设配置
|
||||
*/
|
||||
export const THEME_PRESETS: Record<ThemePreset, Required<ThemeColors>> = {
|
||||
// 黑金主题(原 Company 模块风格)
|
||||
blackGold: {
|
||||
bg: '#1A202C',
|
||||
selectedBg: '#C9A961',
|
||||
selectedText: '#FFFFFF',
|
||||
unselectedText: '#D4AF37',
|
||||
dividerColor: 'gray.600',
|
||||
},
|
||||
// 默认主题(Chakra 风格)
|
||||
default: {
|
||||
bg: 'white',
|
||||
selectedBg: 'blue.500',
|
||||
selectedText: 'white',
|
||||
unselectedText: 'gray.600',
|
||||
dividerColor: 'gray.200',
|
||||
},
|
||||
// 深色主题
|
||||
dark: {
|
||||
bg: 'gray.800',
|
||||
selectedBg: 'blue.400',
|
||||
selectedText: 'white',
|
||||
unselectedText: 'gray.300',
|
||||
dividerColor: 'gray.600',
|
||||
},
|
||||
// 浅色主题
|
||||
light: {
|
||||
bg: 'gray.50',
|
||||
selectedBg: 'blue.500',
|
||||
selectedText: 'white',
|
||||
unselectedText: 'gray.700',
|
||||
dividerColor: 'gray.300',
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* 默认配置
|
||||
*/
|
||||
export const DEFAULT_CONFIG = {
|
||||
themePreset: 'blackGold' as ThemePreset,
|
||||
isLazy: true,
|
||||
size: 'lg' as const,
|
||||
borderRadius: 'lg',
|
||||
shadow: 'lg',
|
||||
panelPadding: 0,
|
||||
};
|
||||
134
src/components/TabContainer/index.tsx
Normal file
134
src/components/TabContainer/index.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
/**
|
||||
* TabContainer 通用 Tab 容器组件
|
||||
*
|
||||
* 功能:
|
||||
* - 管理 Tab 切换状态(支持受控/非受控模式)
|
||||
* - 动态渲染 Tab 导航和内容
|
||||
* - 支持多种主题预设(黑金、默认、深色、浅色)
|
||||
* - 支持自定义主题颜色
|
||||
* - 支持懒加载
|
||||
*
|
||||
* @example
|
||||
* // 基础用法(传入 components)
|
||||
* <TabContainer
|
||||
* tabs={[
|
||||
* { key: 'tab1', name: 'Tab 1', icon: FaHome, component: Tab1Content },
|
||||
* { key: 'tab2', name: 'Tab 2', icon: FaUser, component: Tab2Content },
|
||||
* ]}
|
||||
* componentProps={{ userId: '123' }}
|
||||
* onTabChange={(index, key) => console.log('切换到', key)}
|
||||
* />
|
||||
*
|
||||
* @example
|
||||
* // 自定义渲染用法(使用 children)
|
||||
* <TabContainer tabs={tabs} themePreset="dark">
|
||||
* <TabPanel>自定义内容 1</TabPanel>
|
||||
* <TabPanel>自定义内容 2</TabPanel>
|
||||
* </TabContainer>
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
Card,
|
||||
CardBody,
|
||||
Tabs,
|
||||
TabPanels,
|
||||
TabPanel,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import TabNavigation from './TabNavigation';
|
||||
import { THEME_PRESETS, DEFAULT_CONFIG } from './constants';
|
||||
import type { TabContainerProps, ThemeColors } from './types';
|
||||
|
||||
// 导出类型和常量
|
||||
export type { TabConfig, ThemeColors, ThemePreset, TabContainerProps } from './types';
|
||||
export { THEME_PRESETS } from './constants';
|
||||
|
||||
const TabContainer: React.FC<TabContainerProps> = ({
|
||||
tabs,
|
||||
componentProps = {},
|
||||
onTabChange,
|
||||
defaultIndex = 0,
|
||||
index: controlledIndex,
|
||||
themePreset = DEFAULT_CONFIG.themePreset,
|
||||
themeColors: customThemeColors,
|
||||
isLazy = DEFAULT_CONFIG.isLazy,
|
||||
size = DEFAULT_CONFIG.size,
|
||||
borderRadius = DEFAULT_CONFIG.borderRadius,
|
||||
shadow = DEFAULT_CONFIG.shadow,
|
||||
panelPadding = DEFAULT_CONFIG.panelPadding,
|
||||
children,
|
||||
}) => {
|
||||
// 内部状态(非受控模式)
|
||||
const [internalIndex, setInternalIndex] = useState(defaultIndex);
|
||||
|
||||
// 当前索引(支持受控/非受控)
|
||||
const currentIndex = controlledIndex ?? internalIndex;
|
||||
|
||||
// 合并主题颜色(自定义颜色优先)
|
||||
const themeColors: Required<ThemeColors> = useMemo(() => ({
|
||||
...THEME_PRESETS[themePreset],
|
||||
...customThemeColors,
|
||||
}), [themePreset, customThemeColors]);
|
||||
|
||||
/**
|
||||
* 处理 Tab 切换
|
||||
*/
|
||||
const handleTabChange = useCallback((newIndex: number) => {
|
||||
const tabKey = tabs[newIndex]?.key || '';
|
||||
|
||||
// 触发回调
|
||||
onTabChange?.(newIndex, tabKey, currentIndex);
|
||||
|
||||
// 非受控模式下更新内部状态
|
||||
if (controlledIndex === undefined) {
|
||||
setInternalIndex(newIndex);
|
||||
}
|
||||
}, [tabs, onTabChange, currentIndex, controlledIndex]);
|
||||
|
||||
/**
|
||||
* 渲染 Tab 内容
|
||||
*/
|
||||
const renderTabPanels = () => {
|
||||
// 如果传入了 children,直接渲染 children
|
||||
if (children) {
|
||||
return children;
|
||||
}
|
||||
|
||||
// 否则根据 tabs 配置渲染
|
||||
return tabs.map((tab) => {
|
||||
const Component = tab.component;
|
||||
return (
|
||||
<TabPanel key={tab.key} px={panelPadding} py={panelPadding}>
|
||||
{Component ? <Component {...componentProps} /> : null}
|
||||
</TabPanel>
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Card shadow={shadow} bg={themeColors.bg} borderRadius={borderRadius}>
|
||||
<CardBody p={0}>
|
||||
<Tabs
|
||||
isLazy={isLazy}
|
||||
variant="unstyled"
|
||||
size={size}
|
||||
index={currentIndex}
|
||||
onChange={handleTabChange}
|
||||
>
|
||||
{/* Tab 导航 */}
|
||||
<TabNavigation
|
||||
tabs={tabs}
|
||||
themeColors={themeColors}
|
||||
borderRadius={borderRadius}
|
||||
/>
|
||||
|
||||
{/* Tab 内容面板 */}
|
||||
<TabPanels>{renderTabPanels()}</TabPanels>
|
||||
</Tabs>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default TabContainer;
|
||||
85
src/components/TabContainer/types.ts
Normal file
85
src/components/TabContainer/types.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
/**
|
||||
* TabContainer 通用 Tab 容器组件类型定义
|
||||
*/
|
||||
|
||||
import type { ComponentType, ReactNode } from 'react';
|
||||
import type { IconType } from 'react-icons';
|
||||
|
||||
/**
|
||||
* Tab 配置项
|
||||
*/
|
||||
export interface TabConfig {
|
||||
/** Tab 唯一标识 */
|
||||
key: string;
|
||||
/** Tab 显示名称 */
|
||||
name: string;
|
||||
/** Tab 图标(可选) */
|
||||
icon?: IconType | ComponentType;
|
||||
/** Tab 内容组件(可选,如果不传则使用 children 渲染) */
|
||||
component?: ComponentType<any>;
|
||||
}
|
||||
|
||||
/**
|
||||
* 主题颜色配置
|
||||
*/
|
||||
export interface ThemeColors {
|
||||
/** 容器背景色 */
|
||||
bg?: string;
|
||||
/** 选中 Tab 背景色 */
|
||||
selectedBg?: string;
|
||||
/** 选中 Tab 文字颜色 */
|
||||
selectedText?: string;
|
||||
/** 未选中 Tab 文字颜色 */
|
||||
unselectedText?: string;
|
||||
/** 分割线颜色 */
|
||||
dividerColor?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* 预设主题类型
|
||||
*/
|
||||
export type ThemePreset = 'blackGold' | 'default' | 'dark' | 'light';
|
||||
|
||||
/**
|
||||
* TabContainer 组件 Props
|
||||
*/
|
||||
export interface TabContainerProps {
|
||||
/** Tab 配置数组 */
|
||||
tabs: TabConfig[];
|
||||
/** 传递给 Tab 内容组件的通用 props */
|
||||
componentProps?: Record<string, any>;
|
||||
/** Tab 变更回调 */
|
||||
onTabChange?: (index: number, tabKey: string, prevIndex: number) => void;
|
||||
/** 默认选中的 Tab 索引 */
|
||||
defaultIndex?: number;
|
||||
/** 受控模式下的当前索引 */
|
||||
index?: number;
|
||||
/** 主题预设 */
|
||||
themePreset?: ThemePreset;
|
||||
/** 自定义主题颜色(优先级高于预设) */
|
||||
themeColors?: ThemeColors;
|
||||
/** 是否启用懒加载 */
|
||||
isLazy?: boolean;
|
||||
/** Tab 尺寸 */
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
/** 容器圆角 */
|
||||
borderRadius?: string;
|
||||
/** 容器阴影 */
|
||||
shadow?: string;
|
||||
/** 自定义 Tab 面板内边距 */
|
||||
panelPadding?: number | string;
|
||||
/** 子元素(用于自定义渲染 Tab 内容) */
|
||||
children?: ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* TabNavigation 组件 Props
|
||||
*/
|
||||
export interface TabNavigationProps {
|
||||
/** Tab 配置数组 */
|
||||
tabs: TabConfig[];
|
||||
/** 主题颜色 */
|
||||
themeColors: Required<ThemeColors>;
|
||||
/** 容器圆角 */
|
||||
borderRadius?: string;
|
||||
}
|
||||
100
src/components/TabPanelContainer/index.tsx
Normal file
100
src/components/TabPanelContainer/index.tsx
Normal file
@@ -0,0 +1,100 @@
|
||||
/**
|
||||
* TabPanelContainer - Tab 面板通用容器组件
|
||||
*
|
||||
* 提供统一的:
|
||||
* - Loading 状态处理
|
||||
* - VStack 布局
|
||||
* - 免责声明(可选)
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <TabPanelContainer loading={loading} showDisclaimer>
|
||||
* <YourContent />
|
||||
* </TabPanelContainer>
|
||||
* ```
|
||||
*/
|
||||
|
||||
import React, { memo } from 'react';
|
||||
import { VStack, Center, Spinner, Text, Box } from '@chakra-ui/react';
|
||||
|
||||
// 默认免责声明文案
|
||||
const DEFAULT_DISCLAIMER =
|
||||
'免责声明:本内容由AI模型基于新闻、公告、研报等公开信息自动分析和生成,未经许可严禁转载。所有内容仅供参考,不构成任何投资建议,请投资者注意风险,独立审慎决策。';
|
||||
|
||||
export interface TabPanelContainerProps {
|
||||
/** 是否处于加载状态 */
|
||||
loading?: boolean;
|
||||
/** 加载状态显示的文案 */
|
||||
loadingMessage?: string;
|
||||
/** 加载状态高度 */
|
||||
loadingHeight?: string;
|
||||
/** 子组件间距,默认 6 */
|
||||
spacing?: number;
|
||||
/** 内边距,默认 4 */
|
||||
padding?: number;
|
||||
/** 是否显示免责声明,默认 false */
|
||||
showDisclaimer?: boolean;
|
||||
/** 自定义免责声明文案 */
|
||||
disclaimerText?: string;
|
||||
/** 子组件 */
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载状态组件
|
||||
*/
|
||||
const LoadingState: React.FC<{ message: string; height: string }> = ({
|
||||
message,
|
||||
height,
|
||||
}) => (
|
||||
<Center h={height}>
|
||||
<VStack spacing={3}>
|
||||
<Spinner size="lg" color="#D4AF37" thickness="3px" />
|
||||
<Text fontSize="sm" color="gray.500">
|
||||
{message}
|
||||
</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
);
|
||||
|
||||
/**
|
||||
* 免责声明组件
|
||||
*/
|
||||
const DisclaimerText: React.FC<{ text: string }> = ({ text }) => (
|
||||
<Text mt={4} color="gray.500" fontSize="12px" lineHeight="1.5">
|
||||
{text}
|
||||
</Text>
|
||||
);
|
||||
|
||||
/**
|
||||
* Tab 面板通用容器
|
||||
*/
|
||||
const TabPanelContainer: React.FC<TabPanelContainerProps> = memo(
|
||||
({
|
||||
loading = false,
|
||||
loadingMessage = '加载中...',
|
||||
loadingHeight = '200px',
|
||||
spacing = 6,
|
||||
padding = 4,
|
||||
showDisclaimer = false,
|
||||
disclaimerText = DEFAULT_DISCLAIMER,
|
||||
children,
|
||||
}) => {
|
||||
if (loading) {
|
||||
return <LoadingState message={loadingMessage} height={loadingHeight} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={padding}>
|
||||
<VStack spacing={spacing} align="stretch">
|
||||
{children}
|
||||
</VStack>
|
||||
{showDisclaimer && <DisclaimerText text={disclaimerText} />}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
TabPanelContainer.displayName = 'TabPanelContainer';
|
||||
|
||||
export default TabPanelContainer;
|
||||
@@ -661,6 +661,12 @@ export const NotificationProvider = ({ children }) => {
|
||||
|
||||
// ========== 连接到 Socket 服务(⚡ 异步初始化,不阻塞首屏) ==========
|
||||
useEffect(() => {
|
||||
// ⚡ Mock 模式下跳过 Socket 连接(避免连接生产服务器失败的错误)
|
||||
if (process.env.REACT_APP_ENABLE_MOCK === 'true') {
|
||||
logger.debug('NotificationContext', 'Mock 模式,跳过 Socket 连接');
|
||||
return;
|
||||
}
|
||||
|
||||
// ⚡ 防止 React Strict Mode 导致的重复初始化
|
||||
if (socketInitialized) {
|
||||
logger.debug('NotificationContext', 'Socket 已初始化,跳过重复执行(Strict Mode 保护)');
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { initWebVitalsTracking, getCachedMetrics } from '@utils/performance/webVitals';
|
||||
import { collectResourceStats, collectApiStats } from '@utils/performance/resourceMonitor';
|
||||
import { performanceMonitor } from '@utils/performanceMonitor';
|
||||
import { usePerformanceMark } from '@hooks/usePerformanceTracker';
|
||||
import posthog from 'posthog-js';
|
||||
import type {
|
||||
FirstScreenMetrics,
|
||||
@@ -44,11 +46,17 @@ export const useFirstScreenMetrics = (
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [metrics, setMetrics] = useState<FirstScreenMetrics | null>(null);
|
||||
|
||||
// 使用 ref 记录页面加载开始时间
|
||||
const pageLoadStartRef = useRef<number>(performance.now());
|
||||
const skeletonStartRef = useRef<number>(performance.now());
|
||||
// 使用 ref 避免重复标记
|
||||
const hasMarkedRef = useRef(false);
|
||||
const hasInitializedRef = useRef(false);
|
||||
|
||||
// 在组件首次渲染时标记开始时间点
|
||||
if (!hasMarkedRef.current) {
|
||||
hasMarkedRef.current = true;
|
||||
performanceMonitor.mark(`${pageType}-page-load-start`);
|
||||
performanceMonitor.mark(`${pageType}-skeleton-start`);
|
||||
}
|
||||
|
||||
/**
|
||||
* 收集所有首屏指标
|
||||
*/
|
||||
@@ -82,12 +90,20 @@ export const useFirstScreenMetrics = (
|
||||
customProperties,
|
||||
});
|
||||
|
||||
// 5. 计算首屏可交互时间(TTI)
|
||||
const now = performance.now();
|
||||
const timeToInteractive = now - pageLoadStartRef.current;
|
||||
// 5. 标记可交互时间点,并计算 TTI
|
||||
performanceMonitor.mark(`${pageType}-interactive`);
|
||||
const timeToInteractive = performanceMonitor.measure(
|
||||
`${pageType}-page-load-start`,
|
||||
`${pageType}-interactive`,
|
||||
`${pageType} TTI`
|
||||
) || 0;
|
||||
|
||||
// 6. 计算骨架屏展示时长
|
||||
const skeletonDisplayDuration = now - skeletonStartRef.current;
|
||||
const skeletonDisplayDuration = performanceMonitor.measure(
|
||||
`${pageType}-skeleton-start`,
|
||||
`${pageType}-interactive`,
|
||||
`${pageType} 骨架屏时长`
|
||||
) || 0;
|
||||
|
||||
const firstScreenMetrics: FirstScreenMetrics = {
|
||||
webVitals,
|
||||
@@ -143,9 +159,9 @@ export const useFirstScreenMetrics = (
|
||||
const remeasure = useCallback(() => {
|
||||
setIsLoading(true);
|
||||
|
||||
// 重置计时器
|
||||
pageLoadStartRef.current = performance.now();
|
||||
skeletonStartRef.current = performance.now();
|
||||
// 重置性能标记
|
||||
performanceMonitor.mark(`${pageType}-page-load-start`);
|
||||
performanceMonitor.mark(`${pageType}-skeleton-start`);
|
||||
|
||||
// 延迟收集指标(等待 Web Vitals 完成)
|
||||
setTimeout(() => {
|
||||
@@ -167,7 +183,7 @@ export const useFirstScreenMetrics = (
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, 1000); // 延迟 1 秒收集
|
||||
}, [collectAllMetrics, trackFirstScreenInteractive, enableConsoleLog]);
|
||||
}, [pageType, collectAllMetrics, trackFirstScreenInteractive, enableConsoleLog]);
|
||||
|
||||
/**
|
||||
* 导出指标为 JSON
|
||||
@@ -247,7 +263,7 @@ export const useFirstScreenMetrics = (
|
||||
*
|
||||
* 使用示例:
|
||||
* ```tsx
|
||||
* const { markSkeletonEnd } = useSkeletonTiming();
|
||||
* const { markSkeletonEnd } = useSkeletonTiming('home-skeleton');
|
||||
*
|
||||
* useEffect(() => {
|
||||
* if (!loading) {
|
||||
@@ -256,27 +272,32 @@ export const useFirstScreenMetrics = (
|
||||
* }, [loading, markSkeletonEnd]);
|
||||
* ```
|
||||
*/
|
||||
export const useSkeletonTiming = () => {
|
||||
const skeletonStartRef = useRef<number>(performance.now());
|
||||
const skeletonEndRef = useRef<number | null>(null);
|
||||
export const useSkeletonTiming = (prefix = 'skeleton') => {
|
||||
const { mark, getMeasure } = usePerformanceMark(prefix);
|
||||
const hasMarkedEndRef = useRef(false);
|
||||
const hasMarkedStartRef = useRef(false);
|
||||
|
||||
// 在组件首次渲染时标记开始
|
||||
if (!hasMarkedStartRef.current) {
|
||||
hasMarkedStartRef.current = true;
|
||||
mark('start');
|
||||
}
|
||||
|
||||
const markSkeletonEnd = useCallback(() => {
|
||||
if (!skeletonEndRef.current) {
|
||||
skeletonEndRef.current = performance.now();
|
||||
const duration = skeletonEndRef.current - skeletonStartRef.current;
|
||||
if (!hasMarkedEndRef.current) {
|
||||
hasMarkedEndRef.current = true;
|
||||
mark('end');
|
||||
const duration = getMeasure('start', 'end');
|
||||
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
if (process.env.NODE_ENV === 'development' && duration) {
|
||||
console.log(`⏱️ Skeleton Display Duration: ${(duration / 1000).toFixed(2)}s`);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
}, [mark, getMeasure]);
|
||||
|
||||
const getSkeletonDuration = useCallback((): number | null => {
|
||||
if (skeletonEndRef.current) {
|
||||
return skeletonEndRef.current - skeletonStartRef.current;
|
||||
}
|
||||
return null;
|
||||
}, []);
|
||||
return getMeasure('start', 'end');
|
||||
}, [getMeasure]);
|
||||
|
||||
return {
|
||||
markSkeletonEnd,
|
||||
|
||||
129
src/hooks/usePerformanceTracker.ts
Normal file
129
src/hooks/usePerformanceTracker.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* React 性能追踪 Hooks
|
||||
* 封装 performanceMonitor 工具,提供 React 友好的性能追踪 API
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useCallback } from 'react';
|
||||
import { performanceMonitor } from '@utils/performanceMonitor';
|
||||
|
||||
/**
|
||||
* usePerformanceMark 返回值类型
|
||||
*/
|
||||
export interface UsePerformanceMarkReturn {
|
||||
/** 标记时间点 */
|
||||
mark: (suffix: string) => void;
|
||||
/** 测量并记录到 performanceMonitor */
|
||||
measure: (startSuffix: string, endSuffix: string, name?: string) => number | null;
|
||||
/** 获取测量值(不记录) */
|
||||
getMeasure: (startSuffix: string, endSuffix: string) => number | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* usePerformanceTracker - 自动追踪组件渲染性能
|
||||
*
|
||||
* @param componentName - 组件名称,用于标记
|
||||
* @param options - 配置选项
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* function MyComponent() {
|
||||
* usePerformanceTracker('MyComponent');
|
||||
* return <div>...</div>;
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* 自动标记:
|
||||
* - {componentName}-mount: 组件挂载时
|
||||
* - {componentName}-rendered: 首次渲染完成
|
||||
* - {componentName}-unmount: 组件卸载时
|
||||
*/
|
||||
export function usePerformanceTracker(
|
||||
componentName: string,
|
||||
options: { trackRender?: boolean } = {}
|
||||
): void {
|
||||
const { trackRender = true } = options;
|
||||
const hasMounted = useRef(false);
|
||||
|
||||
// 首次渲染时立即标记(同步)
|
||||
if (!hasMounted.current) {
|
||||
performanceMonitor.mark(`${componentName}-mount`);
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasMounted.current) {
|
||||
hasMounted.current = true;
|
||||
|
||||
// 渲染完成标记(在 useEffect 中,表示 DOM 已更新)
|
||||
if (trackRender) {
|
||||
performanceMonitor.mark(`${componentName}-rendered`);
|
||||
performanceMonitor.measure(
|
||||
`${componentName}-mount`,
|
||||
`${componentName}-rendered`,
|
||||
`${componentName} 渲染`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// 组件卸载时标记
|
||||
return () => {
|
||||
performanceMonitor.mark(`${componentName}-unmount`);
|
||||
};
|
||||
}, [componentName, trackRender]);
|
||||
}
|
||||
|
||||
/**
|
||||
* usePerformanceMark - 手动标记自定义操作的性能
|
||||
*
|
||||
* @param prefix - 标记前缀,用于区分不同操作
|
||||
* @returns 包含 mark、measure、getMeasure 方法的对象
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* function MyComponent() {
|
||||
* const { mark, getMeasure } = usePerformanceMark('api-call');
|
||||
*
|
||||
* const handleFetch = async () => {
|
||||
* mark('start');
|
||||
* await fetchData();
|
||||
* mark('end');
|
||||
* const duration = getMeasure('start', 'end');
|
||||
* console.log('API耗时:', duration, 'ms');
|
||||
* };
|
||||
*
|
||||
* return <button onClick={handleFetch}>加载数据</button>;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function usePerformanceMark(prefix: string): UsePerformanceMarkReturn {
|
||||
const mark = useCallback(
|
||||
(suffix: string) => {
|
||||
performanceMonitor.mark(`${prefix}-${suffix}`);
|
||||
},
|
||||
[prefix]
|
||||
);
|
||||
|
||||
const measure = useCallback(
|
||||
(startSuffix: string, endSuffix: string, name?: string) => {
|
||||
return performanceMonitor.measure(
|
||||
`${prefix}-${startSuffix}`,
|
||||
`${prefix}-${endSuffix}`,
|
||||
name || `${prefix}: ${startSuffix} → ${endSuffix}`
|
||||
);
|
||||
},
|
||||
[prefix]
|
||||
);
|
||||
|
||||
const getMeasure = useCallback(
|
||||
(startSuffix: string, endSuffix: string) => {
|
||||
return performanceMonitor.measure(
|
||||
`${prefix}-${startSuffix}`,
|
||||
`${prefix}-${endSuffix}`
|
||||
);
|
||||
},
|
||||
[prefix]
|
||||
);
|
||||
|
||||
return { mark, measure, getMeasure };
|
||||
}
|
||||
|
||||
export default usePerformanceTracker;
|
||||
@@ -1,244 +0,0 @@
|
||||
// src/hooks/useSearchEvents.js
|
||||
// 全局搜索功能事件追踪 Hook
|
||||
|
||||
import { useCallback } from 'react';
|
||||
import { usePostHogTrack } from './usePostHogRedux';
|
||||
import { RETENTION_EVENTS } from '../lib/constants';
|
||||
import { logger } from '../utils/logger';
|
||||
|
||||
/**
|
||||
* 全局搜索事件追踪 Hook
|
||||
* @param {Object} options - 配置选项
|
||||
* @param {string} options.context - 搜索上下文 ('global' | 'stock' | 'news' | 'concept' | 'simulation')
|
||||
* @returns {Object} 事件追踪处理函数集合
|
||||
*/
|
||||
export const useSearchEvents = ({ context = 'global' } = {}) => {
|
||||
const { track } = usePostHogTrack();
|
||||
|
||||
/**
|
||||
* 追踪搜索开始(聚焦搜索框)
|
||||
* @param {string} placeholder - 搜索框提示文本
|
||||
*/
|
||||
const trackSearchInitiated = useCallback((placeholder = '') => {
|
||||
track(RETENTION_EVENTS.SEARCH_INITIATED, {
|
||||
context,
|
||||
placeholder,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '🔍 Search Initiated', {
|
||||
context,
|
||||
placeholder,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索查询提交
|
||||
* @param {string} query - 搜索查询词
|
||||
* @param {number} resultCount - 搜索结果数量
|
||||
* @param {Object} filters - 应用的筛选条件
|
||||
*/
|
||||
const trackSearchQuerySubmitted = useCallback((query, resultCount = 0, filters = {}) => {
|
||||
if (!query) {
|
||||
logger.warn('useSearchEvents', 'trackSearchQuerySubmitted: query is required');
|
||||
return;
|
||||
}
|
||||
|
||||
track(RETENTION_EVENTS.SEARCH_QUERY_SUBMITTED, {
|
||||
query,
|
||||
query_length: query.length,
|
||||
result_count: resultCount,
|
||||
has_results: resultCount > 0,
|
||||
context,
|
||||
filters: filters,
|
||||
filter_count: Object.keys(filters).length,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
// 如果没有搜索结果,额外追踪
|
||||
if (resultCount === 0) {
|
||||
track(RETENTION_EVENTS.SEARCH_NO_RESULTS, {
|
||||
query,
|
||||
context,
|
||||
filters,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '❌ Search No Results', {
|
||||
query,
|
||||
context,
|
||||
});
|
||||
} else {
|
||||
logger.debug('useSearchEvents', '✅ Search Query Submitted', {
|
||||
query,
|
||||
resultCount,
|
||||
context,
|
||||
});
|
||||
}
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索结果点击
|
||||
* @param {Object} result - 被点击的搜索结果
|
||||
* @param {string} result.type - 结果类型 ('stock' | 'news' | 'concept' | 'event')
|
||||
* @param {string} result.id - 结果ID
|
||||
* @param {string} result.title - 结果标题
|
||||
* @param {number} position - 在搜索结果中的位置
|
||||
* @param {string} query - 搜索查询词
|
||||
*/
|
||||
const trackSearchResultClicked = useCallback((result, position = 0, query = '') => {
|
||||
if (!result || !result.type) {
|
||||
logger.warn('useSearchEvents', 'trackSearchResultClicked: result object with type is required');
|
||||
return;
|
||||
}
|
||||
|
||||
track(RETENTION_EVENTS.SEARCH_RESULT_CLICKED, {
|
||||
result_type: result.type,
|
||||
result_id: result.id || result.code || '',
|
||||
result_title: result.title || result.name || '',
|
||||
position,
|
||||
query,
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '🎯 Search Result Clicked', {
|
||||
type: result.type,
|
||||
id: result.id || result.code,
|
||||
position,
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索筛选应用
|
||||
* @param {Object} filters - 应用的筛选条件
|
||||
* @param {string} filterType - 筛选类型 ('sort' | 'category' | 'date_range' | 'price_range')
|
||||
* @param {any} filterValue - 筛选值
|
||||
*/
|
||||
const trackSearchFilterApplied = useCallback((filterType, filterValue, filters = {}) => {
|
||||
if (!filterType) {
|
||||
logger.warn('useSearchEvents', 'trackSearchFilterApplied: filterType is required');
|
||||
return;
|
||||
}
|
||||
|
||||
track(RETENTION_EVENTS.SEARCH_FILTER_APPLIED, {
|
||||
filter_type: filterType,
|
||||
filter_value: String(filterValue),
|
||||
all_filters: filters,
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '🔍 Search Filter Applied', {
|
||||
filterType,
|
||||
filterValue,
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索建议点击(自动完成)
|
||||
* @param {string} suggestion - 被点击的搜索建议
|
||||
* @param {number} position - 在建议列表中的位置
|
||||
* @param {string} source - 建议来源 ('history' | 'popular' | 'related')
|
||||
*/
|
||||
const trackSearchSuggestionClicked = useCallback((suggestion, position = 0, source = 'popular') => {
|
||||
if (!suggestion) {
|
||||
logger.warn('useSearchEvents', 'trackSearchSuggestionClicked: suggestion is required');
|
||||
return;
|
||||
}
|
||||
|
||||
track('Search Suggestion Clicked', {
|
||||
suggestion,
|
||||
position,
|
||||
source,
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '💡 Search Suggestion Clicked', {
|
||||
suggestion,
|
||||
position,
|
||||
source,
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索历史查看
|
||||
* @param {number} historyCount - 历史记录数量
|
||||
*/
|
||||
const trackSearchHistoryViewed = useCallback((historyCount = 0) => {
|
||||
track('Search History Viewed', {
|
||||
history_count: historyCount,
|
||||
has_history: historyCount > 0,
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '📜 Search History Viewed', {
|
||||
historyCount,
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪搜索历史清除
|
||||
*/
|
||||
const trackSearchHistoryCleared = useCallback(() => {
|
||||
track('Search History Cleared', {
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '🗑️ Search History Cleared', {
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
/**
|
||||
* 追踪热门搜索词点击
|
||||
* @param {string} keyword - 被点击的热门关键词
|
||||
* @param {number} position - 在列表中的位置
|
||||
* @param {number} heatScore - 热度分数
|
||||
*/
|
||||
const trackPopularKeywordClicked = useCallback((keyword, position = 0, heatScore = 0) => {
|
||||
if (!keyword) {
|
||||
logger.warn('useSearchEvents', 'trackPopularKeywordClicked: keyword is required');
|
||||
return;
|
||||
}
|
||||
|
||||
track('Popular Keyword Clicked', {
|
||||
keyword,
|
||||
position,
|
||||
heat_score: heatScore,
|
||||
context,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
logger.debug('useSearchEvents', '🔥 Popular Keyword Clicked', {
|
||||
keyword,
|
||||
position,
|
||||
context,
|
||||
});
|
||||
}, [track, context]);
|
||||
|
||||
return {
|
||||
// 搜索流程事件
|
||||
trackSearchInitiated,
|
||||
trackSearchQuerySubmitted,
|
||||
trackSearchResultClicked,
|
||||
|
||||
// 筛选和建议
|
||||
trackSearchFilterApplied,
|
||||
trackSearchSuggestionClicked,
|
||||
|
||||
// 历史和热门
|
||||
trackSearchHistoryViewed,
|
||||
trackSearchHistoryCleared,
|
||||
trackPopularKeywordClicked,
|
||||
};
|
||||
};
|
||||
|
||||
export default useSearchEvents;
|
||||
11
src/index.js
11
src/index.js
@@ -5,6 +5,17 @@ import { BrowserRouter as Router } from 'react-router-dom';
|
||||
|
||||
// ⚡ 性能监控:在应用启动时尽早标记
|
||||
import { performanceMonitor } from './utils/performanceMonitor';
|
||||
|
||||
// T0: HTML 加载完成时间点
|
||||
if (document.readyState === 'complete') {
|
||||
performanceMonitor.mark('html-loaded');
|
||||
} else {
|
||||
window.addEventListener('load', () => {
|
||||
performanceMonitor.mark('html-loaded');
|
||||
});
|
||||
}
|
||||
|
||||
// T1: React 开始初始化
|
||||
performanceMonitor.mark('app-start');
|
||||
|
||||
// ⚡ 已删除 brainwave.css(项目未安装 Tailwind CSS,该文件无效)
|
||||
|
||||
@@ -128,6 +128,13 @@ export const mockRealtimeQuotes = [
|
||||
|
||||
// ==================== 关注事件数据 ====================
|
||||
|
||||
// 事件关注内存存储(Set 存储已关注的事件 ID)
|
||||
export const followedEventsSet = new Set();
|
||||
|
||||
// 关注事件完整数据存储(Map: eventId -> eventData)
|
||||
export const followedEventsMap = new Map();
|
||||
|
||||
// 初始关注事件列表(用于初始化)
|
||||
export const mockFollowingEvents = [
|
||||
{
|
||||
id: 101,
|
||||
@@ -231,6 +238,74 @@ export const mockFollowingEvents = [
|
||||
}
|
||||
];
|
||||
|
||||
// 初始化:将 mockFollowingEvents 的数据加入内存存储
|
||||
mockFollowingEvents.forEach(event => {
|
||||
followedEventsSet.add(event.id);
|
||||
followedEventsMap.set(event.id, event);
|
||||
});
|
||||
|
||||
/**
|
||||
* 切换事件关注状态
|
||||
* @param {number} eventId - 事件 ID
|
||||
* @param {Object} eventData - 事件数据(关注时需要)
|
||||
* @returns {{ isFollowing: boolean, followerCount: number }}
|
||||
*/
|
||||
export function toggleEventFollowStatus(eventId, eventData = null) {
|
||||
const wasFollowing = followedEventsSet.has(eventId);
|
||||
|
||||
if (wasFollowing) {
|
||||
// 取消关注
|
||||
followedEventsSet.delete(eventId);
|
||||
followedEventsMap.delete(eventId);
|
||||
} else {
|
||||
// 添加关注
|
||||
followedEventsSet.add(eventId);
|
||||
if (eventData) {
|
||||
followedEventsMap.set(eventId, {
|
||||
...eventData,
|
||||
followed_at: new Date().toISOString()
|
||||
});
|
||||
} else {
|
||||
// 如果没有提供事件数据,创建基础数据
|
||||
followedEventsMap.set(eventId, {
|
||||
id: eventId,
|
||||
title: `事件 ${eventId}`,
|
||||
tags: [],
|
||||
followed_at: new Date().toISOString()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const isFollowing = !wasFollowing;
|
||||
const followerCount = isFollowing ? Math.floor(Math.random() * 500) + 100 : Math.floor(Math.random() * 500) + 50;
|
||||
|
||||
console.log('[Mock Data] 切换事件关注状态:', {
|
||||
eventId,
|
||||
wasFollowing,
|
||||
isFollowing,
|
||||
followedEventsCount: followedEventsSet.size
|
||||
});
|
||||
|
||||
return { isFollowing, followerCount };
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查事件是否已关注
|
||||
* @param {number} eventId - 事件 ID
|
||||
* @returns {boolean}
|
||||
*/
|
||||
export function isEventFollowed(eventId) {
|
||||
return followedEventsSet.has(eventId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已关注的事件列表
|
||||
* @returns {Array}
|
||||
*/
|
||||
export function getFollowedEvents() {
|
||||
return Array.from(followedEventsMap.values());
|
||||
}
|
||||
|
||||
// ==================== 评论数据 ====================
|
||||
|
||||
export const mockEventComments = [
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -874,8 +874,20 @@ export function generateMockEvents(params = {}) {
|
||||
e.title.toLowerCase().includes(query) ||
|
||||
e.description.toLowerCase().includes(query) ||
|
||||
// keywords 是对象数组 { concept, score, ... },需要访问 concept 属性
|
||||
e.keywords.some(k => k.concept && k.concept.toLowerCase().includes(query))
|
||||
e.keywords.some(k => k.concept && k.concept.toLowerCase().includes(query)) ||
|
||||
// 搜索 related_stocks 中的股票名称和代码
|
||||
(e.related_stocks && e.related_stocks.some(stock =>
|
||||
(stock.stock_name && stock.stock_name.toLowerCase().includes(query)) ||
|
||||
(stock.stock_code && stock.stock_code.toLowerCase().includes(query))
|
||||
)) ||
|
||||
// 搜索行业
|
||||
(e.industry && e.industry.toLowerCase().includes(query))
|
||||
);
|
||||
|
||||
// 如果搜索结果为空,返回所有事件(宽松模式)
|
||||
if (filteredEvents.length === 0) {
|
||||
filteredEvents = allEvents;
|
||||
}
|
||||
}
|
||||
|
||||
// 行业筛选
|
||||
@@ -1042,7 +1054,7 @@ function generateTransmissionChain(industry, index) {
|
||||
|
||||
let nodeName;
|
||||
if (nodeType === 'company' && industryStock) {
|
||||
nodeName = industryStock.name;
|
||||
nodeName = industryStock.stock_name;
|
||||
} else if (nodeType === 'industry') {
|
||||
nodeName = `${industry}产业`;
|
||||
} else if (nodeType === 'policy') {
|
||||
@@ -1133,7 +1145,7 @@ export function generateDynamicNewsEvents(timeRange = null, count = 30) {
|
||||
const stock = industryStocks[j % industryStocks.length];
|
||||
relatedStocks.push({
|
||||
stock_code: stock.stock_code,
|
||||
stock_name: stock.name,
|
||||
stock_name: stock.stock_name,
|
||||
relation_desc: relationDescriptions[j % relationDescriptions.length]
|
||||
});
|
||||
}
|
||||
@@ -1145,7 +1157,7 @@ export function generateDynamicNewsEvents(timeRange = null, count = 30) {
|
||||
if (!relatedStocks.some(s => s.stock_code === randomStock.stock_code)) {
|
||||
relatedStocks.push({
|
||||
stock_code: randomStock.stock_code,
|
||||
stock_name: randomStock.name,
|
||||
stock_name: randomStock.stock_name,
|
||||
relation_desc: relationDescriptions[relatedStocks.length % relationDescriptions.length]
|
||||
});
|
||||
}
|
||||
|
||||
@@ -10,73 +10,323 @@ export const generateFinancialData = (stockCode) => {
|
||||
|
||||
// 股票基本信息
|
||||
stockInfo: {
|
||||
code: stockCode,
|
||||
name: stockCode === '000001' ? '平安银行' : '示例公司',
|
||||
stock_code: stockCode,
|
||||
stock_name: stockCode === '000001' ? '平安银行' : '示例公司',
|
||||
industry: stockCode === '000001' ? '银行' : '制造业',
|
||||
list_date: '1991-04-03',
|
||||
market: 'SZ'
|
||||
market: 'SZ',
|
||||
// 关键指标
|
||||
key_metrics: {
|
||||
eps: 2.72,
|
||||
roe: 16.23,
|
||||
gross_margin: 71.92,
|
||||
net_margin: 32.56,
|
||||
roa: 1.05
|
||||
},
|
||||
// 增长率
|
||||
growth_rates: {
|
||||
revenue_growth: 8.2,
|
||||
profit_growth: 12.5,
|
||||
asset_growth: 5.6,
|
||||
equity_growth: 6.8
|
||||
},
|
||||
// 财务概要
|
||||
financial_summary: {
|
||||
revenue: 162350,
|
||||
net_profit: 52860,
|
||||
total_assets: 5024560,
|
||||
total_liabilities: 4698880
|
||||
},
|
||||
// 最新业绩预告
|
||||
latest_forecast: {
|
||||
forecast_type: '预增',
|
||||
content: '预计全年净利润同比增长10%-17%'
|
||||
}
|
||||
},
|
||||
|
||||
// 资产负债表
|
||||
// 资产负债表 - 嵌套结构
|
||||
balanceSheet: periods.map((period, i) => ({
|
||||
period,
|
||||
total_assets: 5024560 - i * 50000, // 百万元
|
||||
total_liabilities: 4698880 - i * 48000,
|
||||
shareholders_equity: 325680 - i * 2000,
|
||||
current_assets: 2512300 - i * 25000,
|
||||
non_current_assets: 2512260 - i * 25000,
|
||||
current_liabilities: 3456780 - i * 35000,
|
||||
non_current_liabilities: 1242100 - i * 13000
|
||||
assets: {
|
||||
current_assets: {
|
||||
cash: 856780 - i * 10000,
|
||||
trading_financial_assets: 234560 - i * 5000,
|
||||
notes_receivable: 12340 - i * 200,
|
||||
accounts_receivable: 45670 - i * 1000,
|
||||
prepayments: 8900 - i * 100,
|
||||
other_receivables: 23450 - i * 500,
|
||||
inventory: 156780 - i * 3000,
|
||||
contract_assets: 34560 - i * 800,
|
||||
other_current_assets: 67890 - i * 1500,
|
||||
total: 2512300 - i * 25000
|
||||
},
|
||||
non_current_assets: {
|
||||
long_term_equity_investments: 234560 - i * 5000,
|
||||
investment_property: 45670 - i * 1000,
|
||||
fixed_assets: 678900 - i * 15000,
|
||||
construction_in_progress: 123450 - i * 3000,
|
||||
right_of_use_assets: 34560 - i * 800,
|
||||
intangible_assets: 89012 - i * 2000,
|
||||
goodwill: 45670 - i * 1000,
|
||||
deferred_tax_assets: 12340 - i * 300,
|
||||
other_non_current_assets: 67890 - i * 1500,
|
||||
total: 2512260 - i * 25000
|
||||
},
|
||||
total: 5024560 - i * 50000
|
||||
},
|
||||
liabilities: {
|
||||
current_liabilities: {
|
||||
short_term_borrowings: 456780 - i * 10000,
|
||||
notes_payable: 23450 - i * 500,
|
||||
accounts_payable: 234560 - i * 5000,
|
||||
advance_receipts: 12340 - i * 300,
|
||||
contract_liabilities: 34560 - i * 800,
|
||||
employee_compensation_payable: 45670 - i * 1000,
|
||||
taxes_payable: 23450 - i * 500,
|
||||
other_payables: 78900 - i * 1500,
|
||||
non_current_liabilities_due_within_one_year: 89012 - i * 2000,
|
||||
total: 3456780 - i * 35000
|
||||
},
|
||||
non_current_liabilities: {
|
||||
long_term_borrowings: 678900 - i * 15000,
|
||||
bonds_payable: 234560 - i * 5000,
|
||||
lease_liabilities: 45670 - i * 1000,
|
||||
deferred_tax_liabilities: 12340 - i * 300,
|
||||
other_non_current_liabilities: 89012 - i * 2000,
|
||||
total: 1242100 - i * 13000
|
||||
},
|
||||
total: 4698880 - i * 48000
|
||||
},
|
||||
equity: {
|
||||
share_capital: 19405,
|
||||
capital_reserve: 89012 - i * 2000,
|
||||
surplus_reserve: 45670 - i * 1000,
|
||||
undistributed_profit: 156780 - i * 3000,
|
||||
treasury_stock: 0,
|
||||
other_comprehensive_income: 12340 - i * 300,
|
||||
parent_company_equity: 315680 - i * 1800,
|
||||
minority_interests: 10000 - i * 200,
|
||||
total: 325680 - i * 2000
|
||||
}
|
||||
})),
|
||||
|
||||
// 利润表
|
||||
// 利润表 - 嵌套结构
|
||||
incomeStatement: periods.map((period, i) => ({
|
||||
period,
|
||||
revenue: 162350 - i * 4000, // 百万元
|
||||
operating_cost: 45620 - i * 1200,
|
||||
gross_profit: 116730 - i * 2800,
|
||||
operating_profit: 68450 - i * 1500,
|
||||
net_profit: 52860 - i * 1200,
|
||||
eps: 2.72 - i * 0.06
|
||||
revenue: {
|
||||
total_operating_revenue: 162350 - i * 4000,
|
||||
operating_revenue: 158900 - i * 3900,
|
||||
other_income: 3450 - i * 100
|
||||
},
|
||||
costs: {
|
||||
total_operating_cost: 93900 - i * 2500,
|
||||
operating_cost: 45620 - i * 1200,
|
||||
taxes_and_surcharges: 4560 - i * 100,
|
||||
selling_expenses: 12340 - i * 300,
|
||||
admin_expenses: 15670 - i * 400,
|
||||
rd_expenses: 8900 - i * 200,
|
||||
financial_expenses: 6810 - i * 300,
|
||||
interest_expense: 8900 - i * 200,
|
||||
interest_income: 2090 - i * 50,
|
||||
three_expenses_total: 34820 - i * 1000,
|
||||
four_expenses_total: 43720 - i * 1200,
|
||||
asset_impairment_loss: 1200 - i * 50,
|
||||
credit_impairment_loss: 2340 - i * 100
|
||||
},
|
||||
other_gains: {
|
||||
fair_value_change: 1230 - i * 50,
|
||||
investment_income: 3450 - i * 100,
|
||||
investment_income_from_associates: 890 - i * 20,
|
||||
exchange_income: 560 - i * 10,
|
||||
asset_disposal_income: 340 - i * 10
|
||||
},
|
||||
profit: {
|
||||
operating_profit: 68450 - i * 1500,
|
||||
total_profit: 69500 - i * 1500,
|
||||
income_tax_expense: 16640 - i * 300,
|
||||
net_profit: 52860 - i * 1200,
|
||||
parent_net_profit: 51200 - i * 1150,
|
||||
minority_profit: 1660 - i * 50,
|
||||
continuing_operations_net_profit: 52860 - i * 1200,
|
||||
discontinued_operations_net_profit: 0
|
||||
},
|
||||
non_operating: {
|
||||
non_operating_income: 1050 - i * 20,
|
||||
non_operating_expenses: 450 - i * 10
|
||||
},
|
||||
per_share: {
|
||||
basic_eps: 2.72 - i * 0.06,
|
||||
diluted_eps: 2.70 - i * 0.06
|
||||
},
|
||||
comprehensive_income: {
|
||||
other_comprehensive_income: 890 - i * 20,
|
||||
total_comprehensive_income: 53750 - i * 1220,
|
||||
parent_comprehensive_income: 52050 - i * 1170,
|
||||
minority_comprehensive_income: 1700 - i * 50
|
||||
}
|
||||
})),
|
||||
|
||||
// 现金流量表
|
||||
// 现金流量表 - 嵌套结构
|
||||
cashflow: periods.map((period, i) => ({
|
||||
period,
|
||||
operating_cashflow: 125600 - i * 3000, // 百万元
|
||||
investing_cashflow: -45300 - i * 1000,
|
||||
financing_cashflow: -38200 + i * 500,
|
||||
net_cashflow: 42100 - i * 1500,
|
||||
cash_ending: 456780 - i * 10000
|
||||
operating_activities: {
|
||||
inflow: {
|
||||
cash_from_sales: 178500 - i * 4500
|
||||
},
|
||||
outflow: {
|
||||
cash_for_goods: 52900 - i * 1500
|
||||
},
|
||||
net_flow: 125600 - i * 3000
|
||||
},
|
||||
investment_activities: {
|
||||
net_flow: -45300 - i * 1000
|
||||
},
|
||||
financing_activities: {
|
||||
net_flow: -38200 + i * 500
|
||||
},
|
||||
cash_changes: {
|
||||
net_increase: 42100 - i * 1500,
|
||||
ending_balance: 456780 - i * 10000
|
||||
},
|
||||
key_metrics: {
|
||||
free_cash_flow: 80300 - i * 2000
|
||||
}
|
||||
})),
|
||||
|
||||
// 财务指标
|
||||
// 财务指标 - 嵌套结构
|
||||
financialMetrics: periods.map((period, i) => ({
|
||||
period,
|
||||
roe: 16.23 - i * 0.3, // %
|
||||
roa: 1.05 - i * 0.02,
|
||||
gross_margin: 71.92 - i * 0.5,
|
||||
net_margin: 32.56 - i * 0.3,
|
||||
current_ratio: 0.73 + i * 0.01,
|
||||
quick_ratio: 0.71 + i * 0.01,
|
||||
debt_ratio: 93.52 + i * 0.05,
|
||||
asset_turnover: 0.41 - i * 0.01,
|
||||
inventory_turnover: 0, // 银行无库存
|
||||
receivable_turnover: 0 // 银行特殊
|
||||
profitability: {
|
||||
roe: 16.23 - i * 0.3,
|
||||
roe_deducted: 15.89 - i * 0.3,
|
||||
roe_weighted: 16.45 - i * 0.3,
|
||||
roa: 1.05 - i * 0.02,
|
||||
gross_margin: 71.92 - i * 0.5,
|
||||
net_profit_margin: 32.56 - i * 0.3,
|
||||
operating_profit_margin: 42.16 - i * 0.4,
|
||||
cost_profit_ratio: 115.8 - i * 1.2,
|
||||
ebit: 86140 - i * 1800
|
||||
},
|
||||
per_share_metrics: {
|
||||
eps: 2.72 - i * 0.06,
|
||||
basic_eps: 2.72 - i * 0.06,
|
||||
diluted_eps: 2.70 - i * 0.06,
|
||||
deducted_eps: 2.65 - i * 0.06,
|
||||
bvps: 16.78 - i * 0.1,
|
||||
operating_cash_flow_ps: 6.47 - i * 0.15,
|
||||
capital_reserve_ps: 4.59 - i * 0.1,
|
||||
undistributed_profit_ps: 8.08 - i * 0.15
|
||||
},
|
||||
growth: {
|
||||
revenue_growth: 8.2 - i * 0.5,
|
||||
net_profit_growth: 12.5 - i * 0.8,
|
||||
deducted_profit_growth: 11.8 - i * 0.7,
|
||||
parent_profit_growth: 12.3 - i * 0.75,
|
||||
operating_cash_flow_growth: 15.6 - i * 1.0,
|
||||
total_asset_growth: 5.6 - i * 0.3,
|
||||
equity_growth: 6.8 - i * 0.4,
|
||||
fixed_asset_growth: 4.2 - i * 0.2
|
||||
},
|
||||
operational_efficiency: {
|
||||
total_asset_turnover: 0.41 - i * 0.01,
|
||||
fixed_asset_turnover: 2.35 - i * 0.05,
|
||||
current_asset_turnover: 0.82 - i * 0.02,
|
||||
receivable_turnover: 12.5 - i * 0.3,
|
||||
receivable_days: 29.2 + i * 0.7,
|
||||
inventory_turnover: 0, // 银行无库存
|
||||
inventory_days: 0,
|
||||
working_capital_turnover: 1.68 - i * 0.04
|
||||
},
|
||||
solvency: {
|
||||
current_ratio: 0.73 + i * 0.01,
|
||||
quick_ratio: 0.71 + i * 0.01,
|
||||
cash_ratio: 0.25 + i * 0.005,
|
||||
conservative_quick_ratio: 0.68 + i * 0.01,
|
||||
asset_liability_ratio: 93.52 + i * 0.05,
|
||||
interest_coverage: 8.56 - i * 0.2,
|
||||
cash_to_maturity_debt_ratio: 0.45 - i * 0.01,
|
||||
tangible_asset_debt_ratio: 94.12 + i * 0.05
|
||||
},
|
||||
expense_ratios: {
|
||||
selling_expense_ratio: 7.60 + i * 0.1,
|
||||
admin_expense_ratio: 9.65 + i * 0.1,
|
||||
financial_expense_ratio: 4.19 + i * 0.1,
|
||||
rd_expense_ratio: 5.48 + i * 0.1,
|
||||
three_expense_ratio: 21.44 + i * 0.3,
|
||||
four_expense_ratio: 26.92 + i * 0.4,
|
||||
cost_ratio: 28.10 + i * 0.2
|
||||
}
|
||||
})),
|
||||
|
||||
// 主营业务
|
||||
// 主营业务 - 按产品/业务分类
|
||||
mainBusiness: {
|
||||
by_product: [
|
||||
{ name: '对公业务', revenue: 68540, ratio: 42.2, yoy_growth: 6.8 },
|
||||
{ name: '零售业务', revenue: 81320, ratio: 50.1, yoy_growth: 11.2 },
|
||||
{ name: '金融市场业务', revenue: 12490, ratio: 7.7, yoy_growth: 3.5 }
|
||||
product_classification: [
|
||||
{
|
||||
period: '2024-09-30',
|
||||
report_type: '2024年三季报',
|
||||
products: [
|
||||
{ content: '零售金融业务', revenue: 81320000000, gross_margin: 68.5, profit_margin: 42.3, profit: 34398160000 },
|
||||
{ content: '对公金融业务', revenue: 68540000000, gross_margin: 62.8, profit_margin: 38.6, profit: 26456440000 },
|
||||
{ content: '金融市场业务', revenue: 12490000000, gross_margin: 75.2, profit_margin: 52.1, profit: 6507290000 },
|
||||
{ content: '合计', revenue: 162350000000, gross_margin: 67.5, profit_margin: 41.2, profit: 66883200000 },
|
||||
]
|
||||
},
|
||||
{
|
||||
period: '2024-06-30',
|
||||
report_type: '2024年中报',
|
||||
products: [
|
||||
{ content: '零售金融业务', revenue: 78650000000, gross_margin: 67.8, profit_margin: 41.5, profit: 32639750000 },
|
||||
{ content: '对公金融业务', revenue: 66280000000, gross_margin: 61.9, profit_margin: 37.8, profit: 25053840000 },
|
||||
{ content: '金融市场业务', revenue: 11870000000, gross_margin: 74.5, profit_margin: 51.2, profit: 6077440000 },
|
||||
{ content: '合计', revenue: 156800000000, gross_margin: 66.8, profit_margin: 40.5, profit: 63504000000 },
|
||||
]
|
||||
},
|
||||
{
|
||||
period: '2024-03-31',
|
||||
report_type: '2024年一季报',
|
||||
products: [
|
||||
{ content: '零售金融业务', revenue: 38920000000, gross_margin: 67.2, profit_margin: 40.8, profit: 15879360000 },
|
||||
{ content: '对公金融业务', revenue: 32650000000, gross_margin: 61.2, profit_margin: 37.1, profit: 12113150000 },
|
||||
{ content: '金融市场业务', revenue: 5830000000, gross_margin: 73.8, profit_margin: 50.5, profit: 2944150000 },
|
||||
{ content: '合计', revenue: 77400000000, gross_margin: 66.1, profit_margin: 39.8, profit: 30805200000 },
|
||||
]
|
||||
},
|
||||
{
|
||||
period: '2023-12-31',
|
||||
report_type: '2023年年报',
|
||||
products: [
|
||||
{ content: '零售金融业务', revenue: 152680000000, gross_margin: 66.5, profit_margin: 40.2, profit: 61377360000 },
|
||||
{ content: '对公金融业务', revenue: 128450000000, gross_margin: 60.5, profit_margin: 36.5, profit: 46884250000 },
|
||||
{ content: '金融市场业务', revenue: 22870000000, gross_margin: 73.2, profit_margin: 49.8, profit: 11389260000 },
|
||||
{ content: '合计', revenue: 304000000000, gross_margin: 65.2, profit_margin: 39.2, profit: 119168000000 },
|
||||
]
|
||||
},
|
||||
],
|
||||
by_region: [
|
||||
{ name: '华南地区', revenue: 56800, ratio: 35.0, yoy_growth: 9.2 },
|
||||
{ name: '华东地区', revenue: 48705, ratio: 30.0, yoy_growth: 8.5 },
|
||||
{ name: '华北地区', revenue: 32470, ratio: 20.0, yoy_growth: 7.8 },
|
||||
{ name: '其他地区', revenue: 24375, ratio: 15.0, yoy_growth: 6.5 }
|
||||
industry_classification: [
|
||||
{
|
||||
period: '2024-09-30',
|
||||
report_type: '2024年三季报',
|
||||
industries: [
|
||||
{ content: '华南地区', revenue: 56817500000, gross_margin: 69.2, profit_margin: 43.5, profit: 24715612500 },
|
||||
{ content: '华东地区', revenue: 48705000000, gross_margin: 67.8, profit_margin: 41.2, profit: 20066460000 },
|
||||
{ content: '华北地区', revenue: 32470000000, gross_margin: 65.5, profit_margin: 38.8, profit: 12598360000 },
|
||||
{ content: '西南地区', revenue: 16235000000, gross_margin: 64.2, profit_margin: 37.5, profit: 6088125000 },
|
||||
{ content: '其他地区', revenue: 8122500000, gross_margin: 62.8, profit_margin: 35.2, profit: 2859120000 },
|
||||
{ content: '合计', revenue: 162350000000, gross_margin: 67.5, profit_margin: 41.2, profit: 66883200000 },
|
||||
]
|
||||
},
|
||||
{
|
||||
period: '2024-06-30',
|
||||
report_type: '2024年中报',
|
||||
industries: [
|
||||
{ content: '华南地区', revenue: 54880000000, gross_margin: 68.5, profit_margin: 42.8, profit: 23488640000 },
|
||||
{ content: '华东地区', revenue: 47040000000, gross_margin: 67.1, profit_margin: 40.5, profit: 19051200000 },
|
||||
{ content: '华北地区', revenue: 31360000000, gross_margin: 64.8, profit_margin: 38.1, profit: 11948160000 },
|
||||
{ content: '西南地区', revenue: 15680000000, gross_margin: 63.5, profit_margin: 36.8, profit: 5770240000 },
|
||||
{ content: '其他地区', revenue: 7840000000, gross_margin: 62.1, profit_margin: 34.5, profit: 2704800000 },
|
||||
{ content: '合计', revenue: 156800000000, gross_margin: 66.8, profit_margin: 40.5, profit: 63504000000 },
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
|
||||
@@ -92,48 +342,74 @@ export const generateFinancialData = (stockCode) => {
|
||||
publish_date: '2024-10-15'
|
||||
},
|
||||
|
||||
// 行业排名
|
||||
industryRank: {
|
||||
industry: '银行',
|
||||
total_companies: 42,
|
||||
rankings: [
|
||||
{ metric: '总资产', rank: 8, value: 5024560, percentile: 19 },
|
||||
{ metric: '营业收入', rank: 9, value: 162350, percentile: 21 },
|
||||
{ metric: '净利润', rank: 8, value: 52860, percentile: 19 },
|
||||
{ metric: 'ROE', rank: 12, value: 16.23, percentile: 29 },
|
||||
{ metric: '不良贷款率', rank: 18, value: 1.02, percentile: 43 }
|
||||
]
|
||||
},
|
||||
// 行业排名(数组格式,符合 IndustryRankingView 组件要求)
|
||||
industryRank: [
|
||||
{
|
||||
period: '2024-09-30',
|
||||
report_type: '三季报',
|
||||
rankings: [
|
||||
{
|
||||
industry_name: stockCode === '000001' ? '银行' : '制造业',
|
||||
level_description: '一级行业',
|
||||
metrics: {
|
||||
eps: { value: 2.72, rank: 8, industry_avg: 1.85 },
|
||||
bvps: { value: 15.23, rank: 12, industry_avg: 12.50 },
|
||||
roe: { value: 16.23, rank: 10, industry_avg: 12.00 },
|
||||
revenue_growth: { value: 8.2, rank: 15, industry_avg: 5.50 },
|
||||
profit_growth: { value: 12.5, rank: 9, industry_avg: 8.00 },
|
||||
operating_margin: { value: 32.56, rank: 6, industry_avg: 25.00 },
|
||||
debt_ratio: { value: 92.5, rank: 35, industry_avg: 88.00 },
|
||||
receivable_turnover: { value: 5.2, rank: 18, industry_avg: 4.80 }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
// 期间对比
|
||||
periodComparison: {
|
||||
periods: ['Q3-2024', 'Q2-2024', 'Q1-2024', 'Q4-2023'],
|
||||
metrics: [
|
||||
{
|
||||
name: '营业收入',
|
||||
unit: '百万元',
|
||||
values: [41500, 40800, 40200, 40850],
|
||||
yoy: [8.2, 7.8, 8.5, 9.2]
|
||||
},
|
||||
{
|
||||
name: '净利润',
|
||||
unit: '百万元',
|
||||
values: [13420, 13180, 13050, 13210],
|
||||
yoy: [12.5, 11.2, 10.8, 12.3]
|
||||
},
|
||||
{
|
||||
name: 'ROE',
|
||||
unit: '%',
|
||||
values: [16.23, 15.98, 15.75, 16.02],
|
||||
yoy: [1.2, 0.8, 0.5, 1.0]
|
||||
},
|
||||
{
|
||||
name: 'EPS',
|
||||
unit: '元',
|
||||
values: [0.69, 0.68, 0.67, 0.68],
|
||||
yoy: [12.3, 11.5, 10.5, 12.0]
|
||||
// 期间对比 - 营收与利润趋势数据
|
||||
periodComparison: [
|
||||
{
|
||||
period: '2024-09-30',
|
||||
performance: {
|
||||
revenue: 41500000000, // 415亿
|
||||
net_profit: 13420000000 // 134.2亿
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
period: '2024-06-30',
|
||||
performance: {
|
||||
revenue: 40800000000, // 408亿
|
||||
net_profit: 13180000000 // 131.8亿
|
||||
}
|
||||
},
|
||||
{
|
||||
period: '2024-03-31',
|
||||
performance: {
|
||||
revenue: 40200000000, // 402亿
|
||||
net_profit: 13050000000 // 130.5亿
|
||||
}
|
||||
},
|
||||
{
|
||||
period: '2023-12-31',
|
||||
performance: {
|
||||
revenue: 40850000000, // 408.5亿
|
||||
net_profit: 13210000000 // 132.1亿
|
||||
}
|
||||
},
|
||||
{
|
||||
period: '2023-09-30',
|
||||
performance: {
|
||||
revenue: 38500000000, // 385亿
|
||||
net_profit: 11920000000 // 119.2亿
|
||||
}
|
||||
},
|
||||
{
|
||||
period: '2023-06-30',
|
||||
performance: {
|
||||
revenue: 37800000000, // 378亿
|
||||
net_profit: 11850000000 // 118.5亿
|
||||
}
|
||||
}
|
||||
]
|
||||
};
|
||||
};
|
||||
|
||||
@@ -24,8 +24,9 @@ export const generateMarketData = (stockCode) => {
|
||||
low: parseFloat(low.toFixed(2)),
|
||||
volume: Math.floor(Math.random() * 500000000) + 100000000, // 1-6亿股
|
||||
amount: Math.floor(Math.random() * 7000000000) + 1300000000, // 13-80亿元
|
||||
turnover_rate: (Math.random() * 2 + 0.5).toFixed(2), // 0.5-2.5%
|
||||
change_pct: (Math.random() * 6 - 3).toFixed(2) // -3% to +3%
|
||||
turnover_rate: parseFloat((Math.random() * 2 + 0.5).toFixed(2)), // 0.5-2.5%
|
||||
change_percent: parseFloat((Math.random() * 6 - 3).toFixed(2)), // -3% to +3%
|
||||
pe_ratio: parseFloat((Math.random() * 3 + 4).toFixed(2)) // 4-7
|
||||
};
|
||||
})
|
||||
},
|
||||
@@ -78,73 +79,119 @@ export const generateMarketData = (stockCode) => {
|
||||
}))
|
||||
},
|
||||
|
||||
// 股权质押
|
||||
// 股权质押 - 匹配 PledgeData[] 类型
|
||||
pledgeData: {
|
||||
success: true,
|
||||
data: {
|
||||
total_pledged: 25.6, // 质押比例%
|
||||
major_shareholders: [
|
||||
{ name: '中国平安保险集团', pledged_shares: 0, total_shares: 10168542300, pledge_ratio: 0 },
|
||||
{ name: '深圳市投资控股', pledged_shares: 50000000, total_shares: 382456100, pledge_ratio: 13.08 }
|
||||
],
|
||||
update_date: '2024-09-30'
|
||||
}
|
||||
data: Array(12).fill(null).map((_, i) => {
|
||||
const date = new Date();
|
||||
date.setMonth(date.getMonth() - (11 - i));
|
||||
return {
|
||||
end_date: date.toISOString().split('T')[0].slice(0, 7) + '-01',
|
||||
unrestricted_pledge: Math.floor(Math.random() * 1000000000) + 500000000,
|
||||
restricted_pledge: Math.floor(Math.random() * 200000000) + 50000000,
|
||||
total_pledge: Math.floor(Math.random() * 1200000000) + 550000000,
|
||||
total_shares: 19405918198,
|
||||
pledge_ratio: parseFloat((Math.random() * 3 + 6).toFixed(2)), // 6-9%
|
||||
pledge_count: Math.floor(Math.random() * 50) + 100 // 100-150
|
||||
};
|
||||
})
|
||||
},
|
||||
|
||||
// 市场摘要
|
||||
// 市场摘要 - 匹配 MarketSummary 类型
|
||||
summaryData: {
|
||||
success: true,
|
||||
data: {
|
||||
current_price: basePrice,
|
||||
change: 0.25,
|
||||
change_pct: 1.89,
|
||||
open: 13.35,
|
||||
high: 13.68,
|
||||
low: 13.28,
|
||||
volume: 345678900,
|
||||
amount: 4678900000,
|
||||
turnover_rate: 1.78,
|
||||
pe_ratio: 4.96,
|
||||
pb_ratio: 0.72,
|
||||
total_market_cap: 262300000000,
|
||||
circulating_market_cap: 262300000000
|
||||
stock_code: stockCode,
|
||||
stock_name: stockCode === '000001' ? '平安银行' : '示例股票',
|
||||
latest_trade: {
|
||||
close: basePrice,
|
||||
change_percent: 1.89,
|
||||
volume: 345678900,
|
||||
amount: 4678900000,
|
||||
turnover_rate: 1.78,
|
||||
pe_ratio: 4.96
|
||||
},
|
||||
latest_funding: {
|
||||
financing_balance: 5823000000,
|
||||
securities_balance: 125600000
|
||||
},
|
||||
latest_pledge: {
|
||||
pledge_ratio: 8.25
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// 涨停分析
|
||||
// 涨停分析 - 返回数组格式,每个元素对应一个交易日
|
||||
riseAnalysisData: {
|
||||
success: true,
|
||||
data: {
|
||||
is_limit_up: false,
|
||||
limit_up_price: basePrice * 1.10,
|
||||
current_price: basePrice,
|
||||
distance_to_limit: 8.92, // %
|
||||
consecutive_days: 0,
|
||||
reason: '',
|
||||
concept_tags: ['银行', '深圳国资', 'MSCI', '沪深300']
|
||||
}
|
||||
data: Array(30).fill(null).map((_, i) => {
|
||||
const tradeDate = new Date(Date.now() - (29 - i) * 24 * 60 * 60 * 1000).toISOString().split('T')[0];
|
||||
const isLimitUp = Math.random() < 0.05; // 5%概率涨停
|
||||
return {
|
||||
trade_date: tradeDate,
|
||||
is_limit_up: isLimitUp,
|
||||
limit_up_price: (basePrice * 1.10).toFixed(2),
|
||||
current_price: (basePrice + (Math.random() - 0.5) * 0.5).toFixed(2),
|
||||
distance_to_limit: (Math.random() * 10).toFixed(2), // %
|
||||
consecutive_days: isLimitUp ? Math.floor(Math.random() * 3) + 1 : 0,
|
||||
reason: isLimitUp ? '业绩超预期' : '',
|
||||
concept_tags: ['银行', '深圳国资', 'MSCI', '沪深300'],
|
||||
analysis: isLimitUp ? '股价触及涨停板,资金流入明显' : '股价正常波动,交投活跃'
|
||||
};
|
||||
})
|
||||
},
|
||||
|
||||
// 最新分时数据
|
||||
// 最新分时数据 - 匹配 MinuteData 类型
|
||||
latestMinuteData: {
|
||||
success: true,
|
||||
data: Array(240).fill(null).map((_, i) => {
|
||||
const minute = 9 * 60 + 30 + i; // 从9:30开始
|
||||
const hour = Math.floor(minute / 60);
|
||||
const min = minute % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${min.toString().padStart(2, '0')}`;
|
||||
const randomChange = (Math.random() - 0.5) * 0.1;
|
||||
return {
|
||||
time,
|
||||
price: (basePrice + randomChange).toFixed(2),
|
||||
volume: Math.floor(Math.random() * 2000000) + 500000,
|
||||
avg_price: (basePrice + randomChange * 0.8).toFixed(2)
|
||||
};
|
||||
}),
|
||||
data: (() => {
|
||||
const minuteData = [];
|
||||
// 上午 9:30-11:30 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 9 + Math.floor((30 + i) / 60);
|
||||
const min = (30 + i) % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${min.toString().padStart(2, '0')}`;
|
||||
const randomChange = (Math.random() - 0.5) * 0.1;
|
||||
const open = parseFloat((basePrice + randomChange).toFixed(2));
|
||||
const close = parseFloat((basePrice + randomChange + (Math.random() - 0.5) * 0.05).toFixed(2));
|
||||
const high = parseFloat(Math.max(open, close, open + Math.random() * 0.05).toFixed(2));
|
||||
const low = parseFloat(Math.min(open, close, close - Math.random() * 0.05).toFixed(2));
|
||||
minuteData.push({
|
||||
time,
|
||||
open,
|
||||
close,
|
||||
high,
|
||||
low,
|
||||
volume: Math.floor(Math.random() * 2000000) + 500000,
|
||||
amount: Math.floor(Math.random() * 30000000) + 5000000
|
||||
});
|
||||
}
|
||||
// 下午 13:00-15:00 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 13 + Math.floor(i / 60);
|
||||
const min = i % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${min.toString().padStart(2, '0')}`;
|
||||
const randomChange = (Math.random() - 0.5) * 0.1;
|
||||
const open = parseFloat((basePrice + randomChange).toFixed(2));
|
||||
const close = parseFloat((basePrice + randomChange + (Math.random() - 0.5) * 0.05).toFixed(2));
|
||||
const high = parseFloat(Math.max(open, close, open + Math.random() * 0.05).toFixed(2));
|
||||
const low = parseFloat(Math.min(open, close, close - Math.random() * 0.05).toFixed(2));
|
||||
minuteData.push({
|
||||
time,
|
||||
open,
|
||||
close,
|
||||
high,
|
||||
low,
|
||||
volume: Math.floor(Math.random() * 1500000) + 400000,
|
||||
amount: Math.floor(Math.random() * 25000000) + 4000000
|
||||
});
|
||||
}
|
||||
return minuteData;
|
||||
})(),
|
||||
code: stockCode,
|
||||
name: stockCode === '000001' ? '平安银行' : '示例股票',
|
||||
trade_date: new Date().toISOString().split('T')[0],
|
||||
type: 'minute'
|
||||
type: '1min'
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -9,7 +9,8 @@ import {
|
||||
mockInvestmentPlans,
|
||||
mockCalendarEvents,
|
||||
mockSubscriptionCurrent,
|
||||
getCalendarEventsByDateRange
|
||||
getCalendarEventsByDateRange,
|
||||
getFollowedEvents
|
||||
} from '../data/account';
|
||||
|
||||
// 模拟网络延迟(毫秒)
|
||||
@@ -250,7 +251,7 @@ export const accountHandlers = [
|
||||
|
||||
// ==================== 事件关注管理 ====================
|
||||
|
||||
// 8. 获取关注的事件
|
||||
// 8. 获取关注的事件(使用内存状态动态返回)
|
||||
http.get('/api/account/events/following', async () => {
|
||||
await delay(NETWORK_DELAY);
|
||||
|
||||
@@ -262,11 +263,14 @@ export const accountHandlers = [
|
||||
);
|
||||
}
|
||||
|
||||
console.log('[Mock] 获取关注的事件');
|
||||
// 从内存存储获取已关注的事件列表
|
||||
const followedEvents = getFollowedEvents();
|
||||
|
||||
console.log('[Mock] 获取关注的事件, 数量:', followedEvents.length);
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: mockFollowingEvents
|
||||
data: followedEvents
|
||||
});
|
||||
}),
|
||||
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
// src/mocks/handlers/bytedesk.js
|
||||
/**
|
||||
* Bytedesk 客服 Widget MSW Handler
|
||||
* 使用 passthrough 让请求通过到真实服务器,消除 MSW 警告
|
||||
* Mock 模式下返回模拟数据
|
||||
*/
|
||||
|
||||
import { http, passthrough } from 'msw';
|
||||
import { http, HttpResponse, passthrough } from 'msw';
|
||||
|
||||
export const bytedeskHandlers = [
|
||||
// Bytedesk API 请求 - 直接 passthrough
|
||||
// 匹配 /bytedesk/* 路径(通过代理访问后端)
|
||||
// 未读消息数量
|
||||
http.get('/bytedesk/visitor/api/v1/message/unread/count', () => {
|
||||
return HttpResponse.json({
|
||||
code: 200,
|
||||
message: 'success',
|
||||
data: { count: 0 },
|
||||
});
|
||||
}),
|
||||
|
||||
// 其他 Bytedesk API - 返回通用成功响应
|
||||
http.all('/bytedesk/*', () => {
|
||||
return passthrough();
|
||||
return HttpResponse.json({
|
||||
code: 200,
|
||||
message: 'success',
|
||||
data: null,
|
||||
});
|
||||
}),
|
||||
|
||||
// Bytedesk 外部 CDN/服务请求
|
||||
|
||||
@@ -43,12 +43,10 @@ export const companyHandlers = [
|
||||
const { stockCode } = params;
|
||||
const data = getCompanyData(stockCode);
|
||||
|
||||
// 直接返回 keyFactorsTimeline 对象(包含 key_factors 和 development_timeline)
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
timeline: data.keyFactorsTimeline,
|
||||
total: data.keyFactorsTimeline.length
|
||||
}
|
||||
data: data.keyFactorsTimeline
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -69,10 +67,19 @@ export const companyHandlers = [
|
||||
await delay(150);
|
||||
const { stockCode } = params;
|
||||
const data = getCompanyData(stockCode);
|
||||
const raw = data.actualControl;
|
||||
|
||||
// 数据已经是数组格式,只做数值转换(holding_ratio 从 0-100 转为 0-1)
|
||||
const formatted = Array.isArray(raw)
|
||||
? raw.map(item => ({
|
||||
...item,
|
||||
holding_ratio: item.holding_ratio > 1 ? item.holding_ratio / 100 : item.holding_ratio,
|
||||
}))
|
||||
: [];
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: data.actualControl
|
||||
data: formatted
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -81,10 +88,19 @@ export const companyHandlers = [
|
||||
await delay(150);
|
||||
const { stockCode } = params;
|
||||
const data = getCompanyData(stockCode);
|
||||
const raw = data.concentration;
|
||||
|
||||
// 数据已经是数组格式,只做数值转换(holding_ratio 从 0-100 转为 0-1)
|
||||
const formatted = Array.isArray(raw)
|
||||
? raw.map(item => ({
|
||||
...item,
|
||||
holding_ratio: item.holding_ratio > 1 ? item.holding_ratio / 100 : item.holding_ratio,
|
||||
}))
|
||||
: [];
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: data.concentration
|
||||
data: formatted
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -212,4 +228,64 @@ export const companyHandlers = [
|
||||
data: data.forecastReport || null
|
||||
});
|
||||
}),
|
||||
|
||||
// 14. 价值链关联公司
|
||||
http.get('/api/company/value-chain/related-companies', async ({ request }) => {
|
||||
await delay(300);
|
||||
|
||||
const url = new URL(request.url);
|
||||
const nodeName = url.searchParams.get('node_name') || '';
|
||||
|
||||
console.log('[Mock] 获取价值链关联公司:', nodeName);
|
||||
|
||||
// 生成模拟的关联公司数据
|
||||
const relatedCompanies = [
|
||||
{
|
||||
stock_code: '601318',
|
||||
stock_name: '中国平安',
|
||||
industry: '保险',
|
||||
relation: '同业竞争',
|
||||
market_cap: 9200,
|
||||
change_pct: 1.25
|
||||
},
|
||||
{
|
||||
stock_code: '600036',
|
||||
stock_name: '招商银行',
|
||||
industry: '银行',
|
||||
relation: '核心供应商',
|
||||
market_cap: 8500,
|
||||
change_pct: 0.85
|
||||
},
|
||||
{
|
||||
stock_code: '601166',
|
||||
stock_name: '兴业银行',
|
||||
industry: '银行',
|
||||
relation: '同业竞争',
|
||||
market_cap: 4200,
|
||||
change_pct: -0.32
|
||||
},
|
||||
{
|
||||
stock_code: '601398',
|
||||
stock_name: '工商银行',
|
||||
industry: '银行',
|
||||
relation: '同业竞争',
|
||||
market_cap: 15000,
|
||||
change_pct: 0.15
|
||||
},
|
||||
{
|
||||
stock_code: '601288',
|
||||
stock_name: '农业银行',
|
||||
industry: '银行',
|
||||
relation: '同业竞争',
|
||||
market_cap: 12000,
|
||||
change_pct: -0.08
|
||||
}
|
||||
];
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: relatedCompanies,
|
||||
node_name: nodeName
|
||||
});
|
||||
}),
|
||||
];
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import { http, HttpResponse } from 'msw';
|
||||
import { getEventRelatedStocks, generateMockEvents, generateHotEvents, generatePopularKeywords, generateDynamicNewsEvents } from '../data/events';
|
||||
import { getMockFutureEvents, getMockEventCountsForMonth } from '../data/account';
|
||||
import { getMockFutureEvents, getMockEventCountsForMonth, toggleEventFollowStatus, isEventFollowed } from '../data/account';
|
||||
import { generatePopularConcepts } from './concept';
|
||||
|
||||
// 模拟网络延迟
|
||||
@@ -119,9 +119,12 @@ export const eventHandlers = [
|
||||
try {
|
||||
const result = generateMockEvents(params);
|
||||
|
||||
// 返回格式兼容 NewsPanel 期望的结构
|
||||
// NewsPanel 期望: { success, data: [], pagination: {} }
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: result,
|
||||
data: result.events, // 事件数组
|
||||
pagination: result.pagination, // 分页信息
|
||||
message: '获取成功'
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -135,16 +138,14 @@ export const eventHandlers = [
|
||||
{
|
||||
success: false,
|
||||
error: '获取事件列表失败',
|
||||
data: {
|
||||
events: [],
|
||||
pagination: {
|
||||
page: 1,
|
||||
per_page: 10,
|
||||
total: 0,
|
||||
pages: 0, // ← 对齐后端字段名
|
||||
has_prev: false, // ← 对齐后端
|
||||
has_next: false // ← 对齐后端
|
||||
}
|
||||
data: [],
|
||||
pagination: {
|
||||
page: 1,
|
||||
per_page: 10,
|
||||
total: 0,
|
||||
pages: 0,
|
||||
has_prev: false,
|
||||
has_next: false
|
||||
}
|
||||
},
|
||||
{ status: 500 }
|
||||
@@ -260,15 +261,19 @@ export const eventHandlers = [
|
||||
await delay(200);
|
||||
|
||||
const { eventId } = params;
|
||||
const numericEventId = parseInt(eventId, 10);
|
||||
|
||||
console.log('[Mock] 获取事件详情, eventId:', eventId);
|
||||
console.log('[Mock] 获取事件详情, eventId:', numericEventId);
|
||||
|
||||
try {
|
||||
// 检查是否已关注
|
||||
const isFollowing = isEventFollowed(numericEventId);
|
||||
|
||||
// 返回模拟的事件详情数据
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
id: parseInt(eventId),
|
||||
id: numericEventId,
|
||||
title: `测试事件 ${eventId} - 重大政策发布`,
|
||||
description: '这是一个模拟的事件描述,用于开发测试。该事件涉及重要政策变化,可能对相关板块产生显著影响。建议关注后续发展动态。',
|
||||
importance: ['S', 'A', 'B', 'C'][Math.floor(Math.random() * 4)],
|
||||
@@ -278,7 +283,7 @@ export const eventHandlers = [
|
||||
related_avg_chg: parseFloat((Math.random() * 10 - 5).toFixed(2)),
|
||||
follower_count: Math.floor(Math.random() * 500) + 50,
|
||||
view_count: Math.floor(Math.random() * 5000) + 100,
|
||||
is_following: false,
|
||||
is_following: isFollowing, // 使用内存状态
|
||||
post_count: Math.floor(Math.random() * 50),
|
||||
expectation_surprise_score: parseFloat((Math.random() * 100).toFixed(1)),
|
||||
},
|
||||
@@ -297,6 +302,45 @@ export const eventHandlers = [
|
||||
}
|
||||
}),
|
||||
|
||||
// 获取事件超预期得分
|
||||
http.get('/api/events/:eventId/expectation-score', async ({ params }) => {
|
||||
await delay(200);
|
||||
|
||||
const { eventId } = params;
|
||||
|
||||
console.log('[Mock] 获取事件超预期得分, eventId:', eventId);
|
||||
|
||||
try {
|
||||
// 生成模拟的超预期得分数据
|
||||
const score = parseFloat((Math.random() * 100).toFixed(1));
|
||||
const avgChange = parseFloat((Math.random() * 10 - 2).toFixed(2));
|
||||
const maxChange = parseFloat((Math.random() * 15).toFixed(2));
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
event_id: parseInt(eventId),
|
||||
expectation_score: score,
|
||||
avg_change: avgChange,
|
||||
max_change: maxChange,
|
||||
stock_count: Math.floor(Math.random() * 20) + 5,
|
||||
updated_at: new Date().toISOString(),
|
||||
},
|
||||
message: '获取成功'
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[Mock] 获取事件超预期得分失败:', error);
|
||||
return HttpResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: '获取事件超预期得分失败',
|
||||
data: null
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}),
|
||||
|
||||
// 获取事件相关股票
|
||||
http.get('/api/events/:eventId/stocks', async ({ params }) => {
|
||||
await delay(300);
|
||||
@@ -356,19 +400,29 @@ export const eventHandlers = [
|
||||
}
|
||||
}),
|
||||
|
||||
// 切换事件关注状态
|
||||
http.post('/api/events/:eventId/follow', async ({ params }) => {
|
||||
// 切换事件关注状态(使用内存状态管理)
|
||||
http.post('/api/events/:eventId/follow', async ({ params, request }) => {
|
||||
await delay(200);
|
||||
|
||||
const { eventId } = params;
|
||||
const numericEventId = parseInt(eventId, 10);
|
||||
|
||||
console.log('[Mock] 切换事件关注状态, eventId:', eventId);
|
||||
console.log('[Mock] 切换事件关注状态, eventId:', numericEventId);
|
||||
|
||||
try {
|
||||
// 模拟切换逻辑:随机生成关注状态
|
||||
// 实际应用中,这里应该从某个状态存储中读取和更新
|
||||
const isFollowing = Math.random() > 0.5;
|
||||
const followerCount = Math.floor(Math.random() * 1000) + 100;
|
||||
// 尝试从请求体获取事件数据(用于新关注时保存完整信息)
|
||||
let eventData = null;
|
||||
try {
|
||||
const body = await request.json();
|
||||
if (body && body.title) {
|
||||
eventData = body;
|
||||
}
|
||||
} catch {
|
||||
// 没有请求体或解析失败,忽略
|
||||
}
|
||||
|
||||
// 使用内存状态管理切换关注
|
||||
const { isFollowing, followerCount } = toggleEventFollowStatus(numericEventId, eventData);
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -8,6 +8,53 @@ import { generateMarketData } from '../data/market';
|
||||
const delay = (ms) => new Promise(resolve => setTimeout(resolve, ms));
|
||||
|
||||
export const marketHandlers = [
|
||||
// 0. 指数实时行情数据
|
||||
http.get('/api/index/:indexCode/realtime', async ({ params }) => {
|
||||
await delay(100);
|
||||
const { indexCode } = params;
|
||||
|
||||
console.log('[Mock] 获取指数实时行情, indexCode:', indexCode);
|
||||
|
||||
// 指数基础数据
|
||||
const indexData = {
|
||||
'000001': { name: '上证指数', basePrice: 3200, baseVolume: 3500 },
|
||||
'399001': { name: '深证成指', basePrice: 10500, baseVolume: 4200 },
|
||||
'399006': { name: '创业板指', basePrice: 2100, baseVolume: 1800 },
|
||||
'000300': { name: '沪深300', basePrice: 3800, baseVolume: 2800 },
|
||||
'000016': { name: '上证50', basePrice: 2600, baseVolume: 1200 },
|
||||
'000905': { name: '中证500', basePrice: 5800, baseVolume: 1500 },
|
||||
};
|
||||
|
||||
const baseData = indexData[indexCode] || { name: `指数${indexCode}`, basePrice: 3000, baseVolume: 2000 };
|
||||
|
||||
// 生成随机波动
|
||||
const changePercent = parseFloat((Math.random() * 4 - 2).toFixed(2)); // -2% ~ +2%
|
||||
const price = parseFloat((baseData.basePrice * (1 + changePercent / 100)).toFixed(2));
|
||||
const change = parseFloat((price - baseData.basePrice).toFixed(2));
|
||||
const volume = parseFloat((baseData.baseVolume * (0.8 + Math.random() * 0.4)).toFixed(2)); // 80%-120% of base
|
||||
const amount = parseFloat((volume * price / 10000).toFixed(2)); // 亿元
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
index_code: indexCode,
|
||||
index_name: baseData.name,
|
||||
current_price: price,
|
||||
change: change,
|
||||
change_percent: changePercent,
|
||||
open_price: parseFloat((baseData.basePrice * (1 + (Math.random() * 0.01 - 0.005))).toFixed(2)),
|
||||
high_price: parseFloat((price * (1 + Math.random() * 0.01)).toFixed(2)),
|
||||
low_price: parseFloat((price * (1 - Math.random() * 0.01)).toFixed(2)),
|
||||
prev_close: baseData.basePrice,
|
||||
volume: volume, // 亿手
|
||||
amount: amount, // 亿元
|
||||
update_time: new Date().toISOString(),
|
||||
market_status: 'trading', // trading, closed, pre-market, after-hours
|
||||
},
|
||||
message: '获取成功'
|
||||
});
|
||||
}),
|
||||
|
||||
// 1. 成交数据
|
||||
http.get('/api/market/trade/:stockCode', async ({ params, request }) => {
|
||||
await delay(200);
|
||||
@@ -299,7 +346,173 @@ export const marketHandlers = [
|
||||
});
|
||||
}),
|
||||
|
||||
// 11. 市场统计数据(个股中心页面使用)
|
||||
// 11. 热点概览数据(大盘分时 + 概念异动)
|
||||
http.get('/api/market/hotspot-overview', async ({ request }) => {
|
||||
await delay(300);
|
||||
const url = new URL(request.url);
|
||||
const date = url.searchParams.get('date');
|
||||
|
||||
const tradeDate = date || new Date().toISOString().split('T')[0];
|
||||
|
||||
// 生成分时数据(240个点,9:30-11:30 + 13:00-15:00)
|
||||
const timeline = [];
|
||||
const basePrice = 3900 + Math.random() * 100; // 基准价格 3900-4000
|
||||
const prevClose = basePrice;
|
||||
let currentPrice = basePrice;
|
||||
let cumulativeVolume = 0;
|
||||
|
||||
// 上午时段 9:30-11:30 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 9 + Math.floor((i + 30) / 60);
|
||||
const minute = (i + 30) % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 模拟价格波动
|
||||
const volatility = 0.002; // 0.2%波动
|
||||
const drift = (Math.random() - 0.5) * 0.001; // 微小趋势
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 500000 + 100000); // 成交量
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 下午时段 13:00-15:00 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 13 + Math.floor(i / 60);
|
||||
const minute = i % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 下午波动略小
|
||||
const volatility = 0.0015;
|
||||
const drift = (Math.random() - 0.5) * 0.0008;
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 400000 + 80000);
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 生成概念异动数据
|
||||
const conceptNames = [
|
||||
'人工智能', 'AI眼镜', '机器人', '核电', '国企', '卫星导航',
|
||||
'福建自贸区', '两岸融合', 'CRO', '三季报增长', '百货零售',
|
||||
'人形机器人', '央企', '数据中心', 'CPO', '新能源', '电网设备',
|
||||
'氢能源', '算力租赁', '厦门国资', '乳业', '低空安防', '创新药',
|
||||
'商业航天', '控制权变更', '文化传媒', '海峡两岸'
|
||||
];
|
||||
|
||||
const alertTypes = ['surge_up', 'surge_down', 'volume_spike', 'limit_up', 'rank_jump'];
|
||||
|
||||
// 生成 15-25 个异动
|
||||
const alertCount = Math.floor(Math.random() * 10) + 15;
|
||||
const alerts = [];
|
||||
const usedTimes = new Set();
|
||||
|
||||
for (let i = 0; i < alertCount; i++) {
|
||||
// 随机选择一个时间点
|
||||
let timeIdx;
|
||||
let attempts = 0;
|
||||
do {
|
||||
timeIdx = Math.floor(Math.random() * timeline.length);
|
||||
attempts++;
|
||||
} while (usedTimes.has(timeIdx) && attempts < 50);
|
||||
|
||||
if (attempts >= 50) continue;
|
||||
|
||||
// 同一时间可以有多个异动
|
||||
const time = timeline[timeIdx].time;
|
||||
const conceptName = conceptNames[Math.floor(Math.random() * conceptNames.length)];
|
||||
const alertType = alertTypes[Math.floor(Math.random() * alertTypes.length)];
|
||||
|
||||
// 根据类型生成 alpha
|
||||
let alpha;
|
||||
if (alertType === 'surge_up') {
|
||||
alpha = parseFloat((Math.random() * 3 + 2).toFixed(2)); // +2% ~ +5%
|
||||
} else if (alertType === 'surge_down') {
|
||||
alpha = parseFloat((-Math.random() * 3 - 1.5).toFixed(2)); // -1.5% ~ -4.5%
|
||||
} else {
|
||||
alpha = parseFloat((Math.random() * 4 - 1).toFixed(2)); // -1% ~ +3%
|
||||
}
|
||||
|
||||
const finalScore = Math.floor(Math.random() * 40 + 45); // 45-85分
|
||||
const ruleScore = Math.floor(Math.random() * 30 + 40);
|
||||
const mlScore = Math.floor(Math.random() * 30 + 40);
|
||||
|
||||
alerts.push({
|
||||
concept_id: `CONCEPT_${1000 + i}`,
|
||||
concept_name: conceptName,
|
||||
time,
|
||||
alert_type: alertType,
|
||||
alpha,
|
||||
alpha_delta: parseFloat((Math.random() * 2 - 0.5).toFixed(2)),
|
||||
amt_ratio: parseFloat((Math.random() * 5 + 1).toFixed(2)),
|
||||
limit_up_count: alertType === 'limit_up' ? Math.floor(Math.random() * 5 + 1) : 0,
|
||||
limit_up_ratio: parseFloat((Math.random() * 0.3).toFixed(3)),
|
||||
final_score: finalScore,
|
||||
rule_score: ruleScore,
|
||||
ml_score: mlScore,
|
||||
trigger_reason: finalScore >= 65 ? '规则强信号' : (mlScore >= 70 ? 'ML强信号' : '融合触发'),
|
||||
importance_score: parseFloat((finalScore / 100).toFixed(2)),
|
||||
index_price: timeline[timeIdx].price
|
||||
});
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
alerts.sort((a, b) => a.time.localeCompare(b.time));
|
||||
|
||||
// 统计异动类型
|
||||
const alertSummary = alerts.reduce((acc, alert) => {
|
||||
acc[alert.alert_type] = (acc[alert.alert_type] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// 计算指数统计
|
||||
const prices = timeline.map(t => t.price);
|
||||
const latestPrice = prices[prices.length - 1];
|
||||
const highPrice = Math.max(...prices);
|
||||
const lowPrice = Math.min(...prices);
|
||||
const changePct = ((latestPrice - prevClose) / prevClose * 100);
|
||||
|
||||
console.log('[Mock Market] 获取热点概览数据:', {
|
||||
date: tradeDate,
|
||||
timelinePoints: timeline.length,
|
||||
alertCount: alerts.length
|
||||
});
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
index: {
|
||||
code: '000001.SH',
|
||||
name: '上证指数',
|
||||
latest_price: latestPrice,
|
||||
prev_close: prevClose,
|
||||
high: highPrice,
|
||||
low: lowPrice,
|
||||
change_pct: parseFloat(changePct.toFixed(2)),
|
||||
timeline
|
||||
},
|
||||
alerts,
|
||||
alert_summary: alertSummary
|
||||
},
|
||||
trade_date: tradeDate
|
||||
});
|
||||
}),
|
||||
|
||||
// 12. 市场统计数据(个股中心页面使用)
|
||||
http.get('/api/market/statistics', async ({ request }) => {
|
||||
await delay(200);
|
||||
const url = new URL(request.url);
|
||||
|
||||
@@ -341,6 +341,68 @@ export const stockHandlers = [
|
||||
}
|
||||
}),
|
||||
|
||||
// 获取股票业绩预告
|
||||
http.get('/api/stock/:stockCode/forecast', async ({ params }) => {
|
||||
await delay(200);
|
||||
|
||||
const { stockCode } = params;
|
||||
console.log('[Mock Stock] 获取业绩预告:', { stockCode });
|
||||
|
||||
// 生成股票列表用于查找名称
|
||||
const stockList = generateStockList();
|
||||
const stockInfo = stockList.find(s => s.code === stockCode.replace(/\.(SH|SZ)$/i, ''));
|
||||
const stockName = stockInfo?.name || `股票${stockCode}`;
|
||||
|
||||
// 业绩预告类型列表
|
||||
const forecastTypes = ['预增', '预减', '略增', '略减', '扭亏', '续亏', '首亏', '续盈'];
|
||||
|
||||
// 生成业绩预告数据
|
||||
const forecasts = [
|
||||
{
|
||||
forecast_type: '预增',
|
||||
report_date: '2024年年报',
|
||||
content: `${stockName}预计2024年度归属于上市公司股东的净利润为58亿元至62亿元,同比增长10%至17%。`,
|
||||
reason: '报告期内,公司主营业务收入稳步增长,产品结构持续优化,毛利率提升;同时公司加大研发投入,新产品市场表现良好。',
|
||||
change_range: {
|
||||
lower: 10,
|
||||
upper: 17
|
||||
},
|
||||
publish_date: '2024-10-15'
|
||||
},
|
||||
{
|
||||
forecast_type: '略增',
|
||||
report_date: '2024年三季报',
|
||||
content: `${stockName}预计2024年1-9月归属于上市公司股东的净利润为42亿元至45亿元,同比增长5%至12%。`,
|
||||
reason: '公司积极拓展市场渠道,销售规模持续扩大,经营效益稳步提升。',
|
||||
change_range: {
|
||||
lower: 5,
|
||||
upper: 12
|
||||
},
|
||||
publish_date: '2024-07-12'
|
||||
},
|
||||
{
|
||||
forecast_type: forecastTypes[Math.floor(Math.random() * forecastTypes.length)],
|
||||
report_date: '2024年中报',
|
||||
content: `${stockName}预计2024年上半年归属于上市公司股东的净利润为28亿元至30亿元。`,
|
||||
reason: '受益于行业景气度回升及公司降本增效措施效果显现,经营业绩同比有所改善。',
|
||||
change_range: {
|
||||
lower: 3,
|
||||
upper: 8
|
||||
},
|
||||
publish_date: '2024-04-20'
|
||||
}
|
||||
];
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
stock_code: stockCode,
|
||||
stock_name: stockName,
|
||||
forecasts: forecasts
|
||||
}
|
||||
});
|
||||
}),
|
||||
|
||||
// 获取股票报价(批量)
|
||||
http.post('/api/stock/quotes', async ({ request }) => {
|
||||
await delay(200);
|
||||
@@ -368,6 +430,25 @@ export const stockHandlers = [
|
||||
stockMap[s.code] = s.name;
|
||||
});
|
||||
|
||||
// 行业和指数映射表
|
||||
const stockIndustryMap = {
|
||||
'000001': { industry_l1: '金融', industry: '银行', index_tags: ['沪深300', '上证180'] },
|
||||
'600519': { industry_l1: '消费', industry: '白酒', index_tags: ['沪深300', '上证50'] },
|
||||
'300750': { industry_l1: '工业', industry: '电池', index_tags: ['创业板50'] },
|
||||
'601318': { industry_l1: '金融', industry: '保险', index_tags: ['沪深300', '上证50'] },
|
||||
'600036': { industry_l1: '金融', industry: '银行', index_tags: ['沪深300', '上证50'] },
|
||||
'000858': { industry_l1: '消费', industry: '白酒', index_tags: ['沪深300'] },
|
||||
'002594': { industry_l1: '汽车', industry: '乘用车', index_tags: ['沪深300', '创业板指'] },
|
||||
};
|
||||
|
||||
const defaultIndustries = [
|
||||
{ industry_l1: '科技', industry: '软件' },
|
||||
{ industry_l1: '医药', industry: '化学制药' },
|
||||
{ industry_l1: '消费', industry: '食品' },
|
||||
{ industry_l1: '金融', industry: '证券' },
|
||||
{ industry_l1: '工业', industry: '机械' },
|
||||
];
|
||||
|
||||
// 为每只股票生成报价数据
|
||||
const quotesData = {};
|
||||
codes.forEach(stockCode => {
|
||||
@@ -380,6 +461,11 @@ export const stockHandlers = [
|
||||
// 昨收
|
||||
const prevClose = parseFloat((basePrice - change).toFixed(2));
|
||||
|
||||
// 获取行业和指数信息
|
||||
const codeWithoutSuffix = stockCode.replace(/\.(SH|SZ)$/i, '');
|
||||
const industryInfo = stockIndustryMap[codeWithoutSuffix] ||
|
||||
defaultIndustries[Math.floor(Math.random() * defaultIndustries.length)];
|
||||
|
||||
quotesData[stockCode] = {
|
||||
code: stockCode,
|
||||
name: stockMap[stockCode] || `股票${stockCode}`,
|
||||
@@ -393,7 +479,23 @@ export const stockHandlers = [
|
||||
volume: Math.floor(Math.random() * 100000000),
|
||||
amount: parseFloat((Math.random() * 10000000000).toFixed(2)),
|
||||
market: stockCode.startsWith('6') ? 'SH' : 'SZ',
|
||||
update_time: new Date().toISOString()
|
||||
update_time: new Date().toISOString(),
|
||||
// 行业和指数标签
|
||||
industry_l1: industryInfo.industry_l1,
|
||||
industry: industryInfo.industry,
|
||||
index_tags: industryInfo.index_tags || [],
|
||||
// 关键指标
|
||||
pe: parseFloat((Math.random() * 50 + 5).toFixed(2)),
|
||||
eps: parseFloat((Math.random() * 5 + 0.1).toFixed(3)),
|
||||
pb: parseFloat((Math.random() * 8 + 0.5).toFixed(2)),
|
||||
market_cap: `${(Math.random() * 5000 + 100).toFixed(0)}亿`,
|
||||
week52_low: parseFloat((basePrice * 0.7).toFixed(2)),
|
||||
week52_high: parseFloat((basePrice * 1.3).toFixed(2)),
|
||||
// 主力动态
|
||||
main_net_inflow: parseFloat((Math.random() * 10 - 5).toFixed(2)),
|
||||
institution_holding: parseFloat((Math.random() * 50 + 10).toFixed(2)),
|
||||
buy_ratio: parseFloat((Math.random() * 40 + 30).toFixed(2)),
|
||||
sell_ratio: parseFloat((100 - (Math.random() * 40 + 30)).toFixed(2))
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ export const lazyComponents = {
|
||||
|
||||
// 公司相关模块
|
||||
CompanyIndex: React.lazy(() => import('@views/Company')),
|
||||
ForecastReport: React.lazy(() => import('@views/Company/ForecastReport')),
|
||||
FinancialPanorama: React.lazy(() => import('@views/Company/FinancialPanorama')),
|
||||
MarketDataView: React.lazy(() => import('@views/Company/MarketDataView')),
|
||||
ForecastReport: React.lazy(() => import('@views/Company/components/ForecastReport')),
|
||||
FinancialPanorama: React.lazy(() => import('@views/Company/components/FinancialPanorama')),
|
||||
MarketDataView: React.lazy(() => import('@views/Company/components/MarketDataView')),
|
||||
|
||||
// Agent模块
|
||||
AgentChat: React.lazy(() => import('@views/AgentChat')),
|
||||
|
||||
@@ -608,14 +608,40 @@ const communityDataSlice = createSlice({
|
||||
state.error[stateKey] = action.payload;
|
||||
logger.error('CommunityData', `${stateKey} 加载失败`, new Error(action.payload));
|
||||
})
|
||||
// toggleEventFollow
|
||||
// ===== toggleEventFollow(乐观更新)=====
|
||||
// pending: 立即切换状态
|
||||
.addCase(toggleEventFollow.pending, (state, action) => {
|
||||
const eventId = action.meta.arg;
|
||||
const current = state.eventFollowStatus[eventId];
|
||||
// 乐观切换:如果当前已关注则变为未关注,反之亦然
|
||||
state.eventFollowStatus[eventId] = {
|
||||
isFollowing: !(current?.isFollowing),
|
||||
followerCount: current?.followerCount ?? 0
|
||||
};
|
||||
logger.debug('CommunityData', 'toggleEventFollow pending (乐观更新)', {
|
||||
eventId,
|
||||
newIsFollowing: !(current?.isFollowing)
|
||||
});
|
||||
})
|
||||
// rejected: 回滚状态
|
||||
.addCase(toggleEventFollow.rejected, (state, action) => {
|
||||
const eventId = action.meta.arg;
|
||||
const current = state.eventFollowStatus[eventId];
|
||||
// 回滚:恢复到之前的状态(再次切换回去)
|
||||
state.eventFollowStatus[eventId] = {
|
||||
isFollowing: !(current?.isFollowing),
|
||||
followerCount: current?.followerCount ?? 0
|
||||
};
|
||||
logger.error('CommunityData', 'toggleEventFollow rejected (已回滚)', {
|
||||
eventId,
|
||||
error: action.payload
|
||||
});
|
||||
})
|
||||
// fulfilled: 使用 API 返回的准确数据覆盖
|
||||
.addCase(toggleEventFollow.fulfilled, (state, action) => {
|
||||
const { eventId, isFollowing, followerCount } = action.payload;
|
||||
state.eventFollowStatus[eventId] = { isFollowing, followerCount };
|
||||
logger.debug('CommunityData', 'toggleEventFollow fulfilled', { eventId, isFollowing, followerCount });
|
||||
})
|
||||
.addCase(toggleEventFollow.rejected, (_state, action) => {
|
||||
logger.error('CommunityData', 'toggleEventFollow rejected', action.payload);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@@ -4,6 +4,56 @@ import { eventService, stockService } from '../../services/eventService';
|
||||
import { logger } from '../../utils/logger';
|
||||
import { getApiBase } from '../../utils/apiConfig';
|
||||
|
||||
// ==================== Watchlist 缓存配置 ====================
|
||||
const WATCHLIST_CACHE_KEY = 'watchlist_cache';
|
||||
const WATCHLIST_CACHE_DURATION = 7 * 24 * 60 * 60 * 1000; // 7天
|
||||
|
||||
/**
|
||||
* 从 localStorage 读取自选股缓存
|
||||
*/
|
||||
const loadWatchlistFromCache = () => {
|
||||
try {
|
||||
const cached = localStorage.getItem(WATCHLIST_CACHE_KEY);
|
||||
if (!cached) return null;
|
||||
|
||||
const { data, timestamp } = JSON.parse(cached);
|
||||
const now = Date.now();
|
||||
|
||||
// 检查缓存是否过期(7天)
|
||||
if (now - timestamp > WATCHLIST_CACHE_DURATION) {
|
||||
localStorage.removeItem(WATCHLIST_CACHE_KEY);
|
||||
logger.debug('stockSlice', '自选股缓存已过期');
|
||||
return null;
|
||||
}
|
||||
|
||||
logger.debug('stockSlice', '自选股 localStorage 缓存命中', {
|
||||
count: data?.length || 0,
|
||||
age: Math.round((now - timestamp) / 1000 / 60) + '分钟前'
|
||||
});
|
||||
return data;
|
||||
} catch (error) {
|
||||
logger.error('stockSlice', 'loadWatchlistFromCache', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* 保存自选股到 localStorage
|
||||
*/
|
||||
const saveWatchlistToCache = (data) => {
|
||||
try {
|
||||
localStorage.setItem(WATCHLIST_CACHE_KEY, JSON.stringify({
|
||||
data,
|
||||
timestamp: Date.now()
|
||||
}));
|
||||
logger.debug('stockSlice', '自选股已缓存到 localStorage', {
|
||||
count: data?.length || 0
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('stockSlice', 'saveWatchlistToCache', error);
|
||||
}
|
||||
};
|
||||
|
||||
// ==================== Async Thunks ====================
|
||||
|
||||
/**
|
||||
@@ -153,13 +203,28 @@ export const fetchExpectationScore = createAsyncThunk(
|
||||
|
||||
/**
|
||||
* 加载用户自选股列表(包含完整信息)
|
||||
* 缓存策略:Redux 内存缓存 → localStorage 持久缓存(7天) → API 请求
|
||||
*/
|
||||
export const loadWatchlist = createAsyncThunk(
|
||||
'stock/loadWatchlist',
|
||||
async () => {
|
||||
async (_, { getState }) => {
|
||||
logger.debug('stockSlice', 'loadWatchlist');
|
||||
|
||||
try {
|
||||
// 1. 先检查 Redux 内存缓存
|
||||
const reduxCached = getState().stock.watchlist;
|
||||
if (reduxCached && reduxCached.length > 0) {
|
||||
logger.debug('stockSlice', 'Redux watchlist 缓存命中', { count: reduxCached.length });
|
||||
return reduxCached;
|
||||
}
|
||||
|
||||
// 2. 再检查 localStorage 持久缓存(7天有效期)
|
||||
const localCached = loadWatchlistFromCache();
|
||||
if (localCached && localCached.length > 0) {
|
||||
return localCached;
|
||||
}
|
||||
|
||||
// 3. 缓存无效,调用 API
|
||||
const apiBase = getApiBase();
|
||||
const response = await fetch(`${apiBase}/api/account/watchlist`, {
|
||||
credentials: 'include'
|
||||
@@ -172,6 +237,10 @@ export const loadWatchlist = createAsyncThunk(
|
||||
stock_code: item.stock_code,
|
||||
stock_name: item.stock_name,
|
||||
}));
|
||||
|
||||
// 保存到 localStorage 缓存
|
||||
saveWatchlistToCache(watchlistData);
|
||||
|
||||
logger.debug('stockSlice', '自选股列表加载成功', {
|
||||
count: watchlistData.length
|
||||
});
|
||||
@@ -340,6 +409,26 @@ const stockSlice = createSlice({
|
||||
delete state.historicalEventsCache[eventId];
|
||||
delete state.chainAnalysisCache[eventId];
|
||||
delete state.expectationScores[eventId];
|
||||
},
|
||||
|
||||
/**
|
||||
* 乐观更新:添加自选股(同步)
|
||||
*/
|
||||
optimisticAddWatchlist: (state, action) => {
|
||||
const { stockCode, stockName } = action.payload;
|
||||
// 避免重复添加
|
||||
const exists = state.watchlist.some(item => item.stock_code === stockCode);
|
||||
if (!exists) {
|
||||
state.watchlist.push({ stock_code: stockCode, stock_name: stockName || '' });
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 乐观更新:移除自选股(同步)
|
||||
*/
|
||||
optimisticRemoveWatchlist: (state, action) => {
|
||||
const { stockCode } = action.payload;
|
||||
state.watchlist = state.watchlist.filter(item => item.stock_code !== stockCode);
|
||||
}
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
@@ -440,9 +529,10 @@ const stockSlice = createSlice({
|
||||
state.loading.allStocks = false;
|
||||
})
|
||||
|
||||
// ===== toggleWatchlist =====
|
||||
.addCase(toggleWatchlist.fulfilled, (state, action) => {
|
||||
const { stockCode, stockName, isInWatchlist } = action.payload;
|
||||
// ===== toggleWatchlist(乐观更新)=====
|
||||
// pending: 立即更新状态
|
||||
.addCase(toggleWatchlist.pending, (state, action) => {
|
||||
const { stockCode, stockName, isInWatchlist } = action.meta.arg;
|
||||
if (isInWatchlist) {
|
||||
// 移除
|
||||
state.watchlist = state.watchlist.filter(item => item.stock_code !== stockCode);
|
||||
@@ -453,6 +543,26 @@ const stockSlice = createSlice({
|
||||
state.watchlist.push({ stock_code: stockCode, stock_name: stockName });
|
||||
}
|
||||
}
|
||||
})
|
||||
// rejected: 回滚状态
|
||||
.addCase(toggleWatchlist.rejected, (state, action) => {
|
||||
const { stockCode, stockName, isInWatchlist } = action.meta.arg;
|
||||
// 回滚:与 pending 操作相反
|
||||
if (isInWatchlist) {
|
||||
// 之前移除了,现在加回来
|
||||
const exists = state.watchlist.some(item => item.stock_code === stockCode);
|
||||
if (!exists) {
|
||||
state.watchlist.push({ stock_code: stockCode, stock_name: stockName });
|
||||
}
|
||||
} else {
|
||||
// 之前添加了,现在移除
|
||||
state.watchlist = state.watchlist.filter(item => item.stock_code !== stockCode);
|
||||
}
|
||||
})
|
||||
// fulfilled: 同步更新 localStorage 缓存
|
||||
.addCase(toggleWatchlist.fulfilled, (state) => {
|
||||
// 状态已在 pending 时更新,这里同步到 localStorage
|
||||
saveWatchlistToCache(state.watchlist);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -461,7 +571,9 @@ export const {
|
||||
updateQuote,
|
||||
updateQuotes,
|
||||
clearQuotes,
|
||||
clearEventCache
|
||||
clearEventCache,
|
||||
optimisticAddWatchlist,
|
||||
optimisticRemoveWatchlist
|
||||
} = stockSlice.actions;
|
||||
|
||||
export default stockSlice.reducer;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// 性能监控工具 - 统计白屏时间和性能指标
|
||||
|
||||
import { logger } from './logger';
|
||||
import { reportPerformanceMetrics } from '../lib/posthog';
|
||||
|
||||
/**
|
||||
* 性能指标接口
|
||||
@@ -208,6 +209,9 @@ class PerformanceMonitor {
|
||||
// 性能分析建议
|
||||
this.analyzePerformance();
|
||||
|
||||
// 上报性能指标到 PostHog
|
||||
reportPerformanceMetrics(this.metrics);
|
||||
|
||||
return this.metrics;
|
||||
}
|
||||
|
||||
|
||||
@@ -103,3 +103,71 @@ export const PriceArrow = ({ value }) => {
|
||||
|
||||
return <Icon color={color} boxSize="16px" />;
|
||||
};
|
||||
|
||||
// ==================== 货币/数值格式化 ====================
|
||||
|
||||
/**
|
||||
* 格式化货币金额(自动选择单位:亿元/万元/元)
|
||||
* @param {number|null|undefined} value - 金额(单位:元)
|
||||
* @returns {string} 格式化后的金额字符串
|
||||
*/
|
||||
export const formatCurrency = (value) => {
|
||||
if (value === null || value === undefined) return '-';
|
||||
const absValue = Math.abs(value);
|
||||
if (absValue >= 100000000) {
|
||||
return (value / 100000000).toFixed(2) + '亿元';
|
||||
} else if (absValue >= 10000) {
|
||||
return (value / 10000).toFixed(2) + '万元';
|
||||
}
|
||||
return value.toFixed(2) + '元';
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化业务营收(支持指定单位)
|
||||
* @param {number|null|undefined} value - 营收金额
|
||||
* @param {string} [unit] - 原始单位(元/万元/亿元)
|
||||
* @returns {string} 格式化后的营收字符串
|
||||
*/
|
||||
export const formatBusinessRevenue = (value, unit) => {
|
||||
if (value === null || value === undefined) return '-';
|
||||
if (unit) {
|
||||
if (unit === '元') {
|
||||
const absValue = Math.abs(value);
|
||||
if (absValue >= 100000000) {
|
||||
return (value / 100000000).toFixed(2) + '亿元';
|
||||
} else if (absValue >= 10000) {
|
||||
return (value / 10000).toFixed(2) + '万元';
|
||||
}
|
||||
return value.toFixed(0) + '元';
|
||||
} else if (unit === '万元') {
|
||||
const absValue = Math.abs(value);
|
||||
if (absValue >= 10000) {
|
||||
return (value / 10000).toFixed(2) + '亿元';
|
||||
}
|
||||
return value.toFixed(2) + '万元';
|
||||
} else if (unit === '亿元') {
|
||||
return value.toFixed(2) + '亿元';
|
||||
} else {
|
||||
return value.toFixed(2) + unit;
|
||||
}
|
||||
}
|
||||
// 无单位时,假设为元
|
||||
const absValue = Math.abs(value);
|
||||
if (absValue >= 100000000) {
|
||||
return (value / 100000000).toFixed(2) + '亿元';
|
||||
} else if (absValue >= 10000) {
|
||||
return (value / 10000).toFixed(2) + '万元';
|
||||
}
|
||||
return value.toFixed(2) + '元';
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化百分比
|
||||
* @param {number|null|undefined} value - 百分比值
|
||||
* @param {number} [decimals=2] - 小数位数
|
||||
* @returns {string} 格式化后的百分比字符串
|
||||
*/
|
||||
export const formatPercentage = (value, decimals = 2) => {
|
||||
if (value === null || value === undefined) return '-';
|
||||
return value.toFixed(decimals) + '%';
|
||||
};
|
||||
|
||||
184
src/views/Community/README.md
Normal file
184
src/views/Community/README.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# Community 模块说明
|
||||
|
||||
本目录包含社区页面的所有组件,按功能模块进行组织。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
src/views/Community/
|
||||
├── index.js # 页面入口
|
||||
├── components/ # 组件目录
|
||||
│ ├── SearchFilters/ # 搜索筛选模块
|
||||
│ │ ├── CompactSearchBox.js
|
||||
│ │ ├── CompactSearchBox.css
|
||||
│ │ ├── TradingTimeFilter.js
|
||||
│ │ └── index.js
|
||||
│ ├── EventCard/ # 事件卡片模块
|
||||
│ │ ├── atoms/ # 原子组件
|
||||
│ │ └── index.js
|
||||
│ ├── HotEvents/ # 热点事件模块
|
||||
│ │ ├── HotEvents.js
|
||||
│ │ ├── HotEvents.css
|
||||
│ │ ├── HotEventsSection.js
|
||||
│ │ └── index.js
|
||||
│ ├── DynamicNews/ # 动态新闻模块
|
||||
│ │ ├── layouts/
|
||||
│ │ ├── hooks/
|
||||
│ │ └── index.js
|
||||
│ ├── EventDetailModal/ # 事件详情弹窗模块
|
||||
│ │ ├── EventDetailModal.tsx
|
||||
│ │ ├── EventDetailModal.less
|
||||
│ │ └── index.ts
|
||||
│ └── HeroPanel.js # 英雄面板(独立组件)
|
||||
└── hooks/ # 页面级 Hooks
|
||||
├── useEventData.js
|
||||
├── useEventFilters.js
|
||||
└── useCommunityEvents.js
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. SearchFilters 模块(搜索筛选)
|
||||
|
||||
路径:`components/SearchFilters/`
|
||||
|
||||
| 文件 | 行数 | 功能 |
|
||||
|------|------|------|
|
||||
| `CompactSearchBox.js` | 612 | 紧凑搜索框,集成关键词搜索、概念/行业筛选 |
|
||||
| `TradingTimeFilter.js` | 491 | 交易时间筛选器,被 CompactSearchBox 引用 |
|
||||
|
||||
**使用方式**:
|
||||
```javascript
|
||||
import { CompactSearchBox } from './components/SearchFilters';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. EventCard 模块(事件卡片)
|
||||
|
||||
路径:`components/EventCard/`
|
||||
|
||||
### 主卡片组件
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `CompactEventCard.js` | 紧凑事件卡片(列表模式) |
|
||||
| `DetailedEventCard.js` | 详细事件卡片(展开模式) |
|
||||
| `DynamicNewsEventCard.js` | 动态新闻事件卡片 |
|
||||
| `HorizontalDynamicNewsEventCard.js` | 水平布局新闻卡片 |
|
||||
|
||||
### 原子组件(atoms/)
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `EventHeader.js` | 事件标题头部 |
|
||||
| `EventDescription.js` | 事件描述文本 |
|
||||
| `EventStats.js` | 事件统计数据 |
|
||||
| `EventPriceDisplay.js` | 股价显示 |
|
||||
| `EventTimeline.js` | 事件时间线 |
|
||||
| `EventFollowButton.js` | 关注按钮 |
|
||||
| `EventImportanceBadge.js` | 重要性徽章 |
|
||||
| `ImportanceBadge.js` | 通用重要性徽章 |
|
||||
| `ImportanceStamp.js` | 重要性印章 |
|
||||
| `KeywordsCarousel.js` | 关键词轮播 |
|
||||
|
||||
**使用方式**:
|
||||
```javascript
|
||||
// 使用主卡片组件
|
||||
import EventCard from './components/EventCard';
|
||||
|
||||
// 使用原子组件
|
||||
import { EventHeader, EventTimeline } from './components/EventCard/atoms';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. HotEvents 模块(热点事件)
|
||||
|
||||
路径:`components/HotEvents/`
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `HotEvents.js` | 热点事件列表渲染 |
|
||||
| `HotEventsSection.js` | 热点事件区块容器 |
|
||||
|
||||
**使用方式**:
|
||||
```javascript
|
||||
import { HotEventsSection } from './components/HotEvents';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. DynamicNews 模块(动态新闻)
|
||||
|
||||
路径:`components/DynamicNews/`
|
||||
|
||||
### 主组件
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `DynamicNewsCard.js` | 主列表容器(695行) |
|
||||
| `EventScrollList.js` | 事件滚动列表 |
|
||||
| `EventDetailScrollPanel.js` | 事件详情滚动面板 |
|
||||
| `ModeToggleButtons.js` | 模式切换按钮 |
|
||||
| `PaginationControl.js` | 分页控制器 |
|
||||
| `constants.js` | 常量配置 |
|
||||
|
||||
### 布局组件(layouts/)
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `VerticalModeLayout.js` | 垂直布局模式 |
|
||||
| `VirtualizedFourRowGrid.js` | 虚拟滚动四行网格(性能优化) |
|
||||
|
||||
### Hooks(hooks/)
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `usePagination.js` | 分页逻辑 Hook |
|
||||
|
||||
**使用方式**:
|
||||
```javascript
|
||||
import { DynamicNewsCard } from './components/DynamicNews';
|
||||
import { usePagination } from './components/DynamicNews/hooks';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. EventDetailModal 模块(事件详情弹窗)
|
||||
|
||||
路径:`components/EventDetailModal/`
|
||||
|
||||
| 文件 | 功能 |
|
||||
|------|------|
|
||||
| `EventDetailModal.tsx` | 事件详情弹窗(TypeScript) |
|
||||
| `EventDetailModal.less` | 弹窗样式 |
|
||||
|
||||
**使用方式**:
|
||||
```javascript
|
||||
import EventDetailModal from './components/EventDetailModal';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 独立组件
|
||||
|
||||
路径:`components/`
|
||||
|
||||
| 文件 | 行数 | 功能 |
|
||||
|------|------|------|
|
||||
| `HeroPanel.js` | 972 | 首页英雄面板(指数K线 + 概念词云) |
|
||||
|
||||
**说明**:
|
||||
- `HeroPanel.js` 使用懒加载,包含 ECharts (~600KB)
|
||||
|
||||
---
|
||||
|
||||
## 更新日志
|
||||
|
||||
- **2024-12-09**: 目录结构重组
|
||||
- 创建 `SearchFilters/` 模块(含 CSS)
|
||||
- 创建 `EventCard/atoms/` 原子组件目录
|
||||
- 创建 `HotEvents/` 模块(含 CSS)
|
||||
- 重组 `DynamicNews/` 模块(含 layouts/ 和 hooks/)
|
||||
- 创建 `EventDetailModal/` 模块
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard.js
|
||||
// src/views/Community/components/DynamicNews/DynamicNewsCard.js
|
||||
// 横向滚动事件卡片组件(实时要闻·动态追踪)
|
||||
|
||||
import React, { forwardRef, useState, useEffect, useMemo, useCallback, useRef, useImperativeHandle } from 'react';
|
||||
@@ -30,23 +30,23 @@ import {
|
||||
Icon,
|
||||
} from '@chakra-ui/react';
|
||||
import { TimeIcon, BellIcon } from '@chakra-ui/icons';
|
||||
import { useNotification } from '../../../contexts/NotificationContext';
|
||||
import EventScrollList from './DynamicNewsCard/EventScrollList';
|
||||
import ModeToggleButtons from './DynamicNewsCard/ModeToggleButtons';
|
||||
import PaginationControl from './DynamicNewsCard/PaginationControl';
|
||||
import { useNotification } from '@contexts/NotificationContext';
|
||||
import EventScrollList from './EventScrollList';
|
||||
import ModeToggleButtons from './ModeToggleButtons';
|
||||
import PaginationControl from './PaginationControl';
|
||||
import DynamicNewsDetailPanel from '@components/EventDetailPanel';
|
||||
import CompactSearchBox from './CompactSearchBox';
|
||||
import CompactSearchBox from '../SearchFilters/CompactSearchBox';
|
||||
import {
|
||||
fetchDynamicNews,
|
||||
toggleEventFollow,
|
||||
selectEventFollowStatus,
|
||||
selectVerticalEventsWithLoading,
|
||||
selectFourRowEventsWithLoading
|
||||
} from '../../../store/slices/communityDataSlice';
|
||||
import { usePagination } from './DynamicNewsCard/hooks/usePagination';
|
||||
import { PAGINATION_CONFIG, DISPLAY_MODES, REFRESH_DEBOUNCE_DELAY } from './DynamicNewsCard/constants';
|
||||
import { PROFESSIONAL_COLORS } from '../../../constants/professionalTheme';
|
||||
import { debounce } from '../../../utils/debounce';
|
||||
} from '@store/slices/communityDataSlice';
|
||||
import { usePagination } from './hooks/usePagination';
|
||||
import { PAGINATION_CONFIG, DISPLAY_MODES, REFRESH_DEBOUNCE_DELAY } from './constants';
|
||||
import { PROFESSIONAL_COLORS } from '@constants/professionalTheme';
|
||||
import { debounce } from '@utils/debounce';
|
||||
import { useDevice } from '@hooks/useDevice';
|
||||
|
||||
// 🔍 调试:渲染计数器
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/EventDetailScrollPanel.js
|
||||
// src/views/Community/components/DynamicNews/EventDetailScrollPanel.js
|
||||
// 事件详情滚动面板组件
|
||||
|
||||
import React from 'react';
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/EventScrollList.js
|
||||
// src/views/Community/components/DynamicNews/EventScrollList.js
|
||||
// 横向滚动事件列表组件
|
||||
|
||||
import React, { useRef, useCallback } from 'react';
|
||||
@@ -6,8 +6,8 @@ import {
|
||||
Box,
|
||||
useColorModeValue
|
||||
} from '@chakra-ui/react';
|
||||
import VirtualizedFourRowGrid from './VirtualizedFourRowGrid';
|
||||
import VerticalModeLayout from './VerticalModeLayout';
|
||||
import VirtualizedFourRowGrid from './layouts/VirtualizedFourRowGrid';
|
||||
import VerticalModeLayout from './layouts/VerticalModeLayout';
|
||||
|
||||
/**
|
||||
* 事件列表组件 - 支持纵向和平铺两种展示模式
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/ModeToggleButtons.js
|
||||
// src/views/Community/components/DynamicNews/ModeToggleButtons.js
|
||||
// 事件列表模式切换按钮组
|
||||
|
||||
import React from 'react';
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/PaginationControl.js
|
||||
// src/views/Community/components/DynamicNews/PaginationControl.js
|
||||
// 分页控制器组件
|
||||
|
||||
import React, { useState } from 'react';
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/constants.js
|
||||
// src/views/Community/components/DynamicNews/constants.js
|
||||
// 动态新闻卡片组件 - 常量配置
|
||||
|
||||
// ========== 分页配置常量 ==========
|
||||
@@ -0,0 +1,4 @@
|
||||
// src/views/Community/components/DynamicNews/hooks/index.js
|
||||
// Hooks
|
||||
|
||||
export { usePagination } from './usePagination';
|
||||
@@ -1,9 +1,9 @@
|
||||
// src/views/Community/components/DynamicNewsCard/hooks/usePagination.js
|
||||
// src/views/Community/components/DynamicNews/hooks/usePagination.js
|
||||
// 分页逻辑自定义 Hook
|
||||
|
||||
import { useState, useMemo, useCallback, useRef } from 'react';
|
||||
import { fetchDynamicNews, updatePaginationPage } from '../../../../../store/slices/communityDataSlice';
|
||||
import { logger } from '../../../../../utils/logger';
|
||||
import { fetchDynamicNews, updatePaginationPage } from '@store/slices/communityDataSlice';
|
||||
import { logger } from '@utils/logger';
|
||||
import {
|
||||
PAGINATION_CONFIG,
|
||||
DISPLAY_MODES,
|
||||
21
src/views/Community/components/DynamicNews/index.js
Normal file
21
src/views/Community/components/DynamicNews/index.js
Normal file
@@ -0,0 +1,21 @@
|
||||
// src/views/Community/components/DynamicNews/index.js
|
||||
// 动态新闻模块
|
||||
|
||||
// 主组件
|
||||
export { default as DynamicNewsCard } from './DynamicNewsCard';
|
||||
|
||||
// 布局组件
|
||||
export { default as VerticalModeLayout } from './layouts/VerticalModeLayout';
|
||||
export { default as VirtualizedFourRowGrid } from './layouts/VirtualizedFourRowGrid';
|
||||
|
||||
// 子组件
|
||||
export { default as EventScrollList } from './EventScrollList';
|
||||
export { default as EventDetailScrollPanel } from './EventDetailScrollPanel';
|
||||
export { default as ModeToggleButtons } from './ModeToggleButtons';
|
||||
export { default as PaginationControl } from './PaginationControl';
|
||||
|
||||
// Hooks
|
||||
export { usePagination } from './hooks/usePagination';
|
||||
|
||||
// 常量
|
||||
export * from './constants';
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/VerticalModeLayout.js
|
||||
// src/views/Community/components/DynamicNews/layouts/VerticalModeLayout.js
|
||||
// 纵向分栏模式布局组件
|
||||
|
||||
import React, { useState } from 'react';
|
||||
@@ -12,10 +12,10 @@ import {
|
||||
useDisclosure
|
||||
} from '@chakra-ui/react';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
import HorizontalDynamicNewsEventCard from '../EventCard/HorizontalDynamicNewsEventCard';
|
||||
import EventDetailScrollPanel from './EventDetailScrollPanel';
|
||||
import EventDetailModal from '../EventDetailModal';
|
||||
import PaginationControl from './PaginationControl';
|
||||
import HorizontalDynamicNewsEventCard from '../../EventCard/HorizontalDynamicNewsEventCard';
|
||||
import EventDetailScrollPanel from '../EventDetailScrollPanel';
|
||||
import EventDetailModal from '../../EventDetailModal';
|
||||
import PaginationControl from '../PaginationControl';
|
||||
|
||||
/**
|
||||
* 纵向分栏模式布局
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/DynamicNewsCard/VirtualizedFourRowGrid.js
|
||||
// src/views/Community/components/DynamicNews/layouts/VirtualizedFourRowGrid.js
|
||||
// 虚拟化网格组件(支持多列布局 + 纵向滚动 + 无限滚动)
|
||||
|
||||
import React, { useRef, useMemo, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
@@ -6,7 +6,7 @@ import { useVirtualizer } from '@tanstack/react-virtual';
|
||||
import { Box, Grid, Spinner, Text, VStack, Center, HStack, IconButton, useBreakpointValue } from '@chakra-ui/react';
|
||||
import { RepeatIcon } from '@chakra-ui/icons';
|
||||
import { useColorModeValue } from '@chakra-ui/react';
|
||||
import DynamicNewsEventCard from '../EventCard/DynamicNewsEventCard';
|
||||
import DynamicNewsEventCard from '../../EventCard/DynamicNewsEventCard';
|
||||
|
||||
/**
|
||||
* 虚拟化网格组件(支持多列布局 + 无限滚动)
|
||||
@@ -0,0 +1,5 @@
|
||||
// src/views/Community/components/DynamicNews/layouts/index.js
|
||||
// 布局组件
|
||||
|
||||
export { default as VerticalModeLayout } from './VerticalModeLayout';
|
||||
export { default as VirtualizedFourRowGrid } from './VirtualizedFourRowGrid';
|
||||
@@ -1,83 +0,0 @@
|
||||
// src/views/Community/components/DynamicNewsCard/PageNavigationButton.js
|
||||
// 翻页导航按钮组件
|
||||
|
||||
import React from 'react';
|
||||
import { IconButton, useColorModeValue } from '@chakra-ui/react';
|
||||
import { ChevronLeftIcon, ChevronRightIcon } from '@chakra-ui/icons';
|
||||
|
||||
/**
|
||||
* 翻页导航按钮组件
|
||||
* @param {Object} props
|
||||
* @param {'prev'|'next'} props.direction - 按钮方向(prev=上一页,next=下一页)
|
||||
* @param {number} props.currentPage - 当前页码
|
||||
* @param {number} props.totalPages - 总页数
|
||||
* @param {Function} props.onPageChange - 翻页回调
|
||||
* @param {string} props.mode - 显示模式(只在carousel/grid模式下显示)
|
||||
*/
|
||||
const PageNavigationButton = ({
|
||||
direction,
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
mode
|
||||
}) => {
|
||||
// 主题适配
|
||||
const arrowBtnBg = useColorModeValue('rgba(255, 255, 255, 0.9)', 'rgba(0, 0, 0, 0.6)');
|
||||
const arrowBtnHoverBg = useColorModeValue('rgba(255, 255, 255, 1)', 'rgba(0, 0, 0, 0.8)');
|
||||
|
||||
// 根据方向计算配置
|
||||
const isPrev = direction === 'prev';
|
||||
const isNext = direction === 'next';
|
||||
|
||||
const Icon = isPrev ? ChevronLeftIcon : ChevronRightIcon;
|
||||
const position = isPrev ? 'left' : 'right';
|
||||
const label = isPrev ? '上一页' : '下一页';
|
||||
const targetPage = isPrev ? currentPage - 1 : currentPage + 1;
|
||||
const shouldShow = isPrev
|
||||
? currentPage > 1
|
||||
: currentPage < totalPages;
|
||||
const isDisabled = isNext ? currentPage >= totalPages : false;
|
||||
|
||||
// 判断是否显示(只在单排/双排模式显示)
|
||||
const shouldRender = shouldShow && (mode === 'carousel' || mode === 'grid');
|
||||
|
||||
if (!shouldRender) return null;
|
||||
|
||||
const handleClick = () => {
|
||||
console.log(
|
||||
`%c🔵 [翻页] 点击${label}: 当前页${currentPage} → 目标页${targetPage} (共${totalPages}页)`,
|
||||
'color: #3B82F6; font-weight: bold;'
|
||||
);
|
||||
onPageChange(targetPage);
|
||||
};
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<Icon boxSize={6} color="blue.500" />}
|
||||
position="absolute"
|
||||
{...{ [position]: 0 }}
|
||||
top="50%"
|
||||
transform="translateY(-50%)"
|
||||
zIndex={2}
|
||||
onClick={handleClick}
|
||||
variant="ghost"
|
||||
size="md"
|
||||
w="40px"
|
||||
h="40px"
|
||||
minW="40px"
|
||||
borderRadius="full"
|
||||
bg={arrowBtnBg}
|
||||
boxShadow="0 2px 8px rgba(0, 0, 0, 0.15)"
|
||||
_hover={{
|
||||
bg: arrowBtnHoverBg,
|
||||
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.2)',
|
||||
transform: 'translateY(-50%) scale(1.05)'
|
||||
}}
|
||||
isDisabled={isDisabled}
|
||||
aria-label={label}
|
||||
title={label}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default PageNavigationButton;
|
||||
@@ -1,88 +0,0 @@
|
||||
// src/views/Community/components/DynamicNewsCard/hooks/useInfiniteScroll.js
|
||||
// 无限滚动 Hook
|
||||
|
||||
import { useEffect, useRef, useCallback } from 'react';
|
||||
|
||||
/**
|
||||
* 无限滚动 Hook
|
||||
* 监听容器滚动事件,当滚动到底部附近时触发加载更多数据
|
||||
*
|
||||
* @param {Object} options - 配置选项
|
||||
* @param {Function} options.onLoadMore - 加载更多回调函数(返回 Promise)
|
||||
* @param {boolean} options.hasMore - 是否还有更多数据
|
||||
* @param {boolean} options.isLoading - 是否正在加载
|
||||
* @param {number} options.threshold - 触发阈值(距离底部多少像素时触发,默认200px)
|
||||
* @returns {Object} { containerRef } - 容器引用
|
||||
*/
|
||||
export const useInfiniteScroll = ({
|
||||
onLoadMore,
|
||||
hasMore = true,
|
||||
isLoading = false,
|
||||
threshold = 200
|
||||
}) => {
|
||||
const containerRef = useRef(null);
|
||||
const isLoadingRef = useRef(false);
|
||||
|
||||
// 滚动处理函数
|
||||
const handleScroll = useCallback(() => {
|
||||
const container = containerRef.current;
|
||||
|
||||
// 检查条件:容器存在、未加载中、还有更多数据
|
||||
if (!container || isLoadingRef.current || !hasMore) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceToBottom = scrollHeight - scrollTop - clientHeight;
|
||||
|
||||
// 距离底部小于阈值时触发加载
|
||||
if (distanceToBottom < threshold) {
|
||||
console.log(
|
||||
'%c⬇️ [懒加载] 触发加载下一页',
|
||||
'color: #8B5CF6; font-weight: bold;',
|
||||
{
|
||||
scrollTop,
|
||||
scrollHeight,
|
||||
clientHeight,
|
||||
distanceToBottom,
|
||||
threshold
|
||||
}
|
||||
);
|
||||
|
||||
isLoadingRef.current = true;
|
||||
|
||||
// 调用加载函数并更新状态
|
||||
onLoadMore()
|
||||
.then(() => {
|
||||
console.log('%c✅ [懒加载] 加载完成', 'color: #10B981; font-weight: bold;');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('%c❌ [懒加载] 加载失败', 'color: #DC2626; font-weight: bold;', error);
|
||||
})
|
||||
.finally(() => {
|
||||
isLoadingRef.current = false;
|
||||
});
|
||||
}
|
||||
}, [onLoadMore, hasMore, threshold]);
|
||||
|
||||
// 绑定滚动事件
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (!container) return;
|
||||
|
||||
// 添加滚动监听
|
||||
container.addEventListener('scroll', handleScroll, { passive: true });
|
||||
|
||||
// 清理函数
|
||||
return () => {
|
||||
container.removeEventListener('scroll', handleScroll);
|
||||
};
|
||||
}, [handleScroll]);
|
||||
|
||||
// 更新 loading 状态的 ref
|
||||
useEffect(() => {
|
||||
isLoadingRef.current = isLoading;
|
||||
}, [isLoading]);
|
||||
|
||||
return { containerRef };
|
||||
};
|
||||
@@ -12,13 +12,15 @@ import {
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import dayjs from 'dayjs';
|
||||
import { getImportanceConfig } from '../../../../constants/importanceLevels';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
|
||||
// 导入子组件
|
||||
import EventTimeline from './EventTimeline';
|
||||
import EventHeader from './EventHeader';
|
||||
import EventStats from './EventStats';
|
||||
import EventFollowButton from './EventFollowButton';
|
||||
import {
|
||||
EventTimeline,
|
||||
EventHeader,
|
||||
EventStats,
|
||||
EventFollowButton,
|
||||
} from './atoms';
|
||||
|
||||
/**
|
||||
* 紧凑模式事件卡片组件
|
||||
|
||||
@@ -10,15 +10,17 @@ import {
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import dayjs from 'dayjs';
|
||||
import { getImportanceConfig } from '../../../../constants/importanceLevels';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
|
||||
// 导入子组件
|
||||
import EventTimeline from './EventTimeline';
|
||||
import EventHeader from './EventHeader';
|
||||
import EventStats from './EventStats';
|
||||
import EventFollowButton from './EventFollowButton';
|
||||
import EventPriceDisplay from './EventPriceDisplay';
|
||||
import EventDescription from './EventDescription';
|
||||
import {
|
||||
EventTimeline,
|
||||
EventHeader,
|
||||
EventStats,
|
||||
EventFollowButton,
|
||||
EventPriceDisplay,
|
||||
EventDescription,
|
||||
} from './atoms';
|
||||
|
||||
/**
|
||||
* 详细模式事件卡片组件
|
||||
|
||||
@@ -12,12 +12,12 @@ import {
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import dayjs from 'dayjs';
|
||||
import { getImportanceConfig } from '../../../../constants/importanceLevels';
|
||||
import { getChangeColor } from '../../../../utils/colorUtils';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
import { getChangeColor } from '@utils/colorUtils';
|
||||
|
||||
// 导入子组件
|
||||
import EventFollowButton from './EventFollowButton';
|
||||
import StockChangeIndicators from '../../../../components/StockChangeIndicators';
|
||||
import { EventFollowButton } from './atoms';
|
||||
import StockChangeIndicators from '@components/StockChangeIndicators';
|
||||
|
||||
/**
|
||||
* 动态新闻事件卡片组件(极简版)
|
||||
|
||||
@@ -13,17 +13,19 @@ import {
|
||||
useColorModeValue,
|
||||
useBreakpointValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { getImportanceConfig } from '../../../../constants/importanceLevels';
|
||||
import { PROFESSIONAL_COLORS } from '../../../../constants/professionalTheme';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
import { PROFESSIONAL_COLORS } from '@constants/professionalTheme';
|
||||
import { useDevice } from '@hooks/useDevice';
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
// 导入子组件
|
||||
import ImportanceStamp from './ImportanceStamp';
|
||||
import EventTimeline from './EventTimeline';
|
||||
import EventFollowButton from './EventFollowButton';
|
||||
import StockChangeIndicators from '../../../../components/StockChangeIndicators';
|
||||
import KeywordsCarousel from './KeywordsCarousel';
|
||||
import {
|
||||
ImportanceStamp,
|
||||
EventTimeline,
|
||||
EventFollowButton,
|
||||
KeywordsCarousel,
|
||||
} from './atoms';
|
||||
import StockChangeIndicators from '@components/StockChangeIndicators';
|
||||
|
||||
/**
|
||||
* 横向布局的动态新闻事件卡片组件
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/EventCard/EventFollowButton.js
|
||||
// src/views/Community/components/EventCard/atoms/EventFollowButton.js
|
||||
import React from 'react';
|
||||
import { IconButton, Box } from '@chakra-ui/react';
|
||||
import { AiFillStar, AiOutlineStar } from 'react-icons/ai';
|
||||
@@ -1,8 +1,8 @@
|
||||
// src/views/Community/components/EventCard/EventImportanceBadge.js
|
||||
// src/views/Community/components/EventCard/atoms/EventImportanceBadge.js
|
||||
import React from 'react';
|
||||
import { Badge, Tooltip, VStack, HStack, Text, Divider, Circle } from '@chakra-ui/react';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
import { getImportanceConfig, getAllImportanceLevels } from '../../../../constants/importanceLevels';
|
||||
import { getImportanceConfig, getAllImportanceLevels } from '@constants/importanceLevels';
|
||||
|
||||
/**
|
||||
* 事件重要性等级标签组件
|
||||
@@ -1,7 +1,7 @@
|
||||
// src/views/Community/components/EventCard/EventPriceDisplay.js
|
||||
// src/views/Community/components/EventCard/atoms/EventPriceDisplay.js
|
||||
import React, { useState } from 'react';
|
||||
import { HStack, Box, Text, Tooltip, Progress } from '@chakra-ui/react';
|
||||
import { PriceArrow } from '../../../../utils/priceFormatters';
|
||||
import { PriceArrow } from '@utils/priceFormatters';
|
||||
|
||||
/**
|
||||
* 事件价格变动显示组件
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/EventCard/ImportanceBadge.js
|
||||
// src/views/Community/components/EventCard/atoms/ImportanceBadge.js
|
||||
// 重要性标签通用组件
|
||||
|
||||
import React from 'react';
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
PopoverArrow,
|
||||
Portal,
|
||||
} from '@chakra-ui/react';
|
||||
import { getImportanceConfig, getAllImportanceLevels } from '../../../../constants/importanceLevels';
|
||||
import { getImportanceConfig, getAllImportanceLevels } from '@constants/importanceLevels';
|
||||
|
||||
/**
|
||||
* 重要性标签组件(实心样式)
|
||||
@@ -1,4 +1,4 @@
|
||||
// src/views/Community/components/EventCard/ImportanceStamp.js
|
||||
// src/views/Community/components/EventCard/atoms/ImportanceStamp.js
|
||||
// 重要性印章组件
|
||||
|
||||
import React from 'react';
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
Text,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { getImportanceConfig } from '../../../../constants/importanceLevels';
|
||||
import { getImportanceConfig } from '@constants/importanceLevels';
|
||||
|
||||
/**
|
||||
* 重要性印章组件(模拟盖章效果)
|
||||
@@ -1,8 +1,8 @@
|
||||
// src/views/Community/components/EventCard/KeywordsCarousel.js
|
||||
// src/views/Community/components/EventCard/atoms/KeywordsCarousel.js
|
||||
// Keywords标签组件(点击切换)
|
||||
import React, { useState } from 'react';
|
||||
import { Box, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { PROFESSIONAL_COLORS } from '../../../../constants/professionalTheme';
|
||||
import { PROFESSIONAL_COLORS } from '@constants/professionalTheme';
|
||||
|
||||
/**
|
||||
* Keywords标签组件(点击切换下一个)
|
||||
13
src/views/Community/components/EventCard/atoms/index.js
Normal file
13
src/views/Community/components/EventCard/atoms/index.js
Normal file
@@ -0,0 +1,13 @@
|
||||
// src/views/Community/components/EventCard/atoms/index.js
|
||||
// 事件卡片原子组件
|
||||
|
||||
export { default as EventDescription } from './EventDescription';
|
||||
export { default as EventFollowButton } from './EventFollowButton';
|
||||
export { default as EventHeader } from './EventHeader';
|
||||
export { default as EventImportanceBadge } from './EventImportanceBadge';
|
||||
export { default as EventPriceDisplay } from './EventPriceDisplay';
|
||||
export { default as EventStats } from './EventStats';
|
||||
export { default as EventTimeline } from './EventTimeline';
|
||||
export { default as ImportanceBadge } from './ImportanceBadge';
|
||||
export { default as ImportanceStamp } from './ImportanceStamp';
|
||||
export { default as KeywordsCarousel } from './KeywordsCarousel';
|
||||
4
src/views/Community/components/EventDetailModal/index.ts
Normal file
4
src/views/Community/components/EventDetailModal/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
// src/views/Community/components/EventDetailModal/index.ts
|
||||
// 事件详情弹窗模块
|
||||
|
||||
export { default } from './EventDetailModal';
|
||||
@@ -1,614 +0,0 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import {
|
||||
Modal,
|
||||
ModalOverlay,
|
||||
ModalContent,
|
||||
ModalHeader,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
Box,
|
||||
Text,
|
||||
VStack,
|
||||
HStack,
|
||||
Avatar,
|
||||
Textarea,
|
||||
Button,
|
||||
Divider,
|
||||
useToast,
|
||||
Badge,
|
||||
Flex,
|
||||
IconButton,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
useColorModeValue,
|
||||
Spinner,
|
||||
Center,
|
||||
Collapse,
|
||||
Input,
|
||||
} from '@chakra-ui/react';
|
||||
import {
|
||||
ChatIcon,
|
||||
TimeIcon,
|
||||
DeleteIcon,
|
||||
EditIcon,
|
||||
ChevronDownIcon,
|
||||
TriangleDownIcon,
|
||||
TriangleUpIcon,
|
||||
} from '@chakra-ui/icons';
|
||||
import { FaHeart, FaRegHeart, FaComment } from 'react-icons/fa';
|
||||
import { format } from 'date-fns';
|
||||
import { zhCN } from 'date-fns/locale';
|
||||
import { eventService } from '../../../services/eventService';
|
||||
import { logger } from '../../../utils/logger';
|
||||
|
||||
const EventDiscussionModal = ({ isOpen, onClose, eventId, eventTitle, discussionType = '事件讨论' }) => {
|
||||
const [posts, setPosts] = useState([]);
|
||||
const [newPostContent, setNewPostContent] = useState('');
|
||||
const [newPostTitle, setNewPostTitle] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
const [expandedPosts, setExpandedPosts] = useState({});
|
||||
const [postComments, setPostComments] = useState({});
|
||||
const [replyContents, setReplyContents] = useState({});
|
||||
const [loadingComments, setLoadingComments] = useState({});
|
||||
|
||||
const toast = useToast();
|
||||
const bgColor = useColorModeValue('white', 'gray.800');
|
||||
const borderColor = useColorModeValue('gray.200', 'gray.600');
|
||||
const hoverBg = useColorModeValue('gray.50', 'gray.700');
|
||||
|
||||
// 加载帖子列表
|
||||
const loadPosts = async () => {
|
||||
if (!eventId) return;
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await fetch(`/api/events/${eventId}/posts?sort=latest&page=1&per_page=20`, {
|
||||
method: 'GET',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include'
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
setPosts(result.data || []);
|
||||
logger.debug('EventDiscussionModal', '帖子列表加载成功', {
|
||||
eventId,
|
||||
postsCount: result.data?.length || 0
|
||||
});
|
||||
} else {
|
||||
logger.error('EventDiscussionModal', 'loadPosts', new Error('API返回错误'), {
|
||||
eventId,
|
||||
status: response.status,
|
||||
message: result.message
|
||||
});
|
||||
toast({
|
||||
title: '加载帖子失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'loadPosts', error, { eventId });
|
||||
toast({
|
||||
title: '加载帖子失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 加载帖子的评论
|
||||
const loadPostComments = async (postId) => {
|
||||
setLoadingComments(prev => ({ ...prev, [postId]: true }));
|
||||
try {
|
||||
const response = await fetch(`/api/posts/${postId}/comments?sort=latest`, {
|
||||
method: 'GET',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include'
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
setPostComments(prev => ({ ...prev, [postId]: result.data || [] }));
|
||||
logger.debug('EventDiscussionModal', '评论加载成功', {
|
||||
postId,
|
||||
commentsCount: result.data?.length || 0
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'loadPostComments', error, { postId });
|
||||
} finally {
|
||||
setLoadingComments(prev => ({ ...prev, [postId]: false }));
|
||||
}
|
||||
};
|
||||
|
||||
// 切换展开/收起评论
|
||||
const togglePostComments = async (postId) => {
|
||||
const isExpanded = expandedPosts[postId];
|
||||
if (!isExpanded) {
|
||||
// 展开时加载评论
|
||||
await loadPostComments(postId);
|
||||
}
|
||||
setExpandedPosts(prev => ({ ...prev, [postId]: !isExpanded }));
|
||||
};
|
||||
|
||||
// 提交新帖子
|
||||
const handleSubmitPost = async () => {
|
||||
if (!newPostContent.trim()) return;
|
||||
|
||||
setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(`/api/events/${eventId}/posts`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({
|
||||
title: newPostTitle.trim(),
|
||||
content: newPostContent.trim(),
|
||||
content_type: 'text',
|
||||
})
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
setNewPostContent('');
|
||||
setNewPostTitle('');
|
||||
loadPosts();
|
||||
logger.info('EventDiscussionModal', '帖子发布成功', {
|
||||
eventId,
|
||||
postId: result.data?.id
|
||||
});
|
||||
toast({
|
||||
title: '帖子发布成功',
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
} else {
|
||||
logger.error('EventDiscussionModal', 'handleSubmitPost', new Error('API返回错误'), {
|
||||
eventId,
|
||||
message: result.message
|
||||
});
|
||||
toast({
|
||||
title: result.message || '帖子发布失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'handleSubmitPost', error, { eventId });
|
||||
toast({
|
||||
title: '帖子发布失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 删除帖子
|
||||
const handleDeletePost = async (postId) => {
|
||||
if (!window.confirm('确定要删除这个帖子吗?')) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/posts/${postId}`, {
|
||||
method: 'DELETE',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include'
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
loadPosts();
|
||||
logger.info('EventDiscussionModal', '帖子删除成功', { postId });
|
||||
toast({
|
||||
title: '帖子已删除',
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
} else {
|
||||
logger.error('EventDiscussionModal', 'handleDeletePost', new Error('API返回错误'), {
|
||||
postId,
|
||||
message: result.message
|
||||
});
|
||||
toast({
|
||||
title: result.message || '删除失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'handleDeletePost', error, { postId });
|
||||
toast({
|
||||
title: '删除失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 点赞帖子
|
||||
const handleLikePost = async (postId) => {
|
||||
try {
|
||||
const response = await fetch(`/api/posts/${postId}/like`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include'
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
// 更新帖子列表中的点赞状态
|
||||
setPosts(prev => prev.map(post =>
|
||||
post.id === postId
|
||||
? { ...post, likes_count: result.likes_count, liked: result.liked }
|
||||
: post
|
||||
));
|
||||
logger.debug('EventDiscussionModal', '点赞操作成功', {
|
||||
postId,
|
||||
liked: result.liked,
|
||||
likesCount: result.likes_count
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'handleLikePost', error, { postId });
|
||||
toast({
|
||||
title: '操作失败',
|
||||
status: 'error',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 提交评论
|
||||
const handleSubmitComment = async (postId) => {
|
||||
const content = replyContents[postId];
|
||||
if (!content?.trim()) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/posts/${postId}/comments`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({
|
||||
content: content.trim(),
|
||||
})
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
setReplyContents(prev => ({ ...prev, [postId]: '' }));
|
||||
// 重新加载该帖子的评论
|
||||
await loadPostComments(postId);
|
||||
// 更新帖子的评论数
|
||||
setPosts(prev => prev.map(post =>
|
||||
post.id === postId
|
||||
? { ...post, comments_count: (post.comments_count || 0) + 1 }
|
||||
: post
|
||||
));
|
||||
logger.info('EventDiscussionModal', '评论发布成功', {
|
||||
postId,
|
||||
commentId: result.data?.id
|
||||
});
|
||||
toast({
|
||||
title: '评论发布成功',
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'handleSubmitComment', error, { postId });
|
||||
toast({
|
||||
title: '评论发布失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 删除评论
|
||||
const handleDeleteComment = async (commentId, postId) => {
|
||||
if (!window.confirm('确定要删除这条评论吗?')) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/comments/${commentId}`, {
|
||||
method: 'DELETE',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include'
|
||||
});
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
// 重新加载该帖子的评论
|
||||
await loadPostComments(postId);
|
||||
// 更新帖子的评论数
|
||||
setPosts(prev => prev.map(post =>
|
||||
post.id === postId
|
||||
? { ...post, comments_count: Math.max(0, (post.comments_count || 0) - 1) }
|
||||
: post
|
||||
));
|
||||
logger.info('EventDiscussionModal', '评论删除成功', { commentId, postId });
|
||||
toast({
|
||||
title: '评论已删除',
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('EventDiscussionModal', 'handleDeleteComment', error, { commentId, postId });
|
||||
toast({
|
||||
title: '删除失败',
|
||||
status: 'error',
|
||||
duration: 3000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
loadPosts();
|
||||
}
|
||||
}, [isOpen, eventId]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} size="xl">
|
||||
<ModalOverlay />
|
||||
<ModalContent maxH="80vh">
|
||||
<ModalHeader>
|
||||
<VStack align="start" spacing={1}>
|
||||
<HStack>
|
||||
<ChatIcon />
|
||||
<Text>{discussionType}</Text>
|
||||
</HStack>
|
||||
{eventTitle && (
|
||||
<Text fontSize="sm" color="gray.500" fontWeight="normal">
|
||||
{eventTitle}
|
||||
</Text>
|
||||
)}
|
||||
</VStack>
|
||||
</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
|
||||
<ModalBody overflowY="auto">
|
||||
{/* 发布新帖子 */}
|
||||
<Box mb={4}>
|
||||
<Input
|
||||
value={newPostTitle}
|
||||
onChange={(e) => setNewPostTitle(e.target.value)}
|
||||
placeholder="帖子标题(可选)"
|
||||
size="sm"
|
||||
mb={2}
|
||||
/>
|
||||
<Textarea
|
||||
value={newPostContent}
|
||||
onChange={(e) => setNewPostContent(e.target.value)}
|
||||
placeholder="分享您的观点..."
|
||||
size="sm"
|
||||
resize="vertical"
|
||||
minH="80px"
|
||||
/>
|
||||
<Flex justify="flex-end" mt={2}>
|
||||
<Button
|
||||
colorScheme="blue"
|
||||
size="sm"
|
||||
onClick={handleSubmitPost}
|
||||
isLoading={submitting}
|
||||
isDisabled={!newPostContent.trim()}
|
||||
>
|
||||
发布帖子
|
||||
</Button>
|
||||
</Flex>
|
||||
</Box>
|
||||
|
||||
<Divider mb={4} />
|
||||
|
||||
{/* 帖子列表 */}
|
||||
{loading ? (
|
||||
<Center py={8}>
|
||||
<Spinner size="lg" />
|
||||
</Center>
|
||||
) : posts.length > 0 ? (
|
||||
<VStack spacing={4} align="stretch">
|
||||
{posts.map((post) => (
|
||||
<Box
|
||||
key={post.id}
|
||||
p={4}
|
||||
borderWidth="1px"
|
||||
borderColor={borderColor}
|
||||
borderRadius="md"
|
||||
transition="background 0.2s"
|
||||
>
|
||||
{/* 帖子头部 */}
|
||||
<Flex justify="space-between" align="start" mb={3}>
|
||||
<HStack align="start" spacing={3}>
|
||||
<Avatar
|
||||
size="sm"
|
||||
name={post.user?.username || '匿名用户'}
|
||||
src={post.user?.avatar_url}
|
||||
/>
|
||||
<VStack align="start" spacing={1} flex={1}>
|
||||
<HStack>
|
||||
<Text fontWeight="semibold" fontSize="sm">
|
||||
{post.user?.username || '匿名用户'}
|
||||
</Text>
|
||||
</HStack>
|
||||
<HStack fontSize="xs" color="gray.500">
|
||||
<TimeIcon />
|
||||
<Text>
|
||||
{format(new Date(post.created_at), 'MM月dd日 HH:mm', {
|
||||
locale: zhCN,
|
||||
})}
|
||||
</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
</HStack>
|
||||
|
||||
{/* 操作菜单 */}
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
icon={<ChevronDownIcon />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuItem
|
||||
icon={<DeleteIcon />}
|
||||
color="red.500"
|
||||
onClick={() => handleDeletePost(post.id)}
|
||||
>
|
||||
删除帖子
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
</Flex>
|
||||
|
||||
{/* 帖子标题 */}
|
||||
{post.title && (
|
||||
<Text fontSize="md" fontWeight="bold" mb={2}>
|
||||
{post.title}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{/* 帖子内容 */}
|
||||
<Text fontSize="sm" whiteSpace="pre-wrap" mb={3}>
|
||||
{post.content}
|
||||
</Text>
|
||||
|
||||
{/* 帖子操作栏 */}
|
||||
<HStack spacing={4} mb={3}>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={post.liked ? <FaHeart /> : <FaRegHeart />}
|
||||
color={post.liked ? 'red.500' : 'gray.600'}
|
||||
onClick={() => handleLikePost(post.id)}
|
||||
>
|
||||
{post.likes_count || 0}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={<FaComment />}
|
||||
onClick={() => togglePostComments(post.id)}
|
||||
rightIcon={expandedPosts[post.id] ? <TriangleUpIcon /> : <TriangleDownIcon />}
|
||||
>
|
||||
{post.comments_count || 0} 评论
|
||||
</Button>
|
||||
</HStack>
|
||||
|
||||
{/* 评论区 */}
|
||||
<Collapse in={expandedPosts[post.id]} animateOpacity>
|
||||
<Box borderTopWidth="1px" borderColor={borderColor} pt={3}>
|
||||
{/* 评论输入框 */}
|
||||
<HStack mb={3}>
|
||||
<Textarea
|
||||
size="sm"
|
||||
placeholder="写下你的评论..."
|
||||
value={replyContents[post.id] || ''}
|
||||
onChange={(e) => setReplyContents(prev => ({ ...prev, [post.id]: e.target.value }))}
|
||||
rows={2}
|
||||
flex={1}
|
||||
/>
|
||||
<Button
|
||||
size="sm"
|
||||
colorScheme="blue"
|
||||
onClick={() => handleSubmitComment(post.id)}
|
||||
isDisabled={!replyContents[post.id]?.trim()}
|
||||
>
|
||||
评论
|
||||
</Button>
|
||||
</HStack>
|
||||
|
||||
{/* 评论列表 */}
|
||||
{loadingComments[post.id] ? (
|
||||
<Center py={4}>
|
||||
<Spinner size="sm" />
|
||||
</Center>
|
||||
) : (
|
||||
<VStack align="stretch" spacing={2}>
|
||||
{postComments[post.id]?.map((comment) => (
|
||||
<Box key={comment.id} pl={4} borderLeftWidth="2px" borderColor="gray.200">
|
||||
<HStack justify="space-between" mb={1}>
|
||||
<HStack spacing={2}>
|
||||
<Avatar size="xs" name={comment.user?.username} src={comment.user?.avatar_url} />
|
||||
<Text fontSize="sm" fontWeight="medium">
|
||||
{comment.user?.username || '匿名用户'}
|
||||
</Text>
|
||||
<Text fontSize="xs" color="gray.500">
|
||||
{format(new Date(comment.created_at), 'MM-dd HH:mm')}
|
||||
</Text>
|
||||
</HStack>
|
||||
<IconButton
|
||||
size="xs"
|
||||
icon={<DeleteIcon />}
|
||||
variant="ghost"
|
||||
onClick={() => handleDeleteComment(comment.id, post.id)}
|
||||
/>
|
||||
</HStack>
|
||||
<Text fontSize="sm" pl={7}>
|
||||
{comment.content}
|
||||
</Text>
|
||||
|
||||
{/* 显示回复 */}
|
||||
{comment.replies && comment.replies.length > 0 && (
|
||||
<VStack align="stretch" mt={2} spacing={1} pl={4}>
|
||||
{comment.replies.map((reply) => (
|
||||
<Box key={reply.id}>
|
||||
<HStack spacing={1}>
|
||||
<Text fontSize="xs" fontWeight="medium">
|
||||
{reply.user?.username}:
|
||||
</Text>
|
||||
<Text fontSize="xs">{reply.content}</Text>
|
||||
</HStack>
|
||||
</Box>
|
||||
))}
|
||||
</VStack>
|
||||
)}
|
||||
</Box>
|
||||
))}
|
||||
{(!postComments[post.id] || postComments[post.id].length === 0) && (
|
||||
<Text fontSize="sm" color="gray.500" textAlign="center" py={2}>
|
||||
暂无评论
|
||||
</Text>
|
||||
)}
|
||||
</VStack>
|
||||
)}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
))}
|
||||
</VStack>
|
||||
) : (
|
||||
<Center py={8}>
|
||||
<VStack>
|
||||
<ChatIcon boxSize={8} color="gray.400" />
|
||||
<Text color="gray.500">暂无帖子,快来发表您的观点吧!</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
)}
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default EventDiscussionModal;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user