178 lines
5.4 KiB
Python
178 lines
5.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
WebSocket 压力测试脚本
|
|
|
|
使用方式:
|
|
pip install python-socketio[client] websocket-client
|
|
|
|
# 测试 1000 个 WebSocket 连接
|
|
python websocket_test.py --url wss://valuefrontier.cn --connections 1000
|
|
|
|
# 测试 5000 个连接,持续 5 分钟
|
|
python websocket_test.py --url wss://valuefrontier.cn --connections 5000 --duration 300
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import time
|
|
import statistics
|
|
from datetime import datetime
|
|
import socketio
|
|
|
|
# 统计数据
|
|
stats = {
|
|
"connected": 0,
|
|
"disconnected": 0,
|
|
"messages_received": 0,
|
|
"errors": 0,
|
|
"connect_times": [],
|
|
}
|
|
|
|
|
|
async def create_client(client_id, url, namespace="/"):
|
|
"""创建单个 WebSocket 客户端"""
|
|
sio = socketio.AsyncClient(
|
|
reconnection=False,
|
|
logger=False,
|
|
engineio_logger=False
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
@sio.event
|
|
async def connect():
|
|
connect_time = (time.time() - start_time) * 1000
|
|
stats["connected"] += 1
|
|
stats["connect_times"].append(connect_time)
|
|
|
|
@sio.event
|
|
async def disconnect():
|
|
stats["disconnected"] += 1
|
|
|
|
@sio.event
|
|
async def message(data):
|
|
stats["messages_received"] += 1
|
|
|
|
@sio.on("*")
|
|
async def catch_all(event, data):
|
|
stats["messages_received"] += 1
|
|
|
|
try:
|
|
await sio.connect(url, namespaces=[namespace])
|
|
return sio
|
|
except Exception as e:
|
|
stats["errors"] += 1
|
|
return None
|
|
|
|
|
|
async def run_test(url, num_connections, duration, batch_size=100):
|
|
"""运行 WebSocket 压力测试"""
|
|
print("=" * 60)
|
|
print(f"🚀 WebSocket 压力测试")
|
|
print(f" 目标: {url}")
|
|
print(f" 连接数: {num_connections}")
|
|
print(f" 持续时间: {duration} 秒")
|
|
print(f" 批量大小: {batch_size}")
|
|
print("=" * 60)
|
|
|
|
clients = []
|
|
start_time = time.time()
|
|
|
|
# 分批创建连接
|
|
print(f"\n📡 开始创建 {num_connections} 个 WebSocket 连接...")
|
|
|
|
for i in range(0, num_connections, batch_size):
|
|
batch_end = min(i + batch_size, num_connections)
|
|
batch_tasks = [
|
|
create_client(j, url)
|
|
for j in range(i, batch_end)
|
|
]
|
|
|
|
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
|
|
|
for result in batch_results:
|
|
if result and not isinstance(result, Exception):
|
|
clients.append(result)
|
|
|
|
# 打印进度
|
|
progress = (batch_end / num_connections) * 100
|
|
print(f" 进度: {batch_end}/{num_connections} ({progress:.1f}%) - "
|
|
f"成功: {stats['connected']}, 失败: {stats['errors']}")
|
|
|
|
# 短暂暂停,避免连接风暴
|
|
await asyncio.sleep(0.1)
|
|
|
|
connect_duration = time.time() - start_time
|
|
print(f"\n✅ 连接阶段完成!")
|
|
print(f" 耗时: {connect_duration:.2f} 秒")
|
|
print(f" 成功连接: {stats['connected']}")
|
|
print(f" 连接失败: {stats['errors']}")
|
|
|
|
if stats["connect_times"]:
|
|
print(f" 平均连接时间: {statistics.mean(stats['connect_times']):.2f} ms")
|
|
print(f" P99 连接时间: {statistics.quantiles(stats['connect_times'], n=100)[98]:.2f} ms")
|
|
|
|
# 保持连接一段时间
|
|
print(f"\n⏳ 保持连接 {duration} 秒...")
|
|
messages_before = stats["messages_received"]
|
|
|
|
await asyncio.sleep(duration)
|
|
|
|
messages_after = stats["messages_received"]
|
|
messages_during = messages_after - messages_before
|
|
|
|
# 断开所有连接
|
|
print("\n📴 断开所有连接...")
|
|
disconnect_tasks = [
|
|
client.disconnect()
|
|
for client in clients
|
|
if client and client.connected
|
|
]
|
|
await asyncio.gather(*disconnect_tasks, return_exceptions=True)
|
|
|
|
# 打印最终统计
|
|
total_duration = time.time() - start_time
|
|
print("\n" + "=" * 60)
|
|
print("📊 测试结果")
|
|
print("=" * 60)
|
|
print(f" 总耗时: {total_duration:.2f} 秒")
|
|
print(f" 成功连接: {stats['connected']}")
|
|
print(f" 连接失败: {stats['errors']}")
|
|
print(f" 连接成功率: {(stats['connected'] / num_connections * 100):.2f}%")
|
|
print(f" 收到消息数: {messages_during}")
|
|
print(f" 消息速率: {(messages_during / duration):.2f} msg/s")
|
|
|
|
if stats["connect_times"]:
|
|
print(f"\n 连接延迟统计:")
|
|
print(f" 最小: {min(stats['connect_times']):.2f} ms")
|
|
print(f" 最大: {max(stats['connect_times']):.2f} ms")
|
|
print(f" 平均: {statistics.mean(stats['connect_times']):.2f} ms")
|
|
print(f" 中位数: {statistics.median(stats['connect_times']):.2f} ms")
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="WebSocket 压力测试")
|
|
parser.add_argument("--url", default="wss://valuefrontier.cn",
|
|
help="WebSocket URL (default: wss://valuefrontier.cn)")
|
|
parser.add_argument("--connections", type=int, default=1000,
|
|
help="并发连接数 (default: 1000)")
|
|
parser.add_argument("--duration", type=int, default=60,
|
|
help="测试持续时间(秒) (default: 60)")
|
|
parser.add_argument("--batch", type=int, default=100,
|
|
help="批量创建连接数 (default: 100)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
asyncio.run(run_test(
|
|
url=args.url,
|
|
num_connections=args.connections,
|
|
duration=args.duration,
|
|
batch_size=args.batch
|
|
))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|