Compare commits

..

1 Commits

Author SHA1 Message Date
L. Kärkkäinen
50cec6b41c Add websocket.handler.before/after/exception signals. 2023-10-20 22:43:31 +01:00
4 changed files with 99 additions and 10 deletions

View File

@@ -1369,6 +1369,12 @@ class Sanic(
protocol = request.transport.get_protocol() protocol = request.transport.get_protocol()
ws = await protocol.websocket_handshake(request, subprotocols) ws = await protocol.websocket_handshake(request, subprotocols)
await self.dispatch(
"websocket.handler.before",
inline=True,
context={"request": request, "websocket": ws},
fail_not_found=False,
)
# schedule the application handler # schedule the application handler
# its future is kept in self.websocket_tasks in case it # its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped # needs to be cancelled due to the server being stopped
@@ -1377,10 +1383,24 @@ class Sanic(
cancelled = False cancelled = False
try: try:
await fut await fut
await self.dispatch(
"websocket.handler.after",
inline=True,
context={"request": request, "websocket": ws},
reverse=True,
fail_not_found=False,
)
except (CancelledError, ConnectionClosed): # type: ignore except (CancelledError, ConnectionClosed): # type: ignore
cancelled = True cancelled = True
except Exception as e: except Exception as e:
self.error_handler.log(request, e) self.error_handler.log(request, e)
await self.dispatch(
"websocket.handler.exception",
inline=True,
context={"request": request, "websocket": ws, "exception": e},
reverse=True,
fail_not_found=False,
)
finally: finally:
self.websocket_tasks.remove(fut) self.websocket_tasks.remove(fut)
if cancelled: if cancelled:

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import os import os
import ssl import ssl
from pathlib import Path, PurePath
from typing import Any, Dict, Iterable, Optional, Union from typing import Any, Dict, Iterable, Optional, Union
from sanic.log import logger from sanic.log import logger
@@ -40,23 +39,23 @@ def create_context(
def shorthand_to_ctx( def shorthand_to_ctx(
ctxdef: Union[None, ssl.SSLContext, dict, PurePath, str] ctxdef: Union[None, ssl.SSLContext, dict, str]
) -> Optional[ssl.SSLContext]: ) -> Optional[ssl.SSLContext]:
"""Convert an ssl argument shorthand to an SSLContext object.""" """Convert an ssl argument shorthand to an SSLContext object."""
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext): if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
return ctxdef return ctxdef
if isinstance(ctxdef, (PurePath, str)): if isinstance(ctxdef, str):
return load_cert_dir(Path(ctxdef)) return load_cert_dir(ctxdef)
if isinstance(ctxdef, dict): if isinstance(ctxdef, dict):
return CertSimple(**ctxdef) return CertSimple(**ctxdef)
raise ValueError( raise ValueError(
f"Invalid ssl argument {type(ctxdef)}." f"Invalid ssl argument {type(ctxdef)}."
" Expecting one/list of: certdir | dict | SSLContext" " Expecting a list of certdirs, a dict or an SSLContext."
) )
def process_to_context( def process_to_context(
ssldef: Union[None, ssl.SSLContext, dict, PurePath, str, list, tuple] ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
) -> Optional[ssl.SSLContext]: ) -> Optional[ssl.SSLContext]:
"""Process app.run ssl argument from easy formats to full SSLContext.""" """Process app.run ssl argument from easy formats to full SSLContext."""
return ( return (
@@ -66,11 +65,11 @@ def process_to_context(
) )
def load_cert_dir(p: Path) -> ssl.SSLContext: def load_cert_dir(p: str) -> ssl.SSLContext:
if p.is_file(): if os.path.isfile(p):
raise ValueError(f"Certificate folder expected but {p} is a file.") raise ValueError(f"Certificate folder expected but {p} is a file.")
keyfile = p / "privkey.pem" keyfile = os.path.join(p, "privkey.pem")
certfile = p / "fullchain.pem" certfile = os.path.join(p, "fullchain.pem")
if not os.access(keyfile, os.R_OK): if not os.access(keyfile, os.R_OK):
raise ValueError( raise ValueError(
f"Certificate not found or permission denied {keyfile}" f"Certificate not found or permission denied {keyfile}"

View File

@@ -38,6 +38,9 @@ class Event(Enum):
HTTP_LIFECYCLE_SEND = "http.lifecycle.send" HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
HTTP_MIDDLEWARE_AFTER = "http.middleware.after" HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
WEBSOCKET_HANDLER_AFTER = "websocket.handler.after"
WEBSOCKET_HANDLER_BEFORE = "websocket.handler.before"
WEBSOCKET_HANDLER_EXCEPTION = "websocket.handler.exception"
RESERVED_NAMESPACES = { RESERVED_NAMESPACES = {
@@ -65,6 +68,11 @@ RESERVED_NAMESPACES = {
Event.HTTP_MIDDLEWARE_AFTER.value, Event.HTTP_MIDDLEWARE_AFTER.value,
Event.HTTP_MIDDLEWARE_BEFORE.value, Event.HTTP_MIDDLEWARE_BEFORE.value,
), ),
"websocket": {
Event.WEBSOCKET_HANDLER_AFTER.value,
Event.WEBSOCKET_HANDLER_BEFORE.value,
Event.WEBSOCKET_HANDLER_EXCEPTION.value,
},
} }

View File

@@ -54,3 +54,65 @@ def test_ws_handler_async_for(
) )
assert ws_proxy.client_sent == ["test 1", "test 2", ""] assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"] assert ws_proxy.client_received == ["test 1", "test 2"]
def signalapp(app):
@app.signal("websocket.handler.before")
async def ws_before(request: Request, websocket: Websocket):
app.ctx.seq.append("before")
print("before")
await websocket.send("before: " + await websocket.recv())
print("before2")
@app.signal("websocket.handler.after")
async def ws_after(request: Request, websocket: Websocket):
app.ctx.seq.append("after")
await websocket.send("after: " + await websocket.recv())
await websocket.recv()
@app.signal("websocket.handler.exception")
async def ws_exception(
request: Request, websocket: Websocket, exception: Exception
):
app.ctx.seq.append("exception")
await websocket.send(f"exception: {exception}")
await websocket.recv()
@app.websocket("/ws")
async def ws_handler(request: Request, ws: Websocket):
app.ctx.seq.append("ws")
@app.websocket("/wserror")
async def ws_error(request: Request, ws: Websocket):
print("wserr")
app.ctx.seq.append("wserror")
raise Exception(await ws.recv())
print("wserr2")
def test_ws_signals(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)
app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_received == ["before: test 1", "after: test 2"]
assert app.ctx.seq == ["before", "ws", "after"]
def test_ws_signals_exception(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)
app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/wserror", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_received == ["before: test 1", "exception: test 2"]
assert app.ctx.seq == ["before", "wserror", "exception"]