From f39b8b32f70445e80c63a236bf227c1694e3a05b Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 21 Jun 2021 14:39:06 +0300 Subject: [PATCH] Make sure ASGI ws subprotocols is a list (#2127) * Ensure protocols is a list for ASGI * Subprotocol updates --- sanic/app.py | 21 ++++++++++++--------- sanic/asgi.py | 1 - sanic/websocket.py | 23 +++++++++++++++++++---- tests/test_asgi.py | 8 ++++---- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 1d08f02e..002c92dc 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -585,7 +585,12 @@ class Sanic(BaseSanic): # determine if the parameter supplied by the caller # passes the test in the URL if param_info.pattern: - passes_pattern = param_info.pattern.match(supplied_param) + pattern = ( + param_info.pattern[1] + if isinstance(param_info.pattern, tuple) + else param_info.pattern + ) + passes_pattern = pattern.match(supplied_param) if not passes_pattern: if param_info.cast != str: msg = ( @@ -593,13 +598,13 @@ class Sanic(BaseSanic): f"for parameter `{param_info.name}` does " "not match pattern for type " f"`{param_info.cast.__name__}`: " - f"{param_info.pattern.pattern}" + f"{pattern.pattern}" ) else: msg = ( f'Value "{supplied_param}" for parameter ' f"`{param_info.name}` does not satisfy " - f"pattern {param_info.pattern.pattern}" + f"pattern {pattern.pattern}" ) raise URLBuildError(msg) @@ -740,17 +745,14 @@ class Sanic(BaseSanic): if response: response = await request.respond(response) - else: + elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore - # Make sure that response is finished / run StreamingHTTP callback + # Make sure that response is finished / run StreamingHTTP callback if isinstance(response, BaseHTTPResponse): await response.send(end_stream=True) else: - try: - # Fastest method for checking if the property exists - handler.is_websocket # type: ignore - except AttributeError: + if not hasattr(handler, "is_websocket"): raise ServerError( f"Invalid response type {response!r} " "(need HTTPResponse)" @@ -777,6 +779,7 @@ class Sanic(BaseSanic): if self.asgi: ws = request.transport.get_websocket_connection() + await ws.accept(subprotocols) else: protocol = request.transport.get_protocol() protocol.app = self diff --git a/sanic/asgi.py b/sanic/asgi.py index be598ec1..5765a5cd 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -140,7 +140,6 @@ class ASGIApp: instance.ws = instance.transport.create_websocket_connection( send, receive ) - await instance.ws.accept() else: raise ServerError("Received unknown ASGI scope") diff --git a/sanic/websocket.py b/sanic/websocket.py index c4bdd580..b5600ed7 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol): websocket_write_limit=2 ** 16, websocket_ping_interval=20, websocket_ping_timeout=20, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) self.websocket = None @@ -154,7 +154,7 @@ class WebSocketConnection: ) -> None: self._send = send self._receive = receive - self.subprotocols = subprotocols or [] + self._subprotocols = subprotocols or [] async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} @@ -178,13 +178,28 @@ class WebSocketConnection: receive = recv - async def accept(self) -> None: + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + await self._send( { "type": "websocket.accept", - "subprotocol": ",".join(list(self.subprotocols)), + "subprotocol": subprotocol, } ) async def close(self) -> None: pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/tests/test_asgi.py b/tests/test_asgi.py index a4632537..5be3fd26 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -218,7 +218,7 @@ async def test_websocket_accept_with_no_subprotocols( message = message_stack.popleft() assert message["type"] == "websocket.accept" - assert message["subprotocol"] == "" + assert message["subprotocol"] is None assert "bytes" not in message @@ -227,7 +227,7 @@ async def test_websocket_accept_with_subprotocol(send, receive, message_stack): subprotocols = ["graphql-ws"] ws = WebSocketConnection(send, receive, subprotocols) - await ws.accept() + await ws.accept(subprotocols) assert len(message_stack) == 1 @@ -244,13 +244,13 @@ async def test_websocket_accept_with_multiple_subprotocols( subprotocols = ["graphql-ws", "hello", "world"] ws = WebSocketConnection(send, receive, subprotocols) - await ws.accept() + await ws.accept(["hello", "world"]) assert len(message_stack) == 1 message = message_stack.popleft() assert message["type"] == "websocket.accept" - assert message["subprotocol"] == "graphql-ws,hello,world" + assert message["subprotocol"] == "hello" assert "bytes" not in message