diff --git a/sanic/app.py b/sanic/app.py index 172ed877..1adc2732 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -48,7 +48,7 @@ from sanic_routing.route import Route from sanic.application.ext import setup_ext from sanic.application.state import ApplicationState, ServerStage -from sanic.asgi import ASGIApp +from sanic.asgi import ASGIApp, Lifespan from sanic.base.root import BaseSanic from sanic.blueprint_group import BlueprintGroup from sanic.blueprints import Blueprint @@ -119,6 +119,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): ) __slots__ = ( "_asgi_app", + "_asgi_lifespan", "_asgi_client", "_blueprint_order", "_delayed_tasks", @@ -198,6 +199,8 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): self.config.INSPECTOR = inspector # Then we can do the rest + self._asgi_app: Optional[ASGIApp] = None + self._asgi_lifespan: Optional[Lifespan] = None self._asgi_client: Any = None self._blueprint_order: List[Blueprint] = [] self._delayed_tasks: List[str] = [] @@ -1349,12 +1352,14 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): three arguments: scope, receive, send. See the ASGI reference for more details: https://asgi.readthedocs.io/en/latest """ - self.asgi = True if scope["type"] == "lifespan": + self.asgi = True self.motd("") - self._asgi_app = await ASGIApp.create(self, scope, receive, send) - asgi_app = self._asgi_app - await asgi_app() + self._asgi_lifespan = Lifespan(self, scope, receive, send) + await self._asgi_lifespan() + else: + self._asgi_app = await ASGIApp.create(self, scope, receive, send) + await self._asgi_app() _asgi_single_callable = True # We conform to ASGI 3.0 single-callable diff --git a/sanic/asgi.py b/sanic/asgi.py index d8ab4cfa..e5c07c2a 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -22,13 +22,15 @@ if TYPE_CHECKING: class Lifespan: - def __init__(self, asgi_app: ASGIApp) -> None: - self.asgi_app = asgi_app + def __init__( + self, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: + self.sanic_app = sanic_app + self.scope = scope + self.receive = receive + self.send = send - if ( - "server.init.before" - in self.asgi_app.sanic_app.signal_router.name_index - ): + if "server.init.before" in self.sanic_app.signal_router.name_index: logger.debug( 'You have set a listener for "before_server_start" ' "in ASGI mode. " @@ -36,10 +38,7 @@ class Lifespan: "the ASGI server is started.", extra={"verbosity": 1}, ) - if ( - "server.shutdown.after" - in self.asgi_app.sanic_app.signal_router.name_index - ): + if "server.shutdown.after" in self.sanic_app.signal_router.name_index: logger.debug( 'You have set a listener for "after_server_stop" ' "in ASGI mode. " @@ -57,11 +56,11 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single startup event. """ - await self.asgi_app.sanic_app._startup() - await self.asgi_app.sanic_app._server_event("init", "before") - await self.asgi_app.sanic_app._server_event("init", "after") + await self.sanic_app._startup() + await self.sanic_app._server_event("init", "before") + await self.sanic_app._server_event("init", "after") - if not isinstance(self.asgi_app.sanic_app.config.USE_UVLOOP, Default): + if not isinstance(self.sanic_app.config.USE_UVLOOP, Default): warnings.warn( "You have set the USE_UVLOOP configuration option, but Sanic " "cannot control the event loop when running in ASGI mode." @@ -77,35 +76,33 @@ class Lifespan: in sequence since the ASGI lifespan protocol only supports a single shutdown event. """ - await self.asgi_app.sanic_app._server_event("shutdown", "before") - await self.asgi_app.sanic_app._server_event("shutdown", "after") + await self.sanic_app._server_event("shutdown", "before") + await self.sanic_app._server_event("shutdown", "after") - async def __call__( - self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend - ) -> None: - message = await receive() - if message["type"] == "lifespan.startup": - try: - await self.startup() - except Exception as e: - error_logger.exception(e) - await send( - {"type": "lifespan.startup.failed", "message": str(e)} - ) - else: - await send({"type": "lifespan.startup.complete"}) - - message = await receive() - if message["type"] == "lifespan.shutdown": - try: - await self.shutdown() - except Exception as e: - error_logger.exception(e) - await send( - {"type": "lifespan.shutdown.failed", "message": str(e)} - ) - else: - await send({"type": "lifespan.shutdown.complete"}) + async def __call__(self) -> None: + while True: + message = await self.receive() + if message["type"] == "lifespan.startup": + try: + await self.startup() + except Exception as e: + error_logger.exception(e) + await self.send( + {"type": "lifespan.startup.failed", "message": str(e)} + ) + else: + await self.send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + try: + await self.shutdown() + except Exception as e: + error_logger.exception(e) + await self.send( + {"type": "lifespan.shutdown.failed", "message": str(e)} + ) + else: + await self.send({"type": "lifespan.shutdown.complete"}) + return class ASGIApp: @@ -117,19 +114,22 @@ class ASGIApp: stage: Stage response: Optional[BaseHTTPResponse] - def __init__(self) -> None: - self.ws = None - @classmethod async def create( - cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend - ) -> "ASGIApp": + cls, + sanic_app: Sanic, + scope: ASGIScope, + receive: ASGIReceive, + send: ASGISend, + ) -> ASGIApp: instance = cls() + instance.ws = None instance.sanic_app = sanic_app instance.transport = MockTransport(scope, receive, send) instance.transport.loop = sanic_app.loop instance.stage = Stage.IDLE instance.response = None + instance.sanic_app.state.is_started = True setattr(instance.transport, "add_task", sanic_app.loop.create_task) headers = Header( @@ -138,52 +138,47 @@ class ASGIApp: for key, value in scope.get("headers", []) ] ) - instance.lifespan = Lifespan(instance) + path = ( + scope["path"][1:] + if scope["path"].startswith("/") + else scope["path"] + ) + url = "/".join([scope.get("root_path", ""), quote(path)]) + url_bytes = url.encode("latin-1") + url_bytes += b"?" + scope["query_string"] - if scope["type"] == "lifespan": - await instance.lifespan(scope, receive, send) + if scope["type"] == "http": + version = scope["http_version"] + method = scope["method"] + elif scope["type"] == "websocket": + version = "1.1" + method = "GET" + + instance.ws = instance.transport.create_websocket_connection( + send, receive + ) else: - path = ( - scope["path"][1:] - if scope["path"].startswith("/") - else scope["path"] - ) - url = "/".join([scope.get("root_path", ""), quote(path)]) - url_bytes = url.encode("latin-1") - url_bytes += b"?" + scope["query_string"] + raise ServerError("Received unknown ASGI scope") - if scope["type"] == "http": - version = scope["http_version"] - method = scope["method"] - elif scope["type"] == "websocket": - version = "1.1" - method = "GET" + request_class = sanic_app.request_class or Request + instance.request = request_class( + url_bytes, + headers, + version, + method, + instance.transport, + sanic_app, + ) + instance.request.stream = instance # type: ignore + instance.request_body = True + instance.request.conn_info = ConnInfo(instance.transport) - instance.ws = instance.transport.create_websocket_connection( - send, receive - ) - else: - raise ServerError("Received unknown ASGI scope") - - request_class = sanic_app.request_class or Request - instance.request = request_class( - url_bytes, - headers, - version, - method, - instance.transport, - sanic_app, - ) - instance.request.stream = instance - instance.request_body = True - instance.request.conn_info = ConnInfo(instance.transport) - - await sanic_app.dispatch( - "http.lifecycle.request", - inline=True, - context={"request": instance.request}, - fail_not_found=False, - ) + await instance.sanic_app.dispatch( + "http.lifecycle.request", + inline=True, + context={"request": instance.request}, + fail_not_found=False, + ) return instance diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 1b57436d..6b6872ac 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -2,13 +2,14 @@ import asyncio import logging from collections import deque, namedtuple +from unittest.mock import call import pytest import uvicorn from sanic import Sanic from sanic.application.state import Mode -from sanic.asgi import ASGIApp, MockTransport +from sanic.asgi import ASGIApp, Lifespan, MockTransport from sanic.exceptions import BadRequest, Forbidden, ServiceUnavailable from sanic.request import Request from sanic.response import json, text @@ -116,10 +117,6 @@ def test_listeners_triggered(caplog): stop_message, ) not in caplog.record_tuples - all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) - for task in all_tasks: - task.cancel() - assert before_server_start assert after_server_start assert before_server_stop @@ -218,10 +215,6 @@ def test_listeners_triggered_async(app, caplog): stop_message, ) not in caplog.record_tuples - all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) - for task in all_tasks: - task.cancel() - assert before_server_start assert after_server_start assert before_server_stop @@ -272,10 +265,6 @@ def test_non_default_uvloop_config_raises_warning(app): with pytest.warns(UserWarning) as records: server.run() - all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) - for task in all_tasks: - task.cancel() - msg = "" for record in records: _msg = str(record.message) @@ -583,15 +572,28 @@ async def test_error_on_lifespan_exception_start(app, caplog): async def before_server_start(_): 1 / 0 - recv = AsyncMock(return_value={"type": "lifespan.startup"}) + recv = AsyncMock( + side_effect=[ + {"type": "lifespan.startup"}, + {"type": "lifespan.shutdown"}, + ] + ) send = AsyncMock() app.asgi = True + lifespan = Lifespan(app, {"type": "lifespan"}, recv, send) with caplog.at_level(logging.ERROR): - await ASGIApp.create(app, {"type": "lifespan"}, recv, send) + await lifespan() - send.assert_awaited_once_with( - {"type": "lifespan.startup.failed", "message": "division by zero"} + send.assert_has_calls( + [ + call( + { + "type": "lifespan.startup.failed", + "message": "division by zero", + } + ) + ] ) @@ -601,13 +603,26 @@ async def test_error_on_lifespan_exception_stop(app: Sanic): async def before_server_stop(_): 1 / 0 - recv = AsyncMock(return_value={"type": "lifespan.shutdown"}) + recv = AsyncMock( + side_effect=[ + {"type": "lifespan.startup"}, + {"type": "lifespan.shutdown"}, + ] + ) send = AsyncMock() app.asgi = True await app._startup() - await ASGIApp.create(app, {"type": "lifespan"}, recv, send) + lifespan = Lifespan(app, {"type": "lifespan"}, recv, send) + await lifespan() - send.assert_awaited_once_with( - {"type": "lifespan.shutdown.failed", "message": "division by zero"} + send.assert_has_calls( + [ + call( + { + "type": "lifespan.shutdown.failed", + "message": "division by zero", + } + ) + ] )