Make sure ASGI ws subprotocols is a list (#2127)
* Ensure protocols is a list for ASGI * Subprotocol updates
This commit is contained in:
parent
c543d19f8a
commit
f39b8b32f7
21
sanic/app.py
21
sanic/app.py
@ -585,7 +585,12 @@ class Sanic(BaseSanic):
|
|||||||
# determine if the parameter supplied by the caller
|
# determine if the parameter supplied by the caller
|
||||||
# passes the test in the URL
|
# passes the test in the URL
|
||||||
if param_info.pattern:
|
if param_info.pattern:
|
||||||
passes_pattern = param_info.pattern.match(supplied_param)
|
pattern = (
|
||||||
|
param_info.pattern[1]
|
||||||
|
if isinstance(param_info.pattern, tuple)
|
||||||
|
else param_info.pattern
|
||||||
|
)
|
||||||
|
passes_pattern = pattern.match(supplied_param)
|
||||||
if not passes_pattern:
|
if not passes_pattern:
|
||||||
if param_info.cast != str:
|
if param_info.cast != str:
|
||||||
msg = (
|
msg = (
|
||||||
@ -593,13 +598,13 @@ class Sanic(BaseSanic):
|
|||||||
f"for parameter `{param_info.name}` does "
|
f"for parameter `{param_info.name}` does "
|
||||||
"not match pattern for type "
|
"not match pattern for type "
|
||||||
f"`{param_info.cast.__name__}`: "
|
f"`{param_info.cast.__name__}`: "
|
||||||
f"{param_info.pattern.pattern}"
|
f"{pattern.pattern}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f'Value "{supplied_param}" for parameter '
|
f'Value "{supplied_param}" for parameter '
|
||||||
f"`{param_info.name}` does not satisfy "
|
f"`{param_info.name}` does not satisfy "
|
||||||
f"pattern {param_info.pattern.pattern}"
|
f"pattern {pattern.pattern}"
|
||||||
)
|
)
|
||||||
raise URLBuildError(msg)
|
raise URLBuildError(msg)
|
||||||
|
|
||||||
@ -740,17 +745,14 @@ class Sanic(BaseSanic):
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
response = await request.respond(response)
|
response = await request.respond(response)
|
||||||
else:
|
elif not hasattr(handler, "is_websocket"):
|
||||||
response = request.stream.response # type: ignore
|
response = request.stream.response # type: ignore
|
||||||
# Make sure that response is finished / run StreamingHTTP callback
|
|
||||||
|
|
||||||
|
# Make sure that response is finished / run StreamingHTTP callback
|
||||||
if isinstance(response, BaseHTTPResponse):
|
if isinstance(response, BaseHTTPResponse):
|
||||||
await response.send(end_stream=True)
|
await response.send(end_stream=True)
|
||||||
else:
|
else:
|
||||||
try:
|
if not hasattr(handler, "is_websocket"):
|
||||||
# Fastest method for checking if the property exists
|
|
||||||
handler.is_websocket # type: ignore
|
|
||||||
except AttributeError:
|
|
||||||
raise ServerError(
|
raise ServerError(
|
||||||
f"Invalid response type {response!r} "
|
f"Invalid response type {response!r} "
|
||||||
"(need HTTPResponse)"
|
"(need HTTPResponse)"
|
||||||
@ -777,6 +779,7 @@ class Sanic(BaseSanic):
|
|||||||
|
|
||||||
if self.asgi:
|
if self.asgi:
|
||||||
ws = request.transport.get_websocket_connection()
|
ws = request.transport.get_websocket_connection()
|
||||||
|
await ws.accept(subprotocols)
|
||||||
else:
|
else:
|
||||||
protocol = request.transport.get_protocol()
|
protocol = request.transport.get_protocol()
|
||||||
protocol.app = self
|
protocol.app = self
|
||||||
|
@ -140,7 +140,6 @@ class ASGIApp:
|
|||||||
instance.ws = instance.transport.create_websocket_connection(
|
instance.ws = instance.transport.create_websocket_connection(
|
||||||
send, receive
|
send, receive
|
||||||
)
|
)
|
||||||
await instance.ws.accept()
|
|
||||||
else:
|
else:
|
||||||
raise ServerError("Received unknown ASGI scope")
|
raise ServerError("Received unknown ASGI scope")
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||||||
websocket_write_limit=2 ** 16,
|
websocket_write_limit=2 ** 16,
|
||||||
websocket_ping_interval=20,
|
websocket_ping_interval=20,
|
||||||
websocket_ping_timeout=20,
|
websocket_ping_timeout=20,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.websocket = None
|
self.websocket = None
|
||||||
@ -154,7 +154,7 @@ class WebSocketConnection:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._send = send
|
self._send = send
|
||||||
self._receive = receive
|
self._receive = receive
|
||||||
self.subprotocols = subprotocols or []
|
self._subprotocols = subprotocols or []
|
||||||
|
|
||||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||||
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
||||||
@ -178,13 +178,28 @@ class WebSocketConnection:
|
|||||||
|
|
||||||
receive = recv
|
receive = recv
|
||||||
|
|
||||||
async def accept(self) -> None:
|
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
||||||
|
subprotocol = None
|
||||||
|
if subprotocols:
|
||||||
|
for subp in subprotocols:
|
||||||
|
if subp in self.subprotocols:
|
||||||
|
subprotocol = subp
|
||||||
|
break
|
||||||
|
|
||||||
await self._send(
|
await self._send(
|
||||||
{
|
{
|
||||||
"type": "websocket.accept",
|
"type": "websocket.accept",
|
||||||
"subprotocol": ",".join(list(self.subprotocols)),
|
"subprotocol": subprotocol,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subprotocols(self):
|
||||||
|
return self._subprotocols
|
||||||
|
|
||||||
|
@subprotocols.setter
|
||||||
|
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
||||||
|
self._subprotocols = subprotocols or []
|
||||||
|
@ -218,7 +218,7 @@ async def test_websocket_accept_with_no_subprotocols(
|
|||||||
|
|
||||||
message = message_stack.popleft()
|
message = message_stack.popleft()
|
||||||
assert message["type"] == "websocket.accept"
|
assert message["type"] == "websocket.accept"
|
||||||
assert message["subprotocol"] == ""
|
assert message["subprotocol"] is None
|
||||||
assert "bytes" not in message
|
assert "bytes" not in message
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +227,7 @@ async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
|
|||||||
subprotocols = ["graphql-ws"]
|
subprotocols = ["graphql-ws"]
|
||||||
|
|
||||||
ws = WebSocketConnection(send, receive, subprotocols)
|
ws = WebSocketConnection(send, receive, subprotocols)
|
||||||
await ws.accept()
|
await ws.accept(subprotocols)
|
||||||
|
|
||||||
assert len(message_stack) == 1
|
assert len(message_stack) == 1
|
||||||
|
|
||||||
@ -244,13 +244,13 @@ async def test_websocket_accept_with_multiple_subprotocols(
|
|||||||
subprotocols = ["graphql-ws", "hello", "world"]
|
subprotocols = ["graphql-ws", "hello", "world"]
|
||||||
|
|
||||||
ws = WebSocketConnection(send, receive, subprotocols)
|
ws = WebSocketConnection(send, receive, subprotocols)
|
||||||
await ws.accept()
|
await ws.accept(["hello", "world"])
|
||||||
|
|
||||||
assert len(message_stack) == 1
|
assert len(message_stack) == 1
|
||||||
|
|
||||||
message = message_stack.popleft()
|
message = message_stack.popleft()
|
||||||
assert message["type"] == "websocket.accept"
|
assert message["type"] == "websocket.accept"
|
||||||
assert message["subprotocol"] == "graphql-ws,hello,world"
|
assert message["subprotocol"] == "hello"
|
||||||
assert "bytes" not in message
|
assert "bytes" not in message
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user