Make WebsocketImplProtocol async iterable (#2490)

This commit is contained in:
Zhiwei
2022-09-20 16:20:32 -05:00
committed by GitHub
parent 389363ab71
commit 16503319e5
2 changed files with 68 additions and 1 deletions

View File

@@ -13,7 +13,11 @@ from typing import (
)
from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
)
from websockets.frames import Frame, Opcode
from websockets.server import ServerConnection
from websockets.typing import Data
@@ -840,3 +844,10 @@ class WebsocketImplProtocol:
self.abort_pings()
if self.connection_lost_waiter:
self.connection_lost_waiter.set_result(None)
async def __aiter__(self):
try:
while True:
yield await self.recv()
except ConnectionClosedOK:
return

56
tests/test_ws_handlers.py Normal file
View File

@@ -0,0 +1,56 @@
from typing import Any, Callable, Coroutine
import pytest
from websockets.client import WebSocketClientProtocol
from sanic import Request, Sanic, Websocket
MimicClientType = Callable[
[WebSocketClientProtocol], Coroutine[None, None, Any]
]
@pytest.fixture
def simple_ws_mimic_client():
async def client_mimic(ws: WebSocketClientProtocol):
await ws.send("test 1")
await ws.recv()
await ws.send("test 2")
await ws.recv()
return client_mimic
def test_ws_handler(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
while True:
msg = await ws.recv()
await ws.send(msg)
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]
def test_ws_handler_async_for(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
async for msg in ws:
await ws.send(msg)
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]