Compare commits

..

1 Commits

Author SHA1 Message Date
L. Kärkkäinen
50cec6b41c Add websocket.handler.before/after/exception signals. 2023-10-20 22:43:31 +01:00
6 changed files with 104 additions and 11 deletions

View File

@@ -1369,6 +1369,12 @@ class Sanic(
protocol = request.transport.get_protocol()
ws = await protocol.websocket_handshake(request, subprotocols)
await self.dispatch(
"websocket.handler.before",
inline=True,
context={"request": request, "websocket": ws},
fail_not_found=False,
)
# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
@@ -1377,10 +1383,24 @@ class Sanic(
cancelled = False
try:
await fut
await self.dispatch(
"websocket.handler.after",
inline=True,
context={"request": request, "websocket": ws},
reverse=True,
fail_not_found=False,
)
except (CancelledError, ConnectionClosed): # type: ignore
cancelled = True
except Exception as e:
self.error_handler.log(request, e)
await self.dispatch(
"websocket.handler.exception",
inline=True,
context={"request": request, "websocket": ws, "exception": e},
reverse=True,
fail_not_found=False,
)
finally:
self.websocket_tasks.remove(fut)
if cancelled:

View File

@@ -149,11 +149,13 @@ class CookieRequestParameters(RequestParameters):
except KeyError:
return super().get(name, default)
def getlist(self, name: str) -> list[Any]:
def getlist(
self, name: str, default: Optional[Any] = None
) -> Optional[Any]:
try:
return self._get_prefixed_cookie(name)
except KeyError:
return super().getlist(name)
return super().getlist(name, default)
def _get_prefixed_cookie(self, name: str) -> Any:
getitem = super().__getitem__

View File

@@ -19,14 +19,15 @@ class RequestParameters(dict):
return super().get(name, [default])[0]
def getlist(
self, name: str
) -> list[Any]:
self, name: str, default: Optional[Any] = None
) -> Optional[Any]:
"""Return the entire list
Args:
name (str): The name of the parameter
default (Optional[Any], optional): The default value. Defaults to None.
Returns:
list[Any]: The entire list of values or [] if not found
Optional[Any]: The entire list
""" # noqa: E501
return super().get(name) or []
return super().get(name, default)

View File

@@ -38,6 +38,9 @@ class Event(Enum):
HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
WEBSOCKET_HANDLER_AFTER = "websocket.handler.after"
WEBSOCKET_HANDLER_BEFORE = "websocket.handler.before"
WEBSOCKET_HANDLER_EXCEPTION = "websocket.handler.exception"
RESERVED_NAMESPACES = {
@@ -65,6 +68,11 @@ RESERVED_NAMESPACES = {
Event.HTTP_MIDDLEWARE_AFTER.value,
Event.HTTP_MIDDLEWARE_BEFORE.value,
),
"websocket": {
Event.WEBSOCKET_HANDLER_AFTER.value,
Event.WEBSOCKET_HANDLER_BEFORE.value,
Event.WEBSOCKET_HANDLER_EXCEPTION.value,
},
}

View File

@@ -445,10 +445,10 @@ def test_cookie_accessors(app: Sanic):
"four": request.cookies.get("four", "fallback"),
},
"getlist": {
"one": request.cookies.getlist("one"),
"two": request.cookies.getlist("two"),
"three": request.cookies.getlist("three"),
"four": request.cookies.getlist("four"),
"one": request.cookies.getlist("one", ["fallback"]),
"two": request.cookies.getlist("two", ["fallback"]),
"three": request.cookies.getlist("three", ["fallback"]),
"four": request.cookies.getlist("four", ["fallback"]),
},
"getattr": {
"one": request.cookies.one,
@@ -484,7 +484,7 @@ def test_cookie_accessors(app: Sanic):
"one": ["1"],
"two": ["2"],
"three": ["3"],
"four": [],
"four": ["fallback"],
},
"getattr": {
"one": "1",

View File

@@ -54,3 +54,65 @@ def test_ws_handler_async_for(
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]
def signalapp(app):
@app.signal("websocket.handler.before")
async def ws_before(request: Request, websocket: Websocket):
app.ctx.seq.append("before")
print("before")
await websocket.send("before: " + await websocket.recv())
print("before2")
@app.signal("websocket.handler.after")
async def ws_after(request: Request, websocket: Websocket):
app.ctx.seq.append("after")
await websocket.send("after: " + await websocket.recv())
await websocket.recv()
@app.signal("websocket.handler.exception")
async def ws_exception(
request: Request, websocket: Websocket, exception: Exception
):
app.ctx.seq.append("exception")
await websocket.send(f"exception: {exception}")
await websocket.recv()
@app.websocket("/ws")
async def ws_handler(request: Request, ws: Websocket):
app.ctx.seq.append("ws")
@app.websocket("/wserror")
async def ws_error(request: Request, ws: Websocket):
print("wserr")
app.ctx.seq.append("wserror")
raise Exception(await ws.recv())
print("wserr2")
def test_ws_signals(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)
app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_received == ["before: test 1", "after: test 2"]
assert app.ctx.seq == ["before", "ws", "after"]
def test_ws_signals_exception(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)
app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/wserror", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_received == ["before: test 1", "exception: test 2"]
assert app.ctx.seq == ["before", "wserror", "exception"]