Compare commits
1 Commits
tlspath
...
websocket-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50cec6b41c |
20
sanic/app.py
20
sanic/app.py
@@ -1369,6 +1369,12 @@ class Sanic(
|
|||||||
protocol = request.transport.get_protocol()
|
protocol = request.transport.get_protocol()
|
||||||
ws = await protocol.websocket_handshake(request, subprotocols)
|
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||||
|
|
||||||
|
await self.dispatch(
|
||||||
|
"websocket.handler.before",
|
||||||
|
inline=True,
|
||||||
|
context={"request": request, "websocket": ws},
|
||||||
|
fail_not_found=False,
|
||||||
|
)
|
||||||
# schedule the application handler
|
# schedule the application handler
|
||||||
# its future is kept in self.websocket_tasks in case it
|
# its future is kept in self.websocket_tasks in case it
|
||||||
# needs to be cancelled due to the server being stopped
|
# needs to be cancelled due to the server being stopped
|
||||||
@@ -1377,10 +1383,24 @@ class Sanic(
|
|||||||
cancelled = False
|
cancelled = False
|
||||||
try:
|
try:
|
||||||
await fut
|
await fut
|
||||||
|
await self.dispatch(
|
||||||
|
"websocket.handler.after",
|
||||||
|
inline=True,
|
||||||
|
context={"request": request, "websocket": ws},
|
||||||
|
reverse=True,
|
||||||
|
fail_not_found=False,
|
||||||
|
)
|
||||||
except (CancelledError, ConnectionClosed): # type: ignore
|
except (CancelledError, ConnectionClosed): # type: ignore
|
||||||
cancelled = True
|
cancelled = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.error_handler.log(request, e)
|
self.error_handler.log(request, e)
|
||||||
|
await self.dispatch(
|
||||||
|
"websocket.handler.exception",
|
||||||
|
inline=True,
|
||||||
|
context={"request": request, "websocket": ws, "exception": e},
|
||||||
|
reverse=True,
|
||||||
|
fail_not_found=False,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
self.websocket_tasks.remove(fut)
|
self.websocket_tasks.remove(fut)
|
||||||
if cancelled:
|
if cancelled:
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
from pathlib import Path, PurePath
|
|
||||||
from typing import Any, Dict, Iterable, Optional, Union
|
from typing import Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
@@ -40,23 +39,23 @@ def create_context(
|
|||||||
|
|
||||||
|
|
||||||
def shorthand_to_ctx(
|
def shorthand_to_ctx(
|
||||||
ctxdef: Union[None, ssl.SSLContext, dict, PurePath, str]
|
ctxdef: Union[None, ssl.SSLContext, dict, str]
|
||||||
) -> Optional[ssl.SSLContext]:
|
) -> Optional[ssl.SSLContext]:
|
||||||
"""Convert an ssl argument shorthand to an SSLContext object."""
|
"""Convert an ssl argument shorthand to an SSLContext object."""
|
||||||
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
|
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
|
||||||
return ctxdef
|
return ctxdef
|
||||||
if isinstance(ctxdef, (PurePath, str)):
|
if isinstance(ctxdef, str):
|
||||||
return load_cert_dir(Path(ctxdef))
|
return load_cert_dir(ctxdef)
|
||||||
if isinstance(ctxdef, dict):
|
if isinstance(ctxdef, dict):
|
||||||
return CertSimple(**ctxdef)
|
return CertSimple(**ctxdef)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid ssl argument {type(ctxdef)}."
|
f"Invalid ssl argument {type(ctxdef)}."
|
||||||
" Expecting one/list of: certdir | dict | SSLContext"
|
" Expecting a list of certdirs, a dict or an SSLContext."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_to_context(
|
def process_to_context(
|
||||||
ssldef: Union[None, ssl.SSLContext, dict, PurePath, str, list, tuple]
|
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
|
||||||
) -> Optional[ssl.SSLContext]:
|
) -> Optional[ssl.SSLContext]:
|
||||||
"""Process app.run ssl argument from easy formats to full SSLContext."""
|
"""Process app.run ssl argument from easy formats to full SSLContext."""
|
||||||
return (
|
return (
|
||||||
@@ -66,11 +65,11 @@ def process_to_context(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_cert_dir(p: Path) -> ssl.SSLContext:
|
def load_cert_dir(p: str) -> ssl.SSLContext:
|
||||||
if p.is_file():
|
if os.path.isfile(p):
|
||||||
raise ValueError(f"Certificate folder expected but {p} is a file.")
|
raise ValueError(f"Certificate folder expected but {p} is a file.")
|
||||||
keyfile = p / "privkey.pem"
|
keyfile = os.path.join(p, "privkey.pem")
|
||||||
certfile = p / "fullchain.pem"
|
certfile = os.path.join(p, "fullchain.pem")
|
||||||
if not os.access(keyfile, os.R_OK):
|
if not os.access(keyfile, os.R_OK):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Certificate not found or permission denied {keyfile}"
|
f"Certificate not found or permission denied {keyfile}"
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ class Event(Enum):
|
|||||||
HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
|
HTTP_LIFECYCLE_SEND = "http.lifecycle.send"
|
||||||
HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
|
HTTP_MIDDLEWARE_AFTER = "http.middleware.after"
|
||||||
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
|
HTTP_MIDDLEWARE_BEFORE = "http.middleware.before"
|
||||||
|
WEBSOCKET_HANDLER_AFTER = "websocket.handler.after"
|
||||||
|
WEBSOCKET_HANDLER_BEFORE = "websocket.handler.before"
|
||||||
|
WEBSOCKET_HANDLER_EXCEPTION = "websocket.handler.exception"
|
||||||
|
|
||||||
|
|
||||||
RESERVED_NAMESPACES = {
|
RESERVED_NAMESPACES = {
|
||||||
@@ -65,6 +68,11 @@ RESERVED_NAMESPACES = {
|
|||||||
Event.HTTP_MIDDLEWARE_AFTER.value,
|
Event.HTTP_MIDDLEWARE_AFTER.value,
|
||||||
Event.HTTP_MIDDLEWARE_BEFORE.value,
|
Event.HTTP_MIDDLEWARE_BEFORE.value,
|
||||||
),
|
),
|
||||||
|
"websocket": {
|
||||||
|
Event.WEBSOCKET_HANDLER_AFTER.value,
|
||||||
|
Event.WEBSOCKET_HANDLER_BEFORE.value,
|
||||||
|
Event.WEBSOCKET_HANDLER_EXCEPTION.value,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -54,3 +54,65 @@ def test_ws_handler_async_for(
|
|||||||
)
|
)
|
||||||
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
|
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
|
||||||
assert ws_proxy.client_received == ["test 1", "test 2"]
|
assert ws_proxy.client_received == ["test 1", "test 2"]
|
||||||
|
|
||||||
|
|
||||||
|
def signalapp(app):
|
||||||
|
@app.signal("websocket.handler.before")
|
||||||
|
async def ws_before(request: Request, websocket: Websocket):
|
||||||
|
app.ctx.seq.append("before")
|
||||||
|
print("before")
|
||||||
|
await websocket.send("before: " + await websocket.recv())
|
||||||
|
print("before2")
|
||||||
|
|
||||||
|
@app.signal("websocket.handler.after")
|
||||||
|
async def ws_after(request: Request, websocket: Websocket):
|
||||||
|
app.ctx.seq.append("after")
|
||||||
|
await websocket.send("after: " + await websocket.recv())
|
||||||
|
await websocket.recv()
|
||||||
|
|
||||||
|
@app.signal("websocket.handler.exception")
|
||||||
|
async def ws_exception(
|
||||||
|
request: Request, websocket: Websocket, exception: Exception
|
||||||
|
):
|
||||||
|
app.ctx.seq.append("exception")
|
||||||
|
await websocket.send(f"exception: {exception}")
|
||||||
|
await websocket.recv()
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def ws_handler(request: Request, ws: Websocket):
|
||||||
|
app.ctx.seq.append("ws")
|
||||||
|
|
||||||
|
@app.websocket("/wserror")
|
||||||
|
async def ws_error(request: Request, ws: Websocket):
|
||||||
|
print("wserr")
|
||||||
|
app.ctx.seq.append("wserror")
|
||||||
|
raise Exception(await ws.recv())
|
||||||
|
print("wserr2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_signals(
|
||||||
|
app: Sanic,
|
||||||
|
simple_ws_mimic_client: MimicClientType,
|
||||||
|
):
|
||||||
|
signalapp(app)
|
||||||
|
|
||||||
|
app.ctx.seq = []
|
||||||
|
_, ws_proxy = app.test_client.websocket(
|
||||||
|
"/ws", mimic=simple_ws_mimic_client
|
||||||
|
)
|
||||||
|
assert ws_proxy.client_received == ["before: test 1", "after: test 2"]
|
||||||
|
assert app.ctx.seq == ["before", "ws", "after"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_signals_exception(
|
||||||
|
app: Sanic,
|
||||||
|
simple_ws_mimic_client: MimicClientType,
|
||||||
|
):
|
||||||
|
signalapp(app)
|
||||||
|
|
||||||
|
app.ctx.seq = []
|
||||||
|
_, ws_proxy = app.test_client.websocket(
|
||||||
|
"/wserror", mimic=simple_ws_mimic_client
|
||||||
|
)
|
||||||
|
assert ws_proxy.client_received == ["before: test 1", "exception: test 2"]
|
||||||
|
assert app.ctx.seq == ["before", "wserror", "exception"]
|
||||||
|
|||||||
Reference in New Issue
Block a user