Make sure ASGI ws subprotocols is a list (#2127)
* Ensure protocols is a list for ASGI * Subprotocol updates
This commit is contained in:
parent
c543d19f8a
commit
f39b8b32f7
21
sanic/app.py
21
sanic/app.py
|
@ -585,7 +585,12 @@ class Sanic(BaseSanic):
|
|||
# determine if the parameter supplied by the caller
|
||||
# passes the test in the URL
|
||||
if param_info.pattern:
|
||||
passes_pattern = param_info.pattern.match(supplied_param)
|
||||
pattern = (
|
||||
param_info.pattern[1]
|
||||
if isinstance(param_info.pattern, tuple)
|
||||
else param_info.pattern
|
||||
)
|
||||
passes_pattern = pattern.match(supplied_param)
|
||||
if not passes_pattern:
|
||||
if param_info.cast != str:
|
||||
msg = (
|
||||
|
@ -593,13 +598,13 @@ class Sanic(BaseSanic):
|
|||
f"for parameter `{param_info.name}` does "
|
||||
"not match pattern for type "
|
||||
f"`{param_info.cast.__name__}`: "
|
||||
f"{param_info.pattern.pattern}"
|
||||
f"{pattern.pattern}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f'Value "{supplied_param}" for parameter '
|
||||
f"`{param_info.name}` does not satisfy "
|
||||
f"pattern {param_info.pattern.pattern}"
|
||||
f"pattern {pattern.pattern}"
|
||||
)
|
||||
raise URLBuildError(msg)
|
||||
|
||||
|
@ -740,17 +745,14 @@ class Sanic(BaseSanic):
|
|||
|
||||
if response:
|
||||
response = await request.respond(response)
|
||||
else:
|
||||
elif not hasattr(handler, "is_websocket"):
|
||||
response = request.stream.response # type: ignore
|
||||
# Make sure that response is finished / run StreamingHTTP callback
|
||||
|
||||
# Make sure that response is finished / run StreamingHTTP callback
|
||||
if isinstance(response, BaseHTTPResponse):
|
||||
await response.send(end_stream=True)
|
||||
else:
|
||||
try:
|
||||
# Fastest method for checking if the property exists
|
||||
handler.is_websocket # type: ignore
|
||||
except AttributeError:
|
||||
if not hasattr(handler, "is_websocket"):
|
||||
raise ServerError(
|
||||
f"Invalid response type {response!r} "
|
||||
"(need HTTPResponse)"
|
||||
|
@ -777,6 +779,7 @@ class Sanic(BaseSanic):
|
|||
|
||||
if self.asgi:
|
||||
ws = request.transport.get_websocket_connection()
|
||||
await ws.accept(subprotocols)
|
||||
else:
|
||||
protocol = request.transport.get_protocol()
|
||||
protocol.app = self
|
||||
|
|
|
@ -140,7 +140,6 @@ class ASGIApp:
|
|||
instance.ws = instance.transport.create_websocket_connection(
|
||||
send, receive
|
||||
)
|
||||
await instance.ws.accept()
|
||||
else:
|
||||
raise ServerError("Received unknown ASGI scope")
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||
websocket_write_limit=2 ** 16,
|
||||
websocket_ping_interval=20,
|
||||
websocket_ping_timeout=20,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
|
@ -154,7 +154,7 @@ class WebSocketConnection:
|
|||
) -> None:
|
||||
self._send = send
|
||||
self._receive = receive
|
||||
self.subprotocols = subprotocols or []
|
||||
self._subprotocols = subprotocols or []
|
||||
|
||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
||||
|
@ -178,13 +178,28 @@ class WebSocketConnection:
|
|||
|
||||
receive = recv
|
||||
|
||||
async def accept(self) -> None:
|
||||
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
||||
subprotocol = None
|
||||
if subprotocols:
|
||||
for subp in subprotocols:
|
||||
if subp in self.subprotocols:
|
||||
subprotocol = subp
|
||||
break
|
||||
|
||||
await self._send(
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": ",".join(list(self.subprotocols)),
|
||||
"subprotocol": subprotocol,
|
||||
}
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def subprotocols(self):
|
||||
return self._subprotocols
|
||||
|
||||
@subprotocols.setter
|
||||
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
||||
self._subprotocols = subprotocols or []
|
||||
|
|
|
@ -218,7 +218,7 @@ async def test_websocket_accept_with_no_subprotocols(
|
|||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "websocket.accept"
|
||||
assert message["subprotocol"] == ""
|
||||
assert message["subprotocol"] is None
|
||||
assert "bytes" not in message
|
||||
|
||||
|
||||
|
@ -227,7 +227,7 @@ async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
|
|||
subprotocols = ["graphql-ws"]
|
||||
|
||||
ws = WebSocketConnection(send, receive, subprotocols)
|
||||
await ws.accept()
|
||||
await ws.accept(subprotocols)
|
||||
|
||||
assert len(message_stack) == 1
|
||||
|
||||
|
@ -244,13 +244,13 @@ async def test_websocket_accept_with_multiple_subprotocols(
|
|||
subprotocols = ["graphql-ws", "hello", "world"]
|
||||
|
||||
ws = WebSocketConnection(send, receive, subprotocols)
|
||||
await ws.accept()
|
||||
await ws.accept(["hello", "world"])
|
||||
|
||||
assert len(message_stack) == 1
|
||||
|
||||
message = message_stack.popleft()
|
||||
assert message["type"] == "websocket.accept"
|
||||
assert message["subprotocol"] == "graphql-ws,hello,world"
|
||||
assert message["subprotocol"] == "hello"
|
||||
assert "bytes" not in message
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user