Websocket subprotocol (#1887)
* Added fix to include subprotocols from scope * Added unit test to validate fix * Changes by black * Made changes to WebsocketConnection protocol * Linter changes * Added unit tests * Fixing bugs in linting due to isort import checks * Reverting compat import changes * Fixing linter errors in compat.py Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
		 Lee Tat Wai David
					Lee Tat Wai David
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							5ee8ee7b04
						
					
				
				
					commit
					5d5ed10a45
				
			| @@ -98,7 +98,9 @@ class MockTransport: | ||||
|     def create_websocket_connection( | ||||
|         self, send: ASGISend, receive: ASGIReceive | ||||
|     ) -> WebSocketConnection: | ||||
|         self._websocket_connection = WebSocketConnection(send, receive) | ||||
|         self._websocket_connection = WebSocketConnection( | ||||
|             send, receive, self.scope.get("subprotocols", []) | ||||
|         ) | ||||
|         return self._websocket_connection | ||||
|  | ||||
|     def add_task(self) -> None: | ||||
|   | ||||
| @@ -3,6 +3,7 @@ from typing import ( | ||||
|     Awaitable, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     List, | ||||
|     MutableMapping, | ||||
|     Optional, | ||||
|     Union, | ||||
| @@ -137,9 +138,11 @@ class WebSocketConnection: | ||||
|         self, | ||||
|         send: Callable[[ASIMessage], Awaitable[None]], | ||||
|         receive: Callable[[], Awaitable[ASIMessage]], | ||||
|         subprotocols: Optional[List[str]] = None, | ||||
|     ) -> None: | ||||
|         self._send = send | ||||
|         self._receive = receive | ||||
|         self.subprotocols = subprotocols or [] | ||||
|  | ||||
|     async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: | ||||
|         message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} | ||||
| @@ -164,7 +167,14 @@ class WebSocketConnection: | ||||
|     receive = recv | ||||
|  | ||||
|     async def accept(self) -> None: | ||||
|         await self._send({"type": "websocket.accept", "subprotocol": ""}) | ||||
|         await self._send( | ||||
|             { | ||||
|                 "type": "websocket.accept", | ||||
|                 "subprotocol": ",".join( | ||||
|                     [subprotocol for subprotocol in self.subprotocols] | ||||
|                 ), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     async def close(self) -> None: | ||||
|         pass | ||||
|   | ||||
| @@ -208,6 +208,53 @@ async def test_websocket_receive(send, receive, message_stack): | ||||
|     assert text == msg["text"] | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_websocket_accept_with_no_subprotocols( | ||||
|     send, receive, message_stack | ||||
| ): | ||||
|     ws = WebSocketConnection(send, receive) | ||||
|     await ws.accept() | ||||
|  | ||||
|     assert len(message_stack) == 1 | ||||
|  | ||||
|     message = message_stack.popleft() | ||||
|     assert message["type"] == "websocket.accept" | ||||
|     assert message["subprotocol"] == "" | ||||
|     assert "bytes" not in message | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_websocket_accept_with_subprotocol(send, receive, message_stack): | ||||
|     subprotocols = ["graphql-ws"] | ||||
|  | ||||
|     ws = WebSocketConnection(send, receive, subprotocols) | ||||
|     await ws.accept() | ||||
|  | ||||
|     assert len(message_stack) == 1 | ||||
|  | ||||
|     message = message_stack.popleft() | ||||
|     assert message["type"] == "websocket.accept" | ||||
|     assert message["subprotocol"] == "graphql-ws" | ||||
|     assert "bytes" not in message | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_websocket_accept_with_multiple_subprotocols( | ||||
|     send, receive, message_stack | ||||
| ): | ||||
|     subprotocols = ["graphql-ws", "hello", "world"] | ||||
|  | ||||
|     ws = WebSocketConnection(send, receive, subprotocols) | ||||
|     await ws.accept() | ||||
|  | ||||
|     assert len(message_stack) == 1 | ||||
|  | ||||
|     message = message_stack.popleft() | ||||
|     assert message["type"] == "websocket.accept" | ||||
|     assert message["subprotocol"] == "graphql-ws,hello,world" | ||||
|     assert "bytes" not in message | ||||
|  | ||||
|  | ||||
| def test_improper_websocket_connection(transport, send, receive): | ||||
|     with pytest.raises(InvalidUsage): | ||||
|         transport.get_websocket_connection() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user