From 50cec6b41cae88233d6e86a04c95ab5b1dba43f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 20 Oct 2023 22:43:31 +0100 Subject: [PATCH] Add websocket.handler.before/after/exception signals. --- sanic/app.py | 20 +++++++++++++ sanic/signals.py | 8 +++++ tests/test_ws_handlers.py | 62 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+) diff --git a/sanic/app.py b/sanic/app.py index 452954ca..12786a18 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1369,6 +1369,12 @@ class Sanic( protocol = request.transport.get_protocol() ws = await protocol.websocket_handshake(request, subprotocols) + await self.dispatch( + "websocket.handler.before", + inline=True, + context={"request": request, "websocket": ws}, + fail_not_found=False, + ) # schedule the application handler # its future is kept in self.websocket_tasks in case it # needs to be cancelled due to the server being stopped @@ -1377,10 +1383,24 @@ class Sanic( cancelled = False try: await fut + await self.dispatch( + "websocket.handler.after", + inline=True, + context={"request": request, "websocket": ws}, + reverse=True, + fail_not_found=False, + ) except (CancelledError, ConnectionClosed): # type: ignore cancelled = True except Exception as e: self.error_handler.log(request, e) + await self.dispatch( + "websocket.handler.exception", + inline=True, + context={"request": request, "websocket": ws, "exception": e}, + reverse=True, + fail_not_found=False, + ) finally: self.websocket_tasks.remove(fut) if cancelled: diff --git a/sanic/signals.py b/sanic/signals.py index fe252c12..b5069068 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -38,6 +38,9 @@ class Event(Enum): HTTP_LIFECYCLE_SEND = "http.lifecycle.send" HTTP_MIDDLEWARE_AFTER = "http.middleware.after" HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" + WEBSOCKET_HANDLER_AFTER = "websocket.handler.after" + WEBSOCKET_HANDLER_BEFORE = "websocket.handler.before" + WEBSOCKET_HANDLER_EXCEPTION = "websocket.handler.exception" RESERVED_NAMESPACES = { @@ -65,6 +68,11 @@ RESERVED_NAMESPACES = { Event.HTTP_MIDDLEWARE_AFTER.value, Event.HTTP_MIDDLEWARE_BEFORE.value, ), + "websocket": { + Event.WEBSOCKET_HANDLER_AFTER.value, + Event.WEBSOCKET_HANDLER_BEFORE.value, + Event.WEBSOCKET_HANDLER_EXCEPTION.value, + }, } diff --git a/tests/test_ws_handlers.py b/tests/test_ws_handlers.py index 5205090d..c70dbeed 100644 --- a/tests/test_ws_handlers.py +++ b/tests/test_ws_handlers.py @@ -54,3 +54,65 @@ def test_ws_handler_async_for( ) assert ws_proxy.client_sent == ["test 1", "test 2", ""] assert ws_proxy.client_received == ["test 1", "test 2"] + + +def signalapp(app): + @app.signal("websocket.handler.before") + async def ws_before(request: Request, websocket: Websocket): + app.ctx.seq.append("before") + print("before") + await websocket.send("before: " + await websocket.recv()) + print("before2") + + @app.signal("websocket.handler.after") + async def ws_after(request: Request, websocket: Websocket): + app.ctx.seq.append("after") + await websocket.send("after: " + await websocket.recv()) + await websocket.recv() + + @app.signal("websocket.handler.exception") + async def ws_exception( + request: Request, websocket: Websocket, exception: Exception + ): + app.ctx.seq.append("exception") + await websocket.send(f"exception: {exception}") + await websocket.recv() + + @app.websocket("/ws") + async def ws_handler(request: Request, ws: Websocket): + app.ctx.seq.append("ws") + + @app.websocket("/wserror") + async def ws_error(request: Request, ws: Websocket): + print("wserr") + app.ctx.seq.append("wserror") + raise Exception(await ws.recv()) + print("wserr2") + + +def test_ws_signals( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + signalapp(app) + + app.ctx.seq = [] + _, ws_proxy = app.test_client.websocket( + "/ws", mimic=simple_ws_mimic_client + ) + assert ws_proxy.client_received == ["before: test 1", "after: test 2"] + assert app.ctx.seq == ["before", "ws", "after"] + + +def test_ws_signals_exception( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + signalapp(app) + + app.ctx.seq = [] + _, ws_proxy = app.test_client.websocket( + "/wserror", mimic=simple_ws_mimic_client + ) + assert ws_proxy.client_received == ["before: test 1", "exception: test 2"] + assert app.ctx.seq == ["before", "wserror", "exception"]