fix(websocket): ASGI websocket must pass thru bytes as is (#2651)
This commit is contained in:
parent
c7a71cd00c
commit
5e7f6998bd
|
@ -45,7 +45,7 @@ class WebSocketConnection:
|
||||||
|
|
||||||
await self._send(message)
|
await self._send(message)
|
||||||
|
|
||||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
async def recv(self, *args, **kwargs) -> Optional[Union[str, bytes]]:
|
||||||
message = await self._receive()
|
message = await self._receive()
|
||||||
|
|
||||||
if message["type"] == "websocket.receive":
|
if message["type"] == "websocket.receive":
|
||||||
|
@ -53,7 +53,7 @@ class WebSocketConnection:
|
||||||
return message["text"]
|
return message["text"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
return message["bytes"].decode()
|
return message["bytes"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidUsage("Bad ASGI message received")
|
raise InvalidUsage("Bad ASGI message received")
|
||||||
elif message["type"] == "websocket.disconnect":
|
elif message["type"] == "websocket.disconnect":
|
||||||
|
|
|
@ -342,7 +342,7 @@ async def test_websocket_send(send, receive, message_stack):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_websocket_receive(send, receive, message_stack):
|
async def test_websocket_text_receive(send, receive, message_stack):
|
||||||
msg = {"text": "hello", "type": "websocket.receive"}
|
msg = {"text": "hello", "type": "websocket.receive"}
|
||||||
message_stack.append(msg)
|
message_stack.append(msg)
|
||||||
|
|
||||||
|
@ -351,6 +351,15 @@ async def test_websocket_receive(send, receive, message_stack):
|
||||||
|
|
||||||
assert text == msg["text"]
|
assert text == msg["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_bytes_receive(send, receive, message_stack):
|
||||||
|
msg = {"bytes": b"hello", "type": "websocket.receive"}
|
||||||
|
message_stack.append(msg)
|
||||||
|
|
||||||
|
ws = WebSocketConnection(send, receive)
|
||||||
|
data = await ws.receive()
|
||||||
|
|
||||||
|
assert data == msg["bytes"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_websocket_accept_with_no_subprotocols(
|
async def test_websocket_accept_with_no_subprotocols(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user