Files
vf_react/stress_test/websocket_test.py
2025-12-12 00:02:55 +08:00

178 lines
5.4 KiB
Python

# -*- coding: utf-8 -*-
"""
WebSocket 压力测试脚本
使用方式:
pip install python-socketio[client] websocket-client
# 测试 1000 个 WebSocket 连接
python websocket_test.py --url wss://valuefrontier.cn --connections 1000
# 测试 5000 个连接,持续 5 分钟
python websocket_test.py --url wss://valuefrontier.cn --connections 5000 --duration 300
"""
import argparse
import asyncio
import time
import statistics
from datetime import datetime
import socketio
# 统计数据
stats = {
"connected": 0,
"disconnected": 0,
"messages_received": 0,
"errors": 0,
"connect_times": [],
}
async def create_client(client_id, url, namespace="/"):
"""创建单个 WebSocket 客户端"""
sio = socketio.AsyncClient(
reconnection=False,
logger=False,
engineio_logger=False
)
start_time = time.time()
@sio.event
async def connect():
connect_time = (time.time() - start_time) * 1000
stats["connected"] += 1
stats["connect_times"].append(connect_time)
@sio.event
async def disconnect():
stats["disconnected"] += 1
@sio.event
async def message(data):
stats["messages_received"] += 1
@sio.on("*")
async def catch_all(event, data):
stats["messages_received"] += 1
try:
await sio.connect(url, namespaces=[namespace])
return sio
except Exception as e:
stats["errors"] += 1
return None
async def run_test(url, num_connections, duration, batch_size=100):
"""运行 WebSocket 压力测试"""
print("=" * 60)
print(f"🚀 WebSocket 压力测试")
print(f" 目标: {url}")
print(f" 连接数: {num_connections}")
print(f" 持续时间: {duration}")
print(f" 批量大小: {batch_size}")
print("=" * 60)
clients = []
start_time = time.time()
# 分批创建连接
print(f"\n📡 开始创建 {num_connections} 个 WebSocket 连接...")
for i in range(0, num_connections, batch_size):
batch_end = min(i + batch_size, num_connections)
batch_tasks = [
create_client(j, url)
for j in range(i, batch_end)
]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
for result in batch_results:
if result and not isinstance(result, Exception):
clients.append(result)
# 打印进度
progress = (batch_end / num_connections) * 100
print(f" 进度: {batch_end}/{num_connections} ({progress:.1f}%) - "
f"成功: {stats['connected']}, 失败: {stats['errors']}")
# 短暂暂停,避免连接风暴
await asyncio.sleep(0.1)
connect_duration = time.time() - start_time
print(f"\n✅ 连接阶段完成!")
print(f" 耗时: {connect_duration:.2f}")
print(f" 成功连接: {stats['connected']}")
print(f" 连接失败: {stats['errors']}")
if stats["connect_times"]:
print(f" 平均连接时间: {statistics.mean(stats['connect_times']):.2f} ms")
print(f" P99 连接时间: {statistics.quantiles(stats['connect_times'], n=100)[98]:.2f} ms")
# 保持连接一段时间
print(f"\n⏳ 保持连接 {duration} 秒...")
messages_before = stats["messages_received"]
await asyncio.sleep(duration)
messages_after = stats["messages_received"]
messages_during = messages_after - messages_before
# 断开所有连接
print("\n📴 断开所有连接...")
disconnect_tasks = [
client.disconnect()
for client in clients
if client and client.connected
]
await asyncio.gather(*disconnect_tasks, return_exceptions=True)
# 打印最终统计
total_duration = time.time() - start_time
print("\n" + "=" * 60)
print("📊 测试结果")
print("=" * 60)
print(f" 总耗时: {total_duration:.2f}")
print(f" 成功连接: {stats['connected']}")
print(f" 连接失败: {stats['errors']}")
print(f" 连接成功率: {(stats['connected'] / num_connections * 100):.2f}%")
print(f" 收到消息数: {messages_during}")
print(f" 消息速率: {(messages_during / duration):.2f} msg/s")
if stats["connect_times"]:
print(f"\n 连接延迟统计:")
print(f" 最小: {min(stats['connect_times']):.2f} ms")
print(f" 最大: {max(stats['connect_times']):.2f} ms")
print(f" 平均: {statistics.mean(stats['connect_times']):.2f} ms")
print(f" 中位数: {statistics.median(stats['connect_times']):.2f} ms")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description="WebSocket 压力测试")
parser.add_argument("--url", default="wss://valuefrontier.cn",
help="WebSocket URL (default: wss://valuefrontier.cn)")
parser.add_argument("--connections", type=int, default=1000,
help="并发连接数 (default: 1000)")
parser.add_argument("--duration", type=int, default=60,
help="测试持续时间(秒) (default: 60)")
parser.add_argument("--batch", type=int, default=100,
help="批量创建连接数 (default: 100)")
args = parser.parse_args()
asyncio.run(run_test(
url=args.url,
num_connections=args.connections,
duration=args.duration,
batch_size=args.batch
))
if __name__ == "__main__":
main()