Make sure ASGI ws subprotocols is a list (#2127)

* Ensure protocols is a list for ASGI

* Subprotocol updates
This commit is contained in:
Adam Hopkins 2021-06-21 14:39:06 +03:00 committed by GitHub
parent c543d19f8a
commit f39b8b32f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 18 deletions

View File

@ -585,7 +585,12 @@ class Sanic(BaseSanic):
# determine if the parameter supplied by the caller # determine if the parameter supplied by the caller
# passes the test in the URL # passes the test in the URL
if param_info.pattern: if param_info.pattern:
passes_pattern = param_info.pattern.match(supplied_param) pattern = (
param_info.pattern[1]
if isinstance(param_info.pattern, tuple)
else param_info.pattern
)
passes_pattern = pattern.match(supplied_param)
if not passes_pattern: if not passes_pattern:
if param_info.cast != str: if param_info.cast != str:
msg = ( msg = (
@ -593,13 +598,13 @@ class Sanic(BaseSanic):
f"for parameter `{param_info.name}` does " f"for parameter `{param_info.name}` does "
"not match pattern for type " "not match pattern for type "
f"`{param_info.cast.__name__}`: " f"`{param_info.cast.__name__}`: "
f"{param_info.pattern.pattern}" f"{pattern.pattern}"
) )
else: else:
msg = ( msg = (
f'Value "{supplied_param}" for parameter ' f'Value "{supplied_param}" for parameter '
f"`{param_info.name}` does not satisfy " f"`{param_info.name}` does not satisfy "
f"pattern {param_info.pattern.pattern}" f"pattern {pattern.pattern}"
) )
raise URLBuildError(msg) raise URLBuildError(msg)
@ -740,17 +745,14 @@ class Sanic(BaseSanic):
if response: if response:
response = await request.respond(response) response = await request.respond(response)
else: elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore response = request.stream.response # type: ignore
# Make sure that response is finished / run StreamingHTTP callback
# Make sure that response is finished / run StreamingHTTP callback
if isinstance(response, BaseHTTPResponse): if isinstance(response, BaseHTTPResponse):
await response.send(end_stream=True) await response.send(end_stream=True)
else: else:
try: if not hasattr(handler, "is_websocket"):
# Fastest method for checking if the property exists
handler.is_websocket # type: ignore
except AttributeError:
raise ServerError( raise ServerError(
f"Invalid response type {response!r} " f"Invalid response type {response!r} "
"(need HTTPResponse)" "(need HTTPResponse)"
@ -777,6 +779,7 @@ class Sanic(BaseSanic):
if self.asgi: if self.asgi:
ws = request.transport.get_websocket_connection() ws = request.transport.get_websocket_connection()
await ws.accept(subprotocols)
else: else:
protocol = request.transport.get_protocol() protocol = request.transport.get_protocol()
protocol.app = self protocol.app = self

View File

@ -140,7 +140,6 @@ class ASGIApp:
instance.ws = instance.transport.create_websocket_connection( instance.ws = instance.transport.create_websocket_connection(
send, receive send, receive
) )
await instance.ws.accept()
else: else:
raise ServerError("Received unknown ASGI scope") raise ServerError("Received unknown ASGI scope")

View File

@ -41,7 +41,7 @@ class WebSocketProtocol(HttpProtocol):
websocket_write_limit=2 ** 16, websocket_write_limit=2 ** 16,
websocket_ping_interval=20, websocket_ping_interval=20,
websocket_ping_timeout=20, websocket_ping_timeout=20,
**kwargs **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.websocket = None self.websocket = None
@ -154,7 +154,7 @@ class WebSocketConnection:
) -> None: ) -> None:
self._send = send self._send = send
self._receive = receive self._receive = receive
self.subprotocols = subprotocols or [] self._subprotocols = subprotocols or []
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
@ -178,13 +178,28 @@ class WebSocketConnection:
receive = recv receive = recv
async def accept(self) -> None: async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
subprotocol = None
if subprotocols:
for subp in subprotocols:
if subp in self.subprotocols:
subprotocol = subp
break
await self._send( await self._send(
{ {
"type": "websocket.accept", "type": "websocket.accept",
"subprotocol": ",".join(list(self.subprotocols)), "subprotocol": subprotocol,
} }
) )
async def close(self) -> None: async def close(self) -> None:
pass pass
@property
def subprotocols(self):
return self._subprotocols
@subprotocols.setter
def subprotocols(self, subprotocols: Optional[List[str]] = None):
self._subprotocols = subprotocols or []

View File

@ -218,7 +218,7 @@ async def test_websocket_accept_with_no_subprotocols(
message = message_stack.popleft() message = message_stack.popleft()
assert message["type"] == "websocket.accept" assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "" assert message["subprotocol"] is None
assert "bytes" not in message assert "bytes" not in message
@ -227,7 +227,7 @@ async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
subprotocols = ["graphql-ws"] subprotocols = ["graphql-ws"]
ws = WebSocketConnection(send, receive, subprotocols) ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept() await ws.accept(subprotocols)
assert len(message_stack) == 1 assert len(message_stack) == 1
@ -244,13 +244,13 @@ async def test_websocket_accept_with_multiple_subprotocols(
subprotocols = ["graphql-ws", "hello", "world"] subprotocols = ["graphql-ws", "hello", "world"]
ws = WebSocketConnection(send, receive, subprotocols) ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept() await ws.accept(["hello", "world"])
assert len(message_stack) == 1 assert len(message_stack) == 1
message = message_stack.popleft() message = message_stack.popleft()
assert message["type"] == "websocket.accept" assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws,hello,world" assert message["subprotocol"] == "hello"
assert "bytes" not in message assert "bytes" not in message