1033 lines
33 KiB
Python
1033 lines
33 KiB
Python
"""
|
||
MySQL数据库查询模块
|
||
提供股票财务数据查询功能
|
||
"""
|
||
|
||
import aiomysql
|
||
import logging
|
||
from typing import Dict, List, Any, Optional
|
||
from datetime import datetime, date
|
||
from decimal import Decimal
|
||
import json
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# MySQL连接配置
|
||
MYSQL_CONFIG = {
|
||
'host': '222.128.1.157',
|
||
'port': 33060,
|
||
'user': 'root',
|
||
'password': 'Zzl5588161!',
|
||
'db': 'stock',
|
||
'charset': 'utf8mb4',
|
||
'autocommit': True
|
||
}
|
||
|
||
# 全局连接池
|
||
_pool = None
|
||
|
||
|
||
class DateTimeEncoder(json.JSONEncoder):
|
||
"""JSON编码器,处理datetime和Decimal类型"""
|
||
def default(self, obj):
|
||
if isinstance(obj, (datetime, date)):
|
||
return obj.isoformat()
|
||
if isinstance(obj, Decimal):
|
||
return float(obj)
|
||
return super().default(obj)
|
||
|
||
|
||
async def get_pool():
|
||
"""获取MySQL连接池"""
|
||
global _pool
|
||
if _pool is None:
|
||
_pool = await aiomysql.create_pool(
|
||
host=MYSQL_CONFIG['host'],
|
||
port=MYSQL_CONFIG['port'],
|
||
user=MYSQL_CONFIG['user'],
|
||
password=MYSQL_CONFIG['password'],
|
||
db=MYSQL_CONFIG['db'],
|
||
charset=MYSQL_CONFIG['charset'],
|
||
autocommit=MYSQL_CONFIG['autocommit'],
|
||
minsize=1,
|
||
maxsize=10
|
||
)
|
||
logger.info("MySQL connection pool created")
|
||
return _pool
|
||
|
||
|
||
async def close_pool():
|
||
"""关闭MySQL连接池"""
|
||
global _pool
|
||
if _pool:
|
||
_pool.close()
|
||
await _pool.wait_closed()
|
||
_pool = None
|
||
logger.info("MySQL connection pool closed")
|
||
|
||
|
||
def convert_row(row: Dict) -> Dict:
|
||
"""转换数据库行,处理特殊类型"""
|
||
if not row:
|
||
return {}
|
||
|
||
result = {}
|
||
for key, value in row.items():
|
||
if isinstance(value, Decimal):
|
||
result[key] = float(value)
|
||
elif isinstance(value, (datetime, date)):
|
||
result[key] = value.isoformat()
|
||
else:
|
||
result[key] = value
|
||
return result
|
||
|
||
|
||
async def get_stock_basic_info(seccode: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取股票基本信息
|
||
|
||
Args:
|
||
seccode: 股票代码
|
||
|
||
Returns:
|
||
股票基本信息字典
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT
|
||
SECCODE, SECNAME, ORGNAME,
|
||
F001V as english_name,
|
||
F003V as legal_representative,
|
||
F004V as registered_address,
|
||
F005V as office_address,
|
||
F010D as establishment_date,
|
||
F011V as website,
|
||
F012V as email,
|
||
F013V as phone,
|
||
F015V as main_business,
|
||
F016V as business_scope,
|
||
F017V as company_profile,
|
||
F030V as industry_level1,
|
||
F032V as industry_level2,
|
||
F034V as sw_industry_level1,
|
||
F036V as sw_industry_level2,
|
||
F026V as province,
|
||
F028V as city,
|
||
F041V as chairman,
|
||
F042V as general_manager,
|
||
UPDATE_DATE as update_date
|
||
FROM ea_baseinfo
|
||
WHERE SECCODE = %s
|
||
LIMIT 1
|
||
"""
|
||
|
||
await cursor.execute(query, (seccode,))
|
||
result = await cursor.fetchone()
|
||
|
||
if result:
|
||
return convert_row(result)
|
||
return None
|
||
|
||
|
||
async def get_stock_financial_index(
|
||
seccode: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
limit: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票财务指标
|
||
|
||
Args:
|
||
seccode: 股票代码
|
||
start_date: 开始日期 YYYY-MM-DD
|
||
end_date: 结束日期 YYYY-MM-DD
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
财务指标列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 构建查询
|
||
query = """
|
||
SELECT
|
||
SECCODE, SECNAME, ENDDATE, STARTDATE,
|
||
F069D as report_year,
|
||
F003N as eps, -- 每股收益
|
||
F004N as basic_eps,
|
||
F008N as bps, -- 每股净资产
|
||
F014N as roe, -- 净资产收益率
|
||
F016N as roa, -- 总资产报酬率
|
||
F017N as net_profit_margin, -- 净利润率
|
||
F022N as receivable_turnover, -- 应收账款周转率
|
||
F023N as inventory_turnover, -- 存货周转率
|
||
F025N as total_asset_turnover, -- 总资产周转率
|
||
F041N as debt_ratio, -- 资产负债率
|
||
F042N as current_ratio, -- 流动比率
|
||
F043N as quick_ratio, -- 速动比率
|
||
F052N as revenue_growth, -- 营业收入增长率
|
||
F053N as profit_growth, -- 净利润增长率
|
||
F089N as revenue, -- 营业收入
|
||
F090N as operating_cost, -- 营业成本
|
||
F101N as net_profit, -- 净利润
|
||
F102N as net_profit_parent -- 归母净利润
|
||
FROM ea_financialindex
|
||
WHERE SECCODE = %s
|
||
"""
|
||
|
||
params = [seccode]
|
||
|
||
if start_date:
|
||
query += " AND ENDDATE >= %s"
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
query += " AND ENDDATE <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY ENDDATE DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
await cursor.execute(query, params)
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def get_stock_trade_data(
|
||
seccode: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
limit: int = 30
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票交易数据
|
||
|
||
Args:
|
||
seccode: 股票代码
|
||
start_date: 开始日期 YYYY-MM-DD
|
||
end_date: 结束日期 YYYY-MM-DD
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
交易数据列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT
|
||
SECCODE, SECNAME, TRADEDATE,
|
||
F002N as prev_close, -- 昨日收盘价
|
||
F003N as open_price, -- 开盘价
|
||
F005N as high_price, -- 最高价
|
||
F006N as low_price, -- 最低价
|
||
F007N as close_price, -- 收盘价
|
||
F004N as volume, -- 成交量
|
||
F011N as turnover, -- 成交金额
|
||
F009N as change_amount, -- 涨跌额
|
||
F010N as change_pct, -- 涨跌幅
|
||
F012N as turnover_rate, -- 换手率
|
||
F013N as amplitude, -- 振幅
|
||
F026N as pe_ratio, -- 市盈率
|
||
F020N as total_shares, -- 总股本
|
||
F021N as circulating_shares -- 流通股本
|
||
FROM ea_trade
|
||
WHERE SECCODE = %s
|
||
"""
|
||
|
||
params = [seccode]
|
||
|
||
if start_date:
|
||
query += " AND TRADEDATE >= %s"
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
query += " AND TRADEDATE <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY TRADEDATE DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
await cursor.execute(query, params)
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def get_stock_balance_sheet(
|
||
seccode: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
limit: int = 8
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取资产负债表数据
|
||
|
||
Args:
|
||
seccode: 股票代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
资产负债表数据列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT
|
||
SECCODE, SECNAME, ENDDATE,
|
||
F001D as report_year,
|
||
F006N as cash, -- 货币资金
|
||
F009N as receivables, -- 应收账款
|
||
F015N as inventory, -- 存货
|
||
F019N as current_assets, -- 流动资产合计
|
||
F023N as long_term_investment, -- 长期股权投资
|
||
F025N as fixed_assets, -- 固定资产
|
||
F037N as noncurrent_assets, -- 非流动资产合计
|
||
F038N as total_assets, -- 资产总计
|
||
F039N as short_term_loan, -- 短期借款
|
||
F042N as payables, -- 应付账款
|
||
F052N as current_liabilities, -- 流动负债合计
|
||
F053N as long_term_loan, -- 长期借款
|
||
F060N as noncurrent_liabilities, -- 非流动负债合计
|
||
F061N as total_liabilities, -- 负债合计
|
||
F062N as share_capital, -- 股本
|
||
F063N as capital_reserve, -- 资本公积
|
||
F065N as retained_earnings, -- 未分配利润
|
||
F070N as total_equity -- 所有者权益合计
|
||
FROM ea_asset
|
||
WHERE SECCODE = %s
|
||
"""
|
||
|
||
params = [seccode]
|
||
|
||
if start_date:
|
||
query += " AND ENDDATE >= %s"
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
query += " AND ENDDATE <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY ENDDATE DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
await cursor.execute(query, params)
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def get_stock_cashflow(
|
||
seccode: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
limit: int = 8
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取现金流量表数据
|
||
|
||
Args:
|
||
seccode: 股票代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
现金流量表数据列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT
|
||
SECCODE, SECNAME, ENDDATE, STARTDATE,
|
||
F001D as report_year,
|
||
F009N as operating_cash_inflow, -- 经营活动现金流入
|
||
F014N as operating_cash_outflow, -- 经营活动现金流出
|
||
F015N as net_operating_cashflow, -- 经营活动现金流量净额
|
||
F021N as investing_cash_inflow, -- 投资活动现金流入
|
||
F026N as investing_cash_outflow, -- 投资活动现金流出
|
||
F027N as net_investing_cashflow, -- 投资活动现金流量净额
|
||
F031N as financing_cash_inflow, -- 筹资活动现金流入
|
||
F035N as financing_cash_outflow, -- 筹资活动现金流出
|
||
F036N as net_financing_cashflow, -- 筹资活动现金流量净额
|
||
F039N as net_cash_increase, -- 现金及现金等价物净增加额
|
||
F044N as net_profit, -- 净利润
|
||
F046N as depreciation, -- 固定资产折旧
|
||
F060N as net_operating_cashflow_adjusted -- 经营活动现金流量净额(补充)
|
||
FROM ea_cashflow
|
||
WHERE SECCODE = %s
|
||
"""
|
||
|
||
params = [seccode]
|
||
|
||
if start_date:
|
||
query += " AND ENDDATE >= %s"
|
||
params.append(start_date)
|
||
|
||
if end_date:
|
||
query += " AND ENDDATE <= %s"
|
||
params.append(end_date)
|
||
|
||
query += " ORDER BY ENDDATE DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
await cursor.execute(query, params)
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def search_stocks_by_criteria(
|
||
industry: Optional[str] = None,
|
||
province: Optional[str] = None,
|
||
min_market_cap: Optional[float] = None,
|
||
max_market_cap: Optional[float] = None,
|
||
limit: int = 50
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
按条件搜索股票
|
||
|
||
Args:
|
||
industry: 行业名称
|
||
province: 省份
|
||
min_market_cap: 最小市值(亿元)
|
||
max_market_cap: 最大市值(亿元)
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
股票列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT DISTINCT
|
||
b.SECCODE,
|
||
b.SECNAME,
|
||
b.F030V as industry_level1,
|
||
b.F032V as industry_level2,
|
||
b.F034V as sw_industry_level1,
|
||
b.F026V as province,
|
||
b.F028V as city,
|
||
b.F015V as main_business,
|
||
t.F007N as latest_price,
|
||
t.F010N as change_pct,
|
||
t.F026N as pe_ratio,
|
||
t.TRADEDATE as latest_trade_date
|
||
FROM ea_baseinfo b
|
||
LEFT JOIN (
|
||
SELECT SECCODE, MAX(TRADEDATE) as max_date
|
||
FROM ea_trade
|
||
GROUP BY SECCODE
|
||
) latest ON b.SECCODE = latest.SECCODE
|
||
LEFT JOIN ea_trade t ON b.SECCODE = t.SECCODE
|
||
AND t.TRADEDATE = latest.max_date
|
||
WHERE 1=1
|
||
"""
|
||
|
||
params = []
|
||
|
||
if industry:
|
||
query += " AND (b.F030V LIKE %s OR b.F032V LIKE %s OR b.F034V LIKE %s)"
|
||
pattern = f"%{industry}%"
|
||
params.extend([pattern, pattern, pattern])
|
||
|
||
if province:
|
||
query += " AND b.F026V = %s"
|
||
params.append(province)
|
||
|
||
if min_market_cap or max_market_cap:
|
||
# 市值 = 最新价 * 总股本 / 100000000(转换为亿元)
|
||
if min_market_cap:
|
||
query += " AND (t.F007N * t.F020N / 100000000) >= %s"
|
||
params.append(min_market_cap)
|
||
|
||
if max_market_cap:
|
||
query += " AND (t.F007N * t.F020N / 100000000) <= %s"
|
||
params.append(max_market_cap)
|
||
|
||
query += " ORDER BY t.TRADEDATE DESC LIMIT %s"
|
||
params.append(limit)
|
||
|
||
await cursor.execute(query, params)
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def get_stock_comparison(
|
||
seccodes: List[str],
|
||
metric: str = "financial"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
股票对比分析
|
||
|
||
Args:
|
||
seccodes: 股票代码列表
|
||
metric: 对比指标类型 (financial/trade)
|
||
|
||
Returns:
|
||
对比数据
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
if not seccodes or len(seccodes) < 2:
|
||
return {"error": "至少需要2个股票代码进行对比"}
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
placeholders = ','.join(['%s'] * len(seccodes))
|
||
|
||
if metric == "financial":
|
||
# 对比最新财务指标
|
||
query = f"""
|
||
SELECT
|
||
f.SECCODE, f.SECNAME, f.ENDDATE,
|
||
f.F003N as eps,
|
||
f.F008N as bps,
|
||
f.F014N as roe,
|
||
f.F017N as net_profit_margin,
|
||
f.F041N as debt_ratio,
|
||
f.F052N as revenue_growth,
|
||
f.F053N as profit_growth,
|
||
f.F089N as revenue,
|
||
f.F101N as net_profit
|
||
FROM ea_financialindex f
|
||
INNER JOIN (
|
||
SELECT SECCODE, MAX(ENDDATE) as max_date
|
||
FROM ea_financialindex
|
||
WHERE SECCODE IN ({placeholders})
|
||
GROUP BY SECCODE
|
||
) latest ON f.SECCODE = latest.SECCODE
|
||
AND f.ENDDATE = latest.max_date
|
||
"""
|
||
else: # trade
|
||
# 对比最新交易数据
|
||
query = f"""
|
||
SELECT
|
||
t.SECCODE, t.SECNAME, t.TRADEDATE,
|
||
t.F007N as close_price,
|
||
t.F010N as change_pct,
|
||
t.F012N as turnover_rate,
|
||
t.F026N as pe_ratio,
|
||
t.F020N as total_shares,
|
||
t.F021N as circulating_shares
|
||
FROM ea_trade t
|
||
INNER JOIN (
|
||
SELECT SECCODE, MAX(TRADEDATE) as max_date
|
||
FROM ea_trade
|
||
WHERE SECCODE IN ({placeholders})
|
||
GROUP BY SECCODE
|
||
) latest ON t.SECCODE = latest.SECCODE
|
||
AND t.TRADEDATE = latest.max_date
|
||
"""
|
||
|
||
await cursor.execute(query, seccodes)
|
||
results = await cursor.fetchall()
|
||
|
||
return {
|
||
"comparison_type": metric,
|
||
"stocks": [convert_row(row) for row in results]
|
||
}
|
||
|
||
|
||
async def get_user_favorite_stocks(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户自选股列表
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
自选股列表(包含最新行情数据)
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 查询用户自选股(假设有 user_favorites 表)
|
||
# 如果没有此表,可以根据实际情况调整
|
||
query = """
|
||
SELECT
|
||
f.user_id,
|
||
f.stock_code,
|
||
b.SECNAME as stock_name,
|
||
b.F030V as industry,
|
||
t.F007N as current_price,
|
||
t.F010N as change_pct,
|
||
t.F012N as turnover_rate,
|
||
t.F026N as pe_ratio,
|
||
t.TRADEDATE as latest_trade_date,
|
||
f.created_at as favorite_time
|
||
FROM user_favorites f
|
||
INNER JOIN ea_baseinfo b ON f.stock_code = b.SECCODE
|
||
LEFT JOIN (
|
||
SELECT SECCODE, MAX(TRADEDATE) as max_date
|
||
FROM ea_trade
|
||
GROUP BY SECCODE
|
||
) latest ON b.SECCODE = latest.SECCODE
|
||
LEFT JOIN ea_trade t ON b.SECCODE = t.SECCODE
|
||
AND t.TRADEDATE = latest.max_date
|
||
WHERE f.user_id = %s AND f.is_deleted = 0
|
||
ORDER BY f.created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
|
||
await cursor.execute(query, [user_id, limit])
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def get_user_favorite_events(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户自选事件列表
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
limit: 返回条数
|
||
|
||
Returns:
|
||
自选事件列表
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 查询用户自选事件(假设有 user_event_favorites 表)
|
||
query = """
|
||
SELECT
|
||
f.user_id,
|
||
f.event_id,
|
||
e.title,
|
||
e.description,
|
||
e.event_date,
|
||
e.importance,
|
||
e.related_stocks,
|
||
e.category,
|
||
f.created_at as favorite_time
|
||
FROM user_event_favorites f
|
||
INNER JOIN events e ON f.event_id = e.id
|
||
WHERE f.user_id = %s AND f.is_deleted = 0
|
||
ORDER BY e.event_date DESC
|
||
LIMIT %s
|
||
"""
|
||
|
||
await cursor.execute(query, [user_id, limit])
|
||
results = await cursor.fetchall()
|
||
|
||
return [convert_row(row) for row in results]
|
||
|
||
|
||
async def add_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
|
||
"""
|
||
添加自选股
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 检查是否已存在
|
||
check_query = """
|
||
SELECT id, is_deleted
|
||
FROM user_favorites
|
||
WHERE user_id = %s AND stock_code = %s
|
||
"""
|
||
await cursor.execute(check_query, [user_id, stock_code])
|
||
existing = await cursor.fetchone()
|
||
|
||
if existing:
|
||
if existing['is_deleted'] == 1:
|
||
# 恢复已删除的记录
|
||
update_query = """
|
||
UPDATE user_favorites
|
||
SET is_deleted = 0, updated_at = NOW()
|
||
WHERE id = %s
|
||
"""
|
||
await cursor.execute(update_query, [existing['id']])
|
||
return {"success": True, "message": "已恢复自选股"}
|
||
else:
|
||
return {"success": False, "message": "该股票已在自选中"}
|
||
|
||
# 插入新记录
|
||
insert_query = """
|
||
INSERT INTO user_favorites (user_id, stock_code, created_at, updated_at, is_deleted)
|
||
VALUES (%s, %s, NOW(), NOW(), 0)
|
||
"""
|
||
await cursor.execute(insert_query, [user_id, stock_code])
|
||
return {"success": True, "message": "添加自选股成功"}
|
||
|
||
|
||
async def remove_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
|
||
"""
|
||
删除自选股
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
UPDATE user_favorites
|
||
SET is_deleted = 1, updated_at = NOW()
|
||
WHERE user_id = %s AND stock_code = %s AND is_deleted = 0
|
||
"""
|
||
result = await cursor.execute(query, [user_id, stock_code])
|
||
|
||
if result > 0:
|
||
return {"success": True, "message": "删除自选股成功"}
|
||
else:
|
||
return {"success": False, "message": "未找到该自选股"}
|
||
|
||
|
||
async def add_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
|
||
"""
|
||
添加自选事件
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
event_id: 事件ID
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 检查是否已存在
|
||
check_query = """
|
||
SELECT id, is_deleted
|
||
FROM user_event_favorites
|
||
WHERE user_id = %s AND event_id = %s
|
||
"""
|
||
await cursor.execute(check_query, [user_id, event_id])
|
||
existing = await cursor.fetchone()
|
||
|
||
if existing:
|
||
if existing['is_deleted'] == 1:
|
||
# 恢复已删除的记录
|
||
update_query = """
|
||
UPDATE user_event_favorites
|
||
SET is_deleted = 0, updated_at = NOW()
|
||
WHERE id = %s
|
||
"""
|
||
await cursor.execute(update_query, [existing['id']])
|
||
return {"success": True, "message": "已恢复自选事件"}
|
||
else:
|
||
return {"success": False, "message": "该事件已在自选中"}
|
||
|
||
# 插入新记录
|
||
insert_query = """
|
||
INSERT INTO user_event_favorites (user_id, event_id, created_at, updated_at, is_deleted)
|
||
VALUES (%s, %s, NOW(), NOW(), 0)
|
||
"""
|
||
await cursor.execute(insert_query, [user_id, event_id])
|
||
return {"success": True, "message": "添加自选事件成功"}
|
||
|
||
|
||
async def remove_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
|
||
"""
|
||
删除自选事件
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
event_id: 事件ID
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
pool = await get_pool()
|
||
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
UPDATE user_event_favorites
|
||
SET is_deleted = 1, updated_at = NOW()
|
||
WHERE user_id = %s AND event_id = %s AND is_deleted = 0
|
||
"""
|
||
result = await cursor.execute(query, [user_id, event_id])
|
||
|
||
if result > 0:
|
||
return {"success": True, "message": "删除自选事件成功"}
|
||
else:
|
||
return {"success": False, "message": "未找到该自选事件"}
|
||
|
||
|
||
# ==================== ClickHouse 分钟频数据查询 ====================
|
||
|
||
from clickhouse_driver import Client as ClickHouseClient
|
||
|
||
# ClickHouse 连接配置
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '222.128.1.157',
|
||
'port': 18000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
# ClickHouse 客户端(懒加载)
|
||
_clickhouse_client = None
|
||
|
||
|
||
def get_clickhouse_client():
|
||
"""获取 ClickHouse 客户端(单例模式)"""
|
||
global _clickhouse_client
|
||
if _clickhouse_client is None:
|
||
_clickhouse_client = ClickHouseClient(
|
||
host=CLICKHOUSE_CONFIG['host'],
|
||
port=CLICKHOUSE_CONFIG['port'],
|
||
user=CLICKHOUSE_CONFIG['user'],
|
||
password=CLICKHOUSE_CONFIG['password'],
|
||
database=CLICKHOUSE_CONFIG['database']
|
||
)
|
||
logger.info("ClickHouse client created")
|
||
return _clickhouse_client
|
||
|
||
|
||
async def get_stock_minute_data(
|
||
code: str,
|
||
start_time: Optional[str] = None,
|
||
end_time: Optional[str] = None,
|
||
limit: int = 240
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票分钟频数据
|
||
|
||
Args:
|
||
code: 股票代码(例如:'600519' 或 '600519.SH')
|
||
start_time: 开始时间,格式:YYYY-MM-DD HH:MM:SS 或 YYYY-MM-DD
|
||
end_time: 结束时间,格式:YYYY-MM-DD HH:MM:SS 或 YYYY-MM-DD
|
||
limit: 返回条数,默认240(一个交易日的分钟数据)
|
||
|
||
Returns:
|
||
分钟频数据列表
|
||
"""
|
||
try:
|
||
client = get_clickhouse_client()
|
||
|
||
# 标准化股票代码:ClickHouse 分钟数据使用带后缀格式
|
||
# 6开头 -> .SH (上海), 0/3开头 -> .SZ (深圳), 其他 -> .BJ (北京)
|
||
if '.' in code:
|
||
# 已经有后缀,直接使用
|
||
stock_code = code
|
||
else:
|
||
# 需要添加后缀
|
||
if code.startswith('6'):
|
||
stock_code = f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
stock_code = f"{code}.SZ"
|
||
else:
|
||
stock_code = f"{code}.BJ"
|
||
|
||
# 构建查询 - 使用字符串格式化(ClickHouse 参数化语法兼容性问题)
|
||
query = f"""
|
||
SELECT
|
||
code,
|
||
timestamp,
|
||
open,
|
||
high,
|
||
low,
|
||
close,
|
||
volume,
|
||
amt
|
||
FROM stock_minute
|
||
WHERE code = '{stock_code}'
|
||
"""
|
||
|
||
if start_time:
|
||
query += f" AND timestamp >= '{start_time}'"
|
||
|
||
if end_time:
|
||
query += f" AND timestamp <= '{end_time}'"
|
||
|
||
query += f" ORDER BY timestamp DESC LIMIT {limit}"
|
||
|
||
# 执行查询
|
||
result = client.execute(query, with_column_types=True)
|
||
|
||
rows = result[0]
|
||
columns = [col[0] for col in result[1]]
|
||
|
||
# 转换为字典列表
|
||
data = []
|
||
for row in rows:
|
||
record = {}
|
||
for i, col in enumerate(columns):
|
||
value = row[i]
|
||
# 处理 datetime 类型
|
||
if hasattr(value, 'isoformat'):
|
||
record[col] = value.isoformat()
|
||
else:
|
||
record[col] = value
|
||
data.append(record)
|
||
|
||
logger.info(f"[ClickHouse] 查询分钟数据: code={stock_code}, 返回 {len(data)} 条记录")
|
||
return data
|
||
|
||
except Exception as e:
|
||
logger.error(f"[ClickHouse] 查询分钟数据失败: {e}", exc_info=True)
|
||
return []
|
||
|
||
|
||
async def get_stock_minute_aggregation(
|
||
code: str,
|
||
date: str,
|
||
interval: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票分钟频数据的聚合(按指定分钟间隔)
|
||
|
||
Args:
|
||
code: 股票代码
|
||
date: 日期,格式:YYYY-MM-DD
|
||
interval: 聚合间隔(分钟),默认5分钟
|
||
|
||
Returns:
|
||
聚合后的K线数据
|
||
"""
|
||
try:
|
||
client = get_clickhouse_client()
|
||
|
||
# 标准化股票代码
|
||
stock_code = code.split('.')[0] if '.' in code else code
|
||
|
||
# 使用 ClickHouse 的时间函数进行聚合
|
||
query = f"""
|
||
SELECT
|
||
code,
|
||
toStartOfInterval(timestamp, INTERVAL {interval} MINUTE) as interval_start,
|
||
argMin(open, timestamp) as open,
|
||
max(high) as high,
|
||
min(low) as low,
|
||
argMax(close, timestamp) as close,
|
||
sum(volume) as volume,
|
||
sum(amt) as amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND toDate(timestamp) = %(date)s
|
||
GROUP BY code, interval_start
|
||
ORDER BY interval_start
|
||
"""
|
||
|
||
params = {'code': stock_code, 'date': date}
|
||
|
||
result = client.execute(query, params, with_column_types=True)
|
||
|
||
rows = result[0]
|
||
columns = [col[0] for col in result[1]]
|
||
|
||
data = []
|
||
for row in rows:
|
||
record = {}
|
||
for i, col in enumerate(columns):
|
||
value = row[i]
|
||
if hasattr(value, 'isoformat'):
|
||
record[col] = value.isoformat()
|
||
else:
|
||
record[col] = value
|
||
data.append(record)
|
||
|
||
logger.info(f"[ClickHouse] 聚合分钟数据: code={stock_code}, date={date}, interval={interval}min, 返回 {len(data)} 条记录")
|
||
return data
|
||
|
||
except Exception as e:
|
||
logger.error(f"[ClickHouse] 聚合分钟数据失败: {e}", exc_info=True)
|
||
return []
|
||
|
||
|
||
async def get_stock_intraday_statistics(
|
||
code: str,
|
||
date: str
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取股票日内统计数据
|
||
|
||
Args:
|
||
code: 股票代码
|
||
date: 日期,格式:YYYY-MM-DD
|
||
|
||
Returns:
|
||
日内统计数据(开盘价、最高价、最低价、收盘价、成交量、成交额、波动率等)
|
||
"""
|
||
try:
|
||
client = get_clickhouse_client()
|
||
|
||
stock_code = code.split('.')[0] if '.' in code else code
|
||
|
||
query = """
|
||
SELECT
|
||
code,
|
||
toDate(timestamp) as trade_date,
|
||
argMin(open, timestamp) as open,
|
||
max(high) as high,
|
||
min(low) as low,
|
||
argMax(close, timestamp) as close,
|
||
sum(volume) as total_volume,
|
||
sum(amt) as total_amount,
|
||
count(*) as data_points,
|
||
min(timestamp) as first_time,
|
||
max(timestamp) as last_time,
|
||
(max(high) - min(low)) / min(low) * 100 as intraday_range_pct,
|
||
stddevPop(close) as price_volatility
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND toDate(timestamp) = %(date)s
|
||
GROUP BY code, trade_date
|
||
"""
|
||
|
||
params = {'code': stock_code, 'date': date}
|
||
|
||
result = client.execute(query, params, with_column_types=True)
|
||
|
||
if not result[0]:
|
||
return {"success": False, "error": f"未找到 {stock_code} 在 {date} 的分钟数据"}
|
||
|
||
row = result[0][0]
|
||
columns = [col[0] for col in result[1]]
|
||
|
||
data = {}
|
||
for i, col in enumerate(columns):
|
||
value = row[i]
|
||
if hasattr(value, 'isoformat'):
|
||
data[col] = value.isoformat()
|
||
else:
|
||
data[col] = float(value) if isinstance(value, (int, float)) else value
|
||
|
||
logger.info(f"[ClickHouse] 日内统计: code={stock_code}, date={date}")
|
||
return {"success": True, "data": data}
|
||
|
||
except Exception as e:
|
||
logger.error(f"[ClickHouse] 日内统计失败: {e}", exc_info=True)
|
||
return {"success": False, "error": str(e)}
|