Separate ASGI request and lifespan callables (#2646)
This commit is contained in:
		
							
								
								
									
										15
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -48,7 +48,7 @@ from sanic_routing.route import Route | ||||
|  | ||||
| from sanic.application.ext import setup_ext | ||||
| from sanic.application.state import ApplicationState, ServerStage | ||||
| from sanic.asgi import ASGIApp | ||||
| from sanic.asgi import ASGIApp, Lifespan | ||||
| from sanic.base.root import BaseSanic | ||||
| from sanic.blueprint_group import BlueprintGroup | ||||
| from sanic.blueprints import Blueprint | ||||
| @@ -119,6 +119,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|     ) | ||||
|     __slots__ = ( | ||||
|         "_asgi_app", | ||||
|         "_asgi_lifespan", | ||||
|         "_asgi_client", | ||||
|         "_blueprint_order", | ||||
|         "_delayed_tasks", | ||||
| @@ -198,6 +199,8 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|             self.config.INSPECTOR = inspector | ||||
|  | ||||
|         # Then we can do the rest | ||||
|         self._asgi_app: Optional[ASGIApp] = None | ||||
|         self._asgi_lifespan: Optional[Lifespan] = None | ||||
|         self._asgi_client: Any = None | ||||
|         self._blueprint_order: List[Blueprint] = [] | ||||
|         self._delayed_tasks: List[str] = [] | ||||
| @@ -1349,12 +1352,14 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): | ||||
|         three arguments: scope, receive, send. See the ASGI reference for more | ||||
|         details: https://asgi.readthedocs.io/en/latest | ||||
|         """ | ||||
|         self.asgi = True | ||||
|         if scope["type"] == "lifespan": | ||||
|             self.asgi = True | ||||
|             self.motd("") | ||||
|         self._asgi_app = await ASGIApp.create(self, scope, receive, send) | ||||
|         asgi_app = self._asgi_app | ||||
|         await asgi_app() | ||||
|             self._asgi_lifespan = Lifespan(self, scope, receive, send) | ||||
|             await self._asgi_lifespan() | ||||
|         else: | ||||
|             self._asgi_app = await ASGIApp.create(self, scope, receive, send) | ||||
|             await self._asgi_app() | ||||
|  | ||||
|     _asgi_single_callable = True  # We conform to ASGI 3.0 single-callable | ||||
|  | ||||
|   | ||||
							
								
								
									
										173
									
								
								sanic/asgi.py
									
									
									
									
									
								
							
							
						
						
									
										173
									
								
								sanic/asgi.py
									
									
									
									
									
								
							| @@ -22,13 +22,15 @@ if TYPE_CHECKING: | ||||
|  | ||||
|  | ||||
| class Lifespan: | ||||
|     def __init__(self, asgi_app: ASGIApp) -> None: | ||||
|         self.asgi_app = asgi_app | ||||
|     def __init__( | ||||
|         self, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||
|     ) -> None: | ||||
|         self.sanic_app = sanic_app | ||||
|         self.scope = scope | ||||
|         self.receive = receive | ||||
|         self.send = send | ||||
|  | ||||
|         if ( | ||||
|             "server.init.before" | ||||
|             in self.asgi_app.sanic_app.signal_router.name_index | ||||
|         ): | ||||
|         if "server.init.before" in self.sanic_app.signal_router.name_index: | ||||
|             logger.debug( | ||||
|                 'You have set a listener for "before_server_start" ' | ||||
|                 "in ASGI mode. " | ||||
| @@ -36,10 +38,7 @@ class Lifespan: | ||||
|                 "the ASGI server is started.", | ||||
|                 extra={"verbosity": 1}, | ||||
|             ) | ||||
|         if ( | ||||
|             "server.shutdown.after" | ||||
|             in self.asgi_app.sanic_app.signal_router.name_index | ||||
|         ): | ||||
|         if "server.shutdown.after" in self.sanic_app.signal_router.name_index: | ||||
|             logger.debug( | ||||
|                 'You have set a listener for "after_server_stop" ' | ||||
|                 "in ASGI mode. " | ||||
| @@ -57,11 +56,11 @@ class Lifespan: | ||||
|         in sequence since the ASGI lifespan protocol only supports a single | ||||
|         startup event. | ||||
|         """ | ||||
|         await self.asgi_app.sanic_app._startup() | ||||
|         await self.asgi_app.sanic_app._server_event("init", "before") | ||||
|         await self.asgi_app.sanic_app._server_event("init", "after") | ||||
|         await self.sanic_app._startup() | ||||
|         await self.sanic_app._server_event("init", "before") | ||||
|         await self.sanic_app._server_event("init", "after") | ||||
|  | ||||
|         if not isinstance(self.asgi_app.sanic_app.config.USE_UVLOOP, Default): | ||||
|         if not isinstance(self.sanic_app.config.USE_UVLOOP, Default): | ||||
|             warnings.warn( | ||||
|                 "You have set the USE_UVLOOP configuration option, but Sanic " | ||||
|                 "cannot control the event loop when running in ASGI mode." | ||||
| @@ -77,35 +76,33 @@ class Lifespan: | ||||
|         in sequence since the ASGI lifespan protocol only supports a single | ||||
|         shutdown event. | ||||
|         """ | ||||
|         await self.asgi_app.sanic_app._server_event("shutdown", "before") | ||||
|         await self.asgi_app.sanic_app._server_event("shutdown", "after") | ||||
|         await self.sanic_app._server_event("shutdown", "before") | ||||
|         await self.sanic_app._server_event("shutdown", "after") | ||||
|  | ||||
|     async def __call__( | ||||
|         self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||
|     ) -> None: | ||||
|         message = await receive() | ||||
|         if message["type"] == "lifespan.startup": | ||||
|             try: | ||||
|                 await self.startup() | ||||
|             except Exception as e: | ||||
|                 error_logger.exception(e) | ||||
|                 await send( | ||||
|                     {"type": "lifespan.startup.failed", "message": str(e)} | ||||
|                 ) | ||||
|             else: | ||||
|                 await send({"type": "lifespan.startup.complete"}) | ||||
|  | ||||
|         message = await receive() | ||||
|         if message["type"] == "lifespan.shutdown": | ||||
|             try: | ||||
|                 await self.shutdown() | ||||
|             except Exception as e: | ||||
|                 error_logger.exception(e) | ||||
|                 await send( | ||||
|                     {"type": "lifespan.shutdown.failed", "message": str(e)} | ||||
|                 ) | ||||
|             else: | ||||
|                 await send({"type": "lifespan.shutdown.complete"}) | ||||
|     async def __call__(self) -> None: | ||||
|         while True: | ||||
|             message = await self.receive() | ||||
|             if message["type"] == "lifespan.startup": | ||||
|                 try: | ||||
|                     await self.startup() | ||||
|                 except Exception as e: | ||||
|                     error_logger.exception(e) | ||||
|                     await self.send( | ||||
|                         {"type": "lifespan.startup.failed", "message": str(e)} | ||||
|                     ) | ||||
|                 else: | ||||
|                     await self.send({"type": "lifespan.startup.complete"}) | ||||
|             elif message["type"] == "lifespan.shutdown": | ||||
|                 try: | ||||
|                     await self.shutdown() | ||||
|                 except Exception as e: | ||||
|                     error_logger.exception(e) | ||||
|                     await self.send( | ||||
|                         {"type": "lifespan.shutdown.failed", "message": str(e)} | ||||
|                     ) | ||||
|                 else: | ||||
|                     await self.send({"type": "lifespan.shutdown.complete"}) | ||||
|                 return | ||||
|  | ||||
|  | ||||
| class ASGIApp: | ||||
| @@ -117,19 +114,22 @@ class ASGIApp: | ||||
|     stage: Stage | ||||
|     response: Optional[BaseHTTPResponse] | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.ws = None | ||||
|  | ||||
|     @classmethod | ||||
|     async def create( | ||||
|         cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend | ||||
|     ) -> "ASGIApp": | ||||
|         cls, | ||||
|         sanic_app: Sanic, | ||||
|         scope: ASGIScope, | ||||
|         receive: ASGIReceive, | ||||
|         send: ASGISend, | ||||
|     ) -> ASGIApp: | ||||
|         instance = cls() | ||||
|         instance.ws = None | ||||
|         instance.sanic_app = sanic_app | ||||
|         instance.transport = MockTransport(scope, receive, send) | ||||
|         instance.transport.loop = sanic_app.loop | ||||
|         instance.stage = Stage.IDLE | ||||
|         instance.response = None | ||||
|         instance.sanic_app.state.is_started = True | ||||
|         setattr(instance.transport, "add_task", sanic_app.loop.create_task) | ||||
|  | ||||
|         headers = Header( | ||||
| @@ -138,52 +138,47 @@ class ASGIApp: | ||||
|                 for key, value in scope.get("headers", []) | ||||
|             ] | ||||
|         ) | ||||
|         instance.lifespan = Lifespan(instance) | ||||
|         path = ( | ||||
|             scope["path"][1:] | ||||
|             if scope["path"].startswith("/") | ||||
|             else scope["path"] | ||||
|         ) | ||||
|         url = "/".join([scope.get("root_path", ""), quote(path)]) | ||||
|         url_bytes = url.encode("latin-1") | ||||
|         url_bytes += b"?" + scope["query_string"] | ||||
|  | ||||
|         if scope["type"] == "lifespan": | ||||
|             await instance.lifespan(scope, receive, send) | ||||
|         if scope["type"] == "http": | ||||
|             version = scope["http_version"] | ||||
|             method = scope["method"] | ||||
|         elif scope["type"] == "websocket": | ||||
|             version = "1.1" | ||||
|             method = "GET" | ||||
|  | ||||
|             instance.ws = instance.transport.create_websocket_connection( | ||||
|                 send, receive | ||||
|             ) | ||||
|         else: | ||||
|             path = ( | ||||
|                 scope["path"][1:] | ||||
|                 if scope["path"].startswith("/") | ||||
|                 else scope["path"] | ||||
|             ) | ||||
|             url = "/".join([scope.get("root_path", ""), quote(path)]) | ||||
|             url_bytes = url.encode("latin-1") | ||||
|             url_bytes += b"?" + scope["query_string"] | ||||
|             raise ServerError("Received unknown ASGI scope") | ||||
|  | ||||
|             if scope["type"] == "http": | ||||
|                 version = scope["http_version"] | ||||
|                 method = scope["method"] | ||||
|             elif scope["type"] == "websocket": | ||||
|                 version = "1.1" | ||||
|                 method = "GET" | ||||
|         request_class = sanic_app.request_class or Request | ||||
|         instance.request = request_class( | ||||
|             url_bytes, | ||||
|             headers, | ||||
|             version, | ||||
|             method, | ||||
|             instance.transport, | ||||
|             sanic_app, | ||||
|         ) | ||||
|         instance.request.stream = instance  # type: ignore | ||||
|         instance.request_body = True | ||||
|         instance.request.conn_info = ConnInfo(instance.transport) | ||||
|  | ||||
|                 instance.ws = instance.transport.create_websocket_connection( | ||||
|                     send, receive | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ServerError("Received unknown ASGI scope") | ||||
|  | ||||
|             request_class = sanic_app.request_class or Request | ||||
|             instance.request = request_class( | ||||
|                 url_bytes, | ||||
|                 headers, | ||||
|                 version, | ||||
|                 method, | ||||
|                 instance.transport, | ||||
|                 sanic_app, | ||||
|             ) | ||||
|             instance.request.stream = instance | ||||
|             instance.request_body = True | ||||
|             instance.request.conn_info = ConnInfo(instance.transport) | ||||
|  | ||||
|             await sanic_app.dispatch( | ||||
|                 "http.lifecycle.request", | ||||
|                 inline=True, | ||||
|                 context={"request": instance.request}, | ||||
|                 fail_not_found=False, | ||||
|             ) | ||||
|         await instance.sanic_app.dispatch( | ||||
|             "http.lifecycle.request", | ||||
|             inline=True, | ||||
|             context={"request": instance.request}, | ||||
|             fail_not_found=False, | ||||
|         ) | ||||
|  | ||||
|         return instance | ||||
|  | ||||
|   | ||||
| @@ -2,13 +2,14 @@ import asyncio | ||||
| import logging | ||||
|  | ||||
| from collections import deque, namedtuple | ||||
| from unittest.mock import call | ||||
|  | ||||
| import pytest | ||||
| import uvicorn | ||||
|  | ||||
| from sanic import Sanic | ||||
| from sanic.application.state import Mode | ||||
| from sanic.asgi import ASGIApp, MockTransport | ||||
| from sanic.asgi import ASGIApp, Lifespan, MockTransport | ||||
| from sanic.exceptions import BadRequest, Forbidden, ServiceUnavailable | ||||
| from sanic.request import Request | ||||
| from sanic.response import json, text | ||||
| @@ -116,10 +117,6 @@ def test_listeners_triggered(caplog): | ||||
|         stop_message, | ||||
|     ) not in caplog.record_tuples | ||||
|  | ||||
|     all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) | ||||
|     for task in all_tasks: | ||||
|         task.cancel() | ||||
|  | ||||
|     assert before_server_start | ||||
|     assert after_server_start | ||||
|     assert before_server_stop | ||||
| @@ -218,10 +215,6 @@ def test_listeners_triggered_async(app, caplog): | ||||
|         stop_message, | ||||
|     ) not in caplog.record_tuples | ||||
|  | ||||
|     all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) | ||||
|     for task in all_tasks: | ||||
|         task.cancel() | ||||
|  | ||||
|     assert before_server_start | ||||
|     assert after_server_start | ||||
|     assert before_server_stop | ||||
| @@ -272,10 +265,6 @@ def test_non_default_uvloop_config_raises_warning(app): | ||||
|     with pytest.warns(UserWarning) as records: | ||||
|         server.run() | ||||
|  | ||||
|     all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) | ||||
|     for task in all_tasks: | ||||
|         task.cancel() | ||||
|  | ||||
|     msg = "" | ||||
|     for record in records: | ||||
|         _msg = str(record.message) | ||||
| @@ -583,15 +572,28 @@ async def test_error_on_lifespan_exception_start(app, caplog): | ||||
|     async def before_server_start(_): | ||||
|         1 / 0 | ||||
|  | ||||
|     recv = AsyncMock(return_value={"type": "lifespan.startup"}) | ||||
|     recv = AsyncMock( | ||||
|         side_effect=[ | ||||
|             {"type": "lifespan.startup"}, | ||||
|             {"type": "lifespan.shutdown"}, | ||||
|         ] | ||||
|     ) | ||||
|     send = AsyncMock() | ||||
|     app.asgi = True | ||||
|  | ||||
|     lifespan = Lifespan(app, {"type": "lifespan"}, recv, send) | ||||
|     with caplog.at_level(logging.ERROR): | ||||
|         await ASGIApp.create(app, {"type": "lifespan"}, recv, send) | ||||
|         await lifespan() | ||||
|  | ||||
|     send.assert_awaited_once_with( | ||||
|         {"type": "lifespan.startup.failed", "message": "division by zero"} | ||||
|     send.assert_has_calls( | ||||
|         [ | ||||
|             call( | ||||
|                 { | ||||
|                     "type": "lifespan.startup.failed", | ||||
|                     "message": "division by zero", | ||||
|                 } | ||||
|             ) | ||||
|         ] | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @@ -601,13 +603,26 @@ async def test_error_on_lifespan_exception_stop(app: Sanic): | ||||
|     async def before_server_stop(_): | ||||
|         1 / 0 | ||||
|  | ||||
|     recv = AsyncMock(return_value={"type": "lifespan.shutdown"}) | ||||
|     recv = AsyncMock( | ||||
|         side_effect=[ | ||||
|             {"type": "lifespan.startup"}, | ||||
|             {"type": "lifespan.shutdown"}, | ||||
|         ] | ||||
|     ) | ||||
|     send = AsyncMock() | ||||
|     app.asgi = True | ||||
|     await app._startup() | ||||
|  | ||||
|     await ASGIApp.create(app, {"type": "lifespan"}, recv, send) | ||||
|     lifespan = Lifespan(app, {"type": "lifespan"}, recv, send) | ||||
|     await lifespan() | ||||
|  | ||||
|     send.assert_awaited_once_with( | ||||
|         {"type": "lifespan.shutdown.failed", "message": "division by zero"} | ||||
|     send.assert_has_calls( | ||||
|         [ | ||||
|             call( | ||||
|                 { | ||||
|                     "type": "lifespan.shutdown.failed", | ||||
|                     "message": "division by zero", | ||||
|                 } | ||||
|             ) | ||||
|         ] | ||||
|     ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Adam Hopkins
					Adam Hopkins