# -*- coding: utf-8 -*- """ WebSocket 压力测试脚本 使用方式: pip install python-socketio[client] websocket-client # 测试 1000 个 WebSocket 连接 python websocket_test.py --url wss://valuefrontier.cn --connections 1000 # 测试 5000 个连接,持续 5 分钟 python websocket_test.py --url wss://valuefrontier.cn --connections 5000 --duration 300 """ import argparse import asyncio import time import statistics from datetime import datetime import socketio # 统计数据 stats = { "connected": 0, "disconnected": 0, "messages_received": 0, "errors": 0, "connect_times": [], } async def create_client(client_id, url, namespace="/"): """创建单个 WebSocket 客户端""" sio = socketio.AsyncClient( reconnection=False, logger=False, engineio_logger=False ) start_time = time.time() @sio.event async def connect(): connect_time = (time.time() - start_time) * 1000 stats["connected"] += 1 stats["connect_times"].append(connect_time) @sio.event async def disconnect(): stats["disconnected"] += 1 @sio.event async def message(data): stats["messages_received"] += 1 @sio.on("*") async def catch_all(event, data): stats["messages_received"] += 1 try: await sio.connect(url, namespaces=[namespace]) return sio except Exception as e: stats["errors"] += 1 return None async def run_test(url, num_connections, duration, batch_size=100): """运行 WebSocket 压力测试""" print("=" * 60) print(f"🚀 WebSocket 压力测试") print(f" 目标: {url}") print(f" 连接数: {num_connections}") print(f" 持续时间: {duration} 秒") print(f" 批量大小: {batch_size}") print("=" * 60) clients = [] start_time = time.time() # 分批创建连接 print(f"\n📡 开始创建 {num_connections} 个 WebSocket 连接...") for i in range(0, num_connections, batch_size): batch_end = min(i + batch_size, num_connections) batch_tasks = [ create_client(j, url) for j in range(i, batch_end) ] batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) for result in batch_results: if result and not isinstance(result, Exception): clients.append(result) # 打印进度 progress = (batch_end / num_connections) * 100 print(f" 进度: {batch_end}/{num_connections} ({progress:.1f}%) - " f"成功: {stats['connected']}, 失败: {stats['errors']}") # 短暂暂停,避免连接风暴 await asyncio.sleep(0.1) connect_duration = time.time() - start_time print(f"\n✅ 连接阶段完成!") print(f" 耗时: {connect_duration:.2f} 秒") print(f" 成功连接: {stats['connected']}") print(f" 连接失败: {stats['errors']}") if stats["connect_times"]: print(f" 平均连接时间: {statistics.mean(stats['connect_times']):.2f} ms") print(f" P99 连接时间: {statistics.quantiles(stats['connect_times'], n=100)[98]:.2f} ms") # 保持连接一段时间 print(f"\n⏳ 保持连接 {duration} 秒...") messages_before = stats["messages_received"] await asyncio.sleep(duration) messages_after = stats["messages_received"] messages_during = messages_after - messages_before # 断开所有连接 print("\n📴 断开所有连接...") disconnect_tasks = [ client.disconnect() for client in clients if client and client.connected ] await asyncio.gather(*disconnect_tasks, return_exceptions=True) # 打印最终统计 total_duration = time.time() - start_time print("\n" + "=" * 60) print("📊 测试结果") print("=" * 60) print(f" 总耗时: {total_duration:.2f} 秒") print(f" 成功连接: {stats['connected']}") print(f" 连接失败: {stats['errors']}") print(f" 连接成功率: {(stats['connected'] / num_connections * 100):.2f}%") print(f" 收到消息数: {messages_during}") print(f" 消息速率: {(messages_during / duration):.2f} msg/s") if stats["connect_times"]: print(f"\n 连接延迟统计:") print(f" 最小: {min(stats['connect_times']):.2f} ms") print(f" 最大: {max(stats['connect_times']):.2f} ms") print(f" 平均: {statistics.mean(stats['connect_times']):.2f} ms") print(f" 中位数: {statistics.median(stats['connect_times']):.2f} ms") print("=" * 60) def main(): parser = argparse.ArgumentParser(description="WebSocket 压力测试") parser.add_argument("--url", default="wss://valuefrontier.cn", help="WebSocket URL (default: wss://valuefrontier.cn)") parser.add_argument("--connections", type=int, default=1000, help="并发连接数 (default: 1000)") parser.add_argument("--duration", type=int, default=60, help="测试持续时间(秒) (default: 60)") parser.add_argument("--batch", type=int, default=100, help="批量创建连接数 (default: 100)") args = parser.parse_args() asyncio.run(run_test( url=args.url, num_connections=args.connections, duration=args.duration, batch_size=args.batch )) if __name__ == "__main__": main()