From 16503319e5240b6fc1774e8607ffde7469a89774 Mon Sep 17 00:00:00 2001 From: Zhiwei Date: Tue, 20 Sep 2022 16:20:32 -0500 Subject: [PATCH] Make WebsocketImplProtocol async iterable (#2490) --- sanic/server/websockets/impl.py | 13 +++++++- tests/test_ws_handlers.py | 56 +++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 tests/test_ws_handlers.py diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 2bc3fc5e..6104c1f0 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -13,7 +13,11 @@ from typing import ( ) from websockets.connection import CLOSED, CLOSING, OPEN, Event -from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, +) from websockets.frames import Frame, Opcode from websockets.server import ServerConnection from websockets.typing import Data @@ -840,3 +844,10 @@ class WebsocketImplProtocol: self.abort_pings() if self.connection_lost_waiter: self.connection_lost_waiter.set_result(None) + + async def __aiter__(self): + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return diff --git a/tests/test_ws_handlers.py b/tests/test_ws_handlers.py new file mode 100644 index 00000000..5205090d --- /dev/null +++ b/tests/test_ws_handlers.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, Coroutine + +import pytest + +from websockets.client import WebSocketClientProtocol + +from sanic import Request, Sanic, Websocket + + +MimicClientType = Callable[ + [WebSocketClientProtocol], Coroutine[None, None, Any] +] + + +@pytest.fixture +def simple_ws_mimic_client(): + async def client_mimic(ws: WebSocketClientProtocol): + await ws.send("test 1") + await ws.recv() + await ws.send("test 2") + await ws.recv() + + return client_mimic + + +def test_ws_handler( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + @app.websocket("/ws") + async def ws_echo_handler(request: Request, ws: Websocket): + while True: + msg = await ws.recv() + await ws.send(msg) + + _, ws_proxy = app.test_client.websocket( + "/ws", mimic=simple_ws_mimic_client + ) + assert ws_proxy.client_sent == ["test 1", "test 2", ""] + assert ws_proxy.client_received == ["test 1", "test 2"] + + +def test_ws_handler_async_for( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + @app.websocket("/ws") + async def ws_echo_handler(request: Request, ws: Websocket): + async for msg in ws: + await ws.send(msg) + + _, ws_proxy = app.test_client.websocket( + "/ws", mimic=simple_ws_mimic_client + ) + assert ws_proxy.client_sent == ["test 1", "test 2", ""] + assert ws_proxy.client_received == ["test 1", "test 2"]