diff --git a/sanic/asgi.py b/sanic/asgi.py index 21c2a483..87cb36f8 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -88,7 +88,7 @@ class MockTransport: self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection - def add_task(self) -> None: # noqa + def add_task(self) -> None: raise NotImplementedError async def send(self, data) -> None: @@ -119,27 +119,20 @@ class Lifespan: "the ASGI server is stopped." ) - # async def pre_startup(self) -> None: - # for handler in self.asgi_app.sanic_app.listeners[ - # "before_server_start" - # ]: - # response = handler( - # self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - # ) - # if isawaitable(response): - # await response - async def startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if isawaitable(response): - await response + """ + Gather the listeners to fire on server start. + Because we are using a third-party server and not Sanic server, we do + not have access to fire anything BEFORE the server starts. + Therefore, we fire before_server_start and after_server_start + in sequence since the ASGI lifespan protocol only supports a single + startup event. + """ + listeners = self.asgi_app.sanic_app.listeners.get( + "before_server_start", [] + ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) - for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: + for handler in listeners: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -147,14 +140,19 @@ class Lifespan: await response async def shutdown(self) -> None: - for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if isawaitable(response): - await response + """ + Gather the listeners to fire on server stop. + Because we are using a third-party server and not Sanic server, we do + not have access to fire anything AFTER the server stops. + Therefore, we fire before_server_stop and after_server_stop + in sequence since the ASGI lifespan protocol only supports a single + shutdown event. + """ + listeners = self.asgi_app.sanic_app.listeners.get( + "before_server_stop", [] + ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) - for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]: + for handler in listeners: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -223,7 +221,8 @@ class ASGIApp: # TODO: # - close connection - instance.request = Request( + request_class = sanic_app.request_class or Request + instance.request = request_class( url_bytes, headers, version, @@ -267,6 +266,7 @@ class ASGIApp: message = await self.transport.receive() chunk = message.get("body", b"") await self.request.stream.put(chunk) + # self.sanic_app.loop.create_task(self.request.stream.put(chunk)) more_body = message.get("more_body", False) @@ -294,6 +294,7 @@ class ASGIApp: headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() + # if name not in ("Set-Cookie",) ] except AttributeError: logger.error( diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 911260ed..b97f2611 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -5,8 +5,11 @@ from collections import deque import pytest import uvicorn +from sanic import Sanic from sanic.asgi import MockTransport from sanic.exceptions import InvalidUsage +from sanic.request import Request +from sanic.response import text from sanic.websocket import WebSocketConnection @@ -201,3 +204,28 @@ def test_improper_websocket_connection(transport, send, receive): transport.create_websocket_connection(send, receive) connection = transport.get_websocket_connection() assert isinstance(connection, WebSocketConnection) + + +@pytest.mark.asyncio +async def test_request_class_regular(app): + @app.get("/regular") + def regular_request(request): + return text(request.__class__.__name__) + + _, response = await app.asgi_client.get("/regular") + assert response.body == b"Request" + + +@pytest.mark.asyncio +async def test_request_class_custom(): + class MyCustomRequest(Request): + pass + + app = Sanic(request_class=MyCustomRequest) + + @app.get("/custom") + def custom_request(request): + return text(request.__class__.__name__) + + _, response = await app.asgi_client.get("/custom") + assert response.body == b"MyCustomRequest" diff --git a/tests/test_config.py b/tests/test_config.py index a6ee7ec3..7d2a8395 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,8 +19,8 @@ def temp_path(): class ConfigTest: - not_for_config = 'should not be used' - CONFIG_VALUE = 'should be used' + not_for_config = "should not be used" + CONFIG_VALUE = "should be used" def test_load_from_object(app): @@ -31,15 +31,15 @@ def test_load_from_object(app): def test_load_from_object_string(app): - app.config.from_object('test_config.ConfigTest') - assert 'CONFIG_VALUE' in app.config - assert app.config.CONFIG_VALUE == 'should be used' - assert 'not_for_config' not in app.config + app.config.from_object("test_config.ConfigTest") + assert "CONFIG_VALUE" in app.config + assert app.config.CONFIG_VALUE == "should be used" + assert "not_for_config" not in app.config def test_load_from_object_string_exception(app): with pytest.raises(ImportError): - app.config.from_object('test_config.Config.test') + app.config.from_object("test_config.Config.test") def test_auto_load_env(): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 19d0e9e2..cecdd7c6 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,8 +1,9 @@ import inspect +import pytest + from sanic import helpers from sanic.config import Config -import pytest def test_has_message_body(): @@ -63,15 +64,15 @@ def test_remove_entity_headers(): def test_import_string_class(): - obj = helpers.import_string('sanic.config.Config') + obj = helpers.import_string("sanic.config.Config") assert isinstance(obj, Config) def test_import_string_module(): - module = helpers.import_string('sanic.config') + module = helpers.import_string("sanic.config") assert inspect.ismodule(module) def test_import_string_exception(): with pytest.raises(ImportError): - helpers.import_string('test.test.test') + helpers.import_string("test.test.test") diff --git a/tox.ini b/tox.ini index f4933e03..3ff30ef8 100644 --- a/tox.ini +++ b/tox.ini @@ -54,4 +54,4 @@ deps = bandit commands = - bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py \ No newline at end of file + bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py