Websocket subprotocol (#1887)

* Added fix to include subprotocols from scope

* Added unit test to validate fix

* Changes by black

* Made changes to WebsocketConnection protocol

* Linter changes

* Added unit tests

* Fixing bugs in linting due to isort import checks

* Reverting compat import changes

* Fixing linter errors in compat.py

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Lee Tat Wai David 2020-07-29 19:09:26 +08:00 committed by GitHub
parent 5ee8ee7b04
commit 5d5ed10a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 2 deletions

View File

@ -98,7 +98,9 @@ class MockTransport:
def create_websocket_connection( def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive self, send: ASGISend, receive: ASGIReceive
) -> WebSocketConnection: ) -> WebSocketConnection:
self._websocket_connection = WebSocketConnection(send, receive) self._websocket_connection = WebSocketConnection(
send, receive, self.scope.get("subprotocols", [])
)
return self._websocket_connection return self._websocket_connection
def add_task(self) -> None: def add_task(self) -> None:

View File

@ -3,6 +3,7 @@ from typing import (
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
List,
MutableMapping, MutableMapping,
Optional, Optional,
Union, Union,
@ -137,9 +138,11 @@ class WebSocketConnection:
self, self,
send: Callable[[ASIMessage], Awaitable[None]], send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]], receive: Callable[[], Awaitable[ASIMessage]],
subprotocols: Optional[List[str]] = None,
) -> None: ) -> None:
self._send = send self._send = send
self._receive = receive self._receive = receive
self.subprotocols = subprotocols or []
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
@ -164,7 +167,14 @@ class WebSocketConnection:
receive = recv receive = recv
async def accept(self) -> None: async def accept(self) -> None:
await self._send({"type": "websocket.accept", "subprotocol": ""}) await self._send(
{
"type": "websocket.accept",
"subprotocol": ",".join(
[subprotocol for subprotocol in self.subprotocols]
),
}
)
async def close(self) -> None: async def close(self) -> None:
pass pass

View File

@ -208,6 +208,53 @@ async def test_websocket_receive(send, receive, message_stack):
assert text == msg["text"] assert text == msg["text"]
@pytest.mark.asyncio
async def test_websocket_accept_with_no_subprotocols(
send, receive, message_stack
):
ws = WebSocketConnection(send, receive)
await ws.accept()
assert len(message_stack) == 1
message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == ""
assert "bytes" not in message
@pytest.mark.asyncio
async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
subprotocols = ["graphql-ws"]
ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()
assert len(message_stack) == 1
message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws"
assert "bytes" not in message
@pytest.mark.asyncio
async def test_websocket_accept_with_multiple_subprotocols(
send, receive, message_stack
):
subprotocols = ["graphql-ws", "hello", "world"]
ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()
assert len(message_stack) == 1
message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws,hello,world"
assert "bytes" not in message
def test_improper_websocket_connection(transport, send, receive): def test_improper_websocket_connection(transport, send, receive):
with pytest.raises(InvalidUsage): with pytest.raises(InvalidUsage):
transport.get_websocket_connection() transport.get_websocket_connection()