diff --git a/sanic/asgi.py b/sanic/asgi.py index 2ae6f369..5ec13cf4 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -98,7 +98,9 @@ class MockTransport: def create_websocket_connection( self, send: ASGISend, receive: ASGIReceive ) -> WebSocketConnection: - self._websocket_connection = WebSocketConnection(send, receive) + self._websocket_connection = WebSocketConnection( + send, receive, self.scope.get("subprotocols", []) + ) return self._websocket_connection def add_task(self) -> None: diff --git a/sanic/websocket.py b/sanic/websocket.py index 4ae83c85..9443b704 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -3,6 +3,7 @@ from typing import ( Awaitable, Callable, Dict, + List, MutableMapping, Optional, Union, @@ -137,9 +138,11 @@ class WebSocketConnection: self, send: Callable[[ASIMessage], Awaitable[None]], receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, ) -> None: self._send = send self._receive = receive + self.subprotocols = subprotocols or [] async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} @@ -164,7 +167,14 @@ class WebSocketConnection: receive = recv async def accept(self) -> None: - await self._send({"type": "websocket.accept", "subprotocol": ""}) + await self._send( + { + "type": "websocket.accept", + "subprotocol": ",".join( + [subprotocol for subprotocol in self.subprotocols] + ), + } + ) async def close(self) -> None: pass diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 05b2e96d..0c728493 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -208,6 +208,53 @@ async def test_websocket_receive(send, receive, message_stack): assert text == msg["text"] +@pytest.mark.asyncio +async def test_websocket_accept_with_no_subprotocols( + send, receive, message_stack +): + ws = WebSocketConnection(send, receive) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message["subprotocol"] == "" + assert "bytes" not in message + + +@pytest.mark.asyncio +async def test_websocket_accept_with_subprotocol(send, receive, message_stack): + subprotocols = ["graphql-ws"] + + ws = WebSocketConnection(send, receive, subprotocols) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message["subprotocol"] == "graphql-ws" + assert "bytes" not in message + + +@pytest.mark.asyncio +async def test_websocket_accept_with_multiple_subprotocols( + send, receive, message_stack +): + subprotocols = ["graphql-ws", "hello", "world"] + + ws = WebSocketConnection(send, receive, subprotocols) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message["subprotocol"] == "graphql-ws,hello,world" + assert "bytes" not in message + + def test_improper_websocket_connection(transport, send, receive): with pytest.raises(InvalidUsage): transport.get_websocket_connection()