fix(websocket): ASGI websocket must pass thru bytes as is (#2651)

This commit is contained in:
Ryu Juheon 2023-02-05 23:41:54 +09:00 committed by GitHub
parent c7a71cd00c
commit 5e7f6998bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 3 deletions

View File

@ -45,7 +45,7 @@ class WebSocketConnection:
await self._send(message) await self._send(message)
async def recv(self, *args, **kwargs) -> Optional[str]: async def recv(self, *args, **kwargs) -> Optional[Union[str, bytes]]:
message = await self._receive() message = await self._receive()
if message["type"] == "websocket.receive": if message["type"] == "websocket.receive":
@ -53,7 +53,7 @@ class WebSocketConnection:
return message["text"] return message["text"]
except KeyError: except KeyError:
try: try:
return message["bytes"].decode() return message["bytes"]
except KeyError: except KeyError:
raise InvalidUsage("Bad ASGI message received") raise InvalidUsage("Bad ASGI message received")
elif message["type"] == "websocket.disconnect": elif message["type"] == "websocket.disconnect":

View File

@ -342,7 +342,7 @@ async def test_websocket_send(send, receive, message_stack):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_receive(send, receive, message_stack): async def test_websocket_text_receive(send, receive, message_stack):
msg = {"text": "hello", "type": "websocket.receive"} msg = {"text": "hello", "type": "websocket.receive"}
message_stack.append(msg) message_stack.append(msg)
@ -351,6 +351,15 @@ async def test_websocket_receive(send, receive, message_stack):
assert text == msg["text"] assert text == msg["text"]
@pytest.mark.asyncio
async def test_websocket_bytes_receive(send, receive, message_stack):
msg = {"bytes": b"hello", "type": "websocket.receive"}
message_stack.append(msg)
ws = WebSocketConnection(send, receive)
data = await ws.receive()
assert data == msg["bytes"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_accept_with_no_subprotocols( async def test_websocket_accept_with_no_subprotocols(