diff --git a/.coveragerc b/.coveragerc index 724b2872..60831593 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py [html] directory = coverage + +[report] +exclude_lines = + no cov + no qa + noqa + NOQA + pragma: no cover diff --git a/sanic/asgi.py b/sanic/asgi.py index 7d820350..21c2a483 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -88,7 +88,7 @@ class MockTransport: self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection - def add_task(self) -> None: + def add_task(self) -> None: # noqa raise NotImplementedError async def send(self, data) -> None: @@ -119,15 +119,15 @@ class Lifespan: "the ASGI server is stopped." ) - async def pre_startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if isawaitable(response): - await response + # async def pre_startup(self) -> None: + # for handler in self.asgi_app.sanic_app.listeners[ + # "before_server_start" + # ]: + # response = handler( + # self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + # ) + # if isawaitable(response): + # await response async def startup(self) -> None: for handler in self.asgi_app.sanic_app.listeners[ @@ -233,7 +233,14 @@ class ASGIApp: ) if sanic_app.is_request_stream: - instance.request.stream = StreamBuffer() + is_stream_handler = sanic_app.router.is_stream_handler( + instance.request + ) + if is_stream_handler: + instance.request.stream = StreamBuffer( + sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE + ) + instance.do_stream = True return instance diff --git a/sanic/testing.py b/sanic/testing.py index 0755fb9e..06d75fc1 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -136,7 +136,7 @@ class SanicTestClient: try: request, response = results return request, response - except BaseException: + except BaseException: # noqa raise ValueError( "Request and response object expected, got ({})".format( results @@ -145,7 +145,7 @@ class SanicTestClient: else: try: return results[-1] - except BaseException: + except BaseException: # noqa raise ValueError( "Request object expected, got ({})".format(results) ) @@ -175,7 +175,7 @@ class SanicTestClient: return self._sanic_endpoint_test("websocket", *args, **kwargs) -class SanicASGIAdapter(requests.asgi.ASGIAdapter): +class SanicASGIAdapter(requests.asgi.ASGIAdapter): # noqa async def send( # type: ignore self, request: requests.PreparedRequest, @@ -218,19 +218,43 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter): for key, value in request.headers.items() ] - scope = { - "type": "http", - "http_version": "1.1", - "method": request.method, - "path": unquote(path), - "root_path": "", - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "extensions": {"http.response.template": {}}, - } + no_response = False + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols = [] # type: typing.Sequence[str] + else: + subprotocols = [ + value.strip() for value in subprotocol.split(",") + ] + + scope = { + "type": "websocket", + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "subprotocols": subprotocols, + } + no_response = True + + else: + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } async def receive(): nonlocal request_complete, response_complete @@ -306,6 +330,10 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter): if not self.suppress_exceptions: raise exc from None + if no_response: + response_started = True + raw_kwargs = {"status_code": 204, "headers": []} + if not self.suppress_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: @@ -349,13 +377,15 @@ class SanicASGITestClient(requests.ASGISession): ) self.mount("http://", adapter) self.mount("https://", adapter) + self.mount("ws://", adapter) + self.mount("wss://", adapter) self.headers.update({"user-agent": "testclient"}) self.app = app self.base_url = base_url async def request(self, method, url, gather_request=True, *args, **kwargs): + self.gather_request = gather_request - print(url) response = await super().request(method, url, *args, **kwargs) response.status = response.status_code response.body = response.content @@ -372,3 +402,22 @@ class SanicASGITestClient(requests.ASGISession): settings = super().merge_environment_settings(*args, **kwargs) settings.update({"gather_return": self.gather_request}) return settings + + async def websocket(self, uri, subprotocols=None, *args, **kwargs): + if uri.startswith(("ws:", "wss:")): + url = uri + else: + uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) + url = "ws://testserver{uri}".format(uri=uri) + + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault( + "sec-websocket-protocol", ", ".join(subprotocols) + ) + kwargs["headers"] = headers + + return await self.request("websocket", url, **kwargs) diff --git a/sanic/websocket.py b/sanic/websocket.py index ff321284..f87188e4 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -143,9 +143,8 @@ class WebSocketConnection: return message["text"] elif message["type"] == "websocket.disconnect": pass - # await self._send({ - # "type": "websocket.close" - # }) + + receive = recv async def accept(self) -> None: await self._send({"type": "websocket.accept", "subprotocol": ""}) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d0fa1d91..911260ed 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,203 @@ -from sanic.testing import SanicASGITestClient +import asyncio + +from collections import deque + +import pytest +import uvicorn + +from sanic.asgi import MockTransport +from sanic.exceptions import InvalidUsage +from sanic.websocket import WebSocketConnection -def test_asgi_client_instantiation(app): - assert isinstance(app.asgi_client, SanicASGITestClient) +@pytest.fixture +def message_stack(): + return deque() + + +@pytest.fixture +def receive(message_stack): + async def _receive(): + return message_stack.popleft() + + return _receive + + +@pytest.fixture +def send(message_stack): + async def _send(message): + message_stack.append(message) + + return _send + + +@pytest.fixture +def transport(message_stack, receive, send): + return MockTransport({}, receive, send) + + +@pytest.fixture +# @pytest.mark.asyncio +def protocol(transport, loop): + return transport.get_protocol() + + +def test_listeners_triggered(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +def test_listeners_triggered_async(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + async def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + async def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + async def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + async def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +@pytest.mark.asyncio +async def test_mockprotocol_events(protocol): + assert protocol._not_paused.is_set() + protocol.pause_writing() + assert not protocol._not_paused.is_set() + protocol.resume_writing() + assert protocol._not_paused.is_set() + + +@pytest.mark.asyncio +async def test_protocol_push_data(protocol, message_stack): + text = b"hello" + + await protocol.push_data(text) + await protocol.complete() + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert message["more_body"] + assert message["body"] == text + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert not message["more_body"] + assert message["body"] == b"" + + +@pytest.mark.asyncio +async def test_websocket_send(send, receive, message_stack): + text_string = "hello" + text_bytes = b"hello" + + ws = WebSocketConnection(send, receive) + await ws.send(text_string) + await ws.send(text_bytes) + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["text"] == text_string + assert "bytes" not in message + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["bytes"] == text_bytes + assert "text" not in message + + +@pytest.mark.asyncio +async def test_websocket_receive(send, receive, message_stack): + msg = {"text": "hello", "type": "websocket.receive"} + message_stack.append(msg) + + ws = WebSocketConnection(send, receive) + text = await ws.receive() + + assert text == msg["text"] + + +def test_improper_websocket_connection(transport, send, receive): + with pytest.raises(InvalidUsage): + transport.get_websocket_connection() + + transport.create_websocket_connection(send, receive) + connection = transport.get_websocket_connection() + assert isinstance(connection, WebSocketConnection) diff --git a/tests/test_asgi_client.py b/tests/test_asgi_client.py new file mode 100644 index 00000000..d0fa1d91 --- /dev/null +++ b/tests/test_asgi_client.py @@ -0,0 +1,5 @@ +from sanic.testing import SanicASGITestClient + + +def test_asgi_client_instantiation(app): + assert isinstance(app.asgi_client, SanicASGITestClient) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 38c33acd..8f893e2b 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -197,6 +197,121 @@ def test_request_stream_app(app): assert response.text == data +@pytest.mark.asyncio +async def test_request_stream_app_asgi(app): + """for self.is_request_stream = True and decorators""" + + @app.get("/get") + async def get(request): + assert request.stream is None + return text("GET") + + @app.head("/head") + async def head(request): + assert request.stream is None + return text("HEAD") + + @app.delete("/delete") + async def delete(request): + assert request.stream is None + return text("DELETE") + + @app.options("/options") + async def options(request): + assert request.stream is None + return text("OPTIONS") + + @app.post("/_post/") + async def _post(request, id): + assert request.stream is None + return text("_POST") + + @app.post("/post/", stream=True) + async def post(request, id): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.put("/_put") + async def _put(request): + assert request.stream is None + return text("_PUT") + + @app.put("/put", stream=True) + async def put(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.patch("/_patch") + async def _patch(request): + assert request.stream is None + return text("_PATCH") + + @app.patch("/patch", stream=True) + async def patch(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + assert app.is_request_stream is True + + request, response = await app.asgi_client.get("/get") + assert response.status == 200 + assert response.text == "GET" + + request, response = await app.asgi_client.head("/head") + assert response.status == 200 + assert response.text == "" + + request, response = await app.asgi_client.delete("/delete") + assert response.status == 200 + assert response.text == "DELETE" + + request, response = await app.asgi_client.options("/options") + assert response.status == 200 + assert response.text == "OPTIONS" + + request, response = await app.asgi_client.post("/_post/1", data=data) + assert response.status == 200 + assert response.text == "_POST" + + request, response = await app.asgi_client.post("/post/1", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.put("/_put", data=data) + assert response.status == 200 + assert response.text == "_PUT" + + request, response = await app.asgi_client.put("/put", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.patch("/_patch", data=data) + assert response.status == 200 + assert response.text == "_PATCH" + + request, response = await app.asgi_client.patch("/patch", data=data) + assert response.status == 200 + assert response.text == data + + def test_request_stream_handle_exception(app): """for handling exceptions properly""" diff --git a/tests/test_routes.py b/tests/test_routes.py index 4617803e..3b24389f 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -474,6 +474,19 @@ def test_websocket_route(app, url): assert ev.is_set() +@pytest.mark.asyncio +@pytest.mark.parametrize("url", ["/ws", "ws"]) +async def test_websocket_route_asgi(app, url): + ev = asyncio.Event() + + @app.websocket(url) + async def handler(request, ws): + ev.set() + + request, response = await app.asgi_client.websocket(url) + assert ev.is_set() + + def test_websocket_route_with_subprotocols(app): results = [] diff --git a/tox.ini b/tox.ini index 616b7acd..80694198 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ deps = beautifulsoup4 gunicorn pytest-benchmark + uvicorn commands = pytest {posargs:tests --cov sanic} - coverage combine --append