Make sure ASGI ws subprotocols is a list (#2127)

* Ensure protocols is a list for ASGI

* Subprotocol updates
This commit is contained in:
Adam Hopkins 2021-06-21 14:39:06 +03:00 committed by GitHub
parent c543d19f8a
commit f39b8b32f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 18 deletions

View File

@ -585,7 +585,12 @@ class Sanic(BaseSanic):
# determine if the parameter supplied by the caller
# passes the test in the URL
if param_info.pattern:
passes_pattern = param_info.pattern.match(supplied_param)
pattern = (
param_info.pattern[1]
if isinstance(param_info.pattern, tuple)
else param_info.pattern
)
passes_pattern = pattern.match(supplied_param)
if not passes_pattern:
if param_info.cast != str:
msg = (
@ -593,13 +598,13 @@ class Sanic(BaseSanic):
f"for parameter `{param_info.name}` does "
"not match pattern for type "
f"`{param_info.cast.__name__}`: "
f"{param_info.pattern.pattern}"
f"{pattern.pattern}"
)
else:
msg = (
f'Value "{supplied_param}" for parameter '
f"`{param_info.name}` does not satisfy "
f"pattern {param_info.pattern.pattern}"
f"pattern {pattern.pattern}"
)
raise URLBuildError(msg)
@ -740,17 +745,14 @@ class Sanic(BaseSanic):
if response:
response = await request.respond(response)
else:
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore
# Make sure that response is finished / run StreamingHTTP callback
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True)
else:
try:
# Fastest method for checking if the property exists
handler.is_websocket # type: ignore
except AttributeError:
if not hasattr(handler, "is_websocket"):
raise ServerError(
f"Invalid response type {response!r} "
"(need HTTPResponse)"
@ -777,6 +779,7 @@ class Sanic(BaseSanic):
if self.asgi:
ws = request.transport.get_websocket_connection()
await ws.accept(subprotocols)
else:
protocol = request.transport.get_protocol()
protocol.app = self

View File

@ -140,7 +140,6 @@ class ASGIApp:
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept()
else:
raise ServerError("Received unknown ASGI scope")

View File

@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
websocket_write_limit=2 ** 16,
websocket_ping_interval=20,
websocket_ping_timeout=20,
**kwargs
**kwargs,
):
super().__init__(*args, **kwargs)
self.websocket = None
@ -154,7 +154,7 @@ class WebSocketConnection:
) -> None:
self._send = send
self._receive = receive
self.subprotocols = subprotocols or []
self._subprotocols = subprotocols or []
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
@ -178,13 +178,28 @@ class WebSocketConnection:
receive = recv
async def accept(self) -> None:
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
subprotocol = None
if subprotocols:
for subp in subprotocols:
if subp in self.subprotocols:
subprotocol = subp
break
await self._send(
{
"type": "websocket.accept",
"subprotocol": ",".join(list(self.subprotocols)),
"subprotocol": subprotocol,
}
)
async def close(self) -> None:
pass
@property
def subprotocols(self):
return self._subprotocols
@subprotocols.setter
def subprotocols(self, subprotocols: Optional[List[str]] = None):
self._subprotocols = subprotocols or []

View File

@ -218,7 +218,7 @@ async def test_websocket_accept_with_no_subprotocols(
message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == ""
assert message["subprotocol"] is None
assert "bytes" not in message
@ -227,7 +227,7 @@ async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
subprotocols = ["graphql-ws"]
ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()
await ws.accept(subprotocols)
assert len(message_stack) == 1
@ -244,13 +244,13 @@ async def test_websocket_accept_with_multiple_subprotocols(
subprotocols = ["graphql-ws", "hello", "world"]
ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()
await ws.accept(["hello", "world"])
assert len(message_stack) == 1
message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws,hello,world"
assert message["subprotocol"] == "hello"
assert "bytes" not in message