Separate ASGI request and lifespan callables (#2646)

This commit is contained in:
Adam Hopkins
2023-03-17 09:05:33 +02:00
committed by GitHub
parent 08a81c81be
commit 5ee36fd933
3 changed files with 130 additions and 115 deletions

View File

@@ -48,7 +48,7 @@ from sanic_routing.route import Route
from sanic.application.ext import setup_ext
from sanic.application.state import ApplicationState, ServerStage
from sanic.asgi import ASGIApp
from sanic.asgi import ASGIApp, Lifespan
from sanic.base.root import BaseSanic
from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
@@ -119,6 +119,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
)
__slots__ = (
"_asgi_app",
"_asgi_lifespan",
"_asgi_client",
"_blueprint_order",
"_delayed_tasks",
@@ -198,6 +199,8 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
self.config.INSPECTOR = inspector
# Then we can do the rest
self._asgi_app: Optional[ASGIApp] = None
self._asgi_lifespan: Optional[Lifespan] = None
self._asgi_client: Any = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
@@ -1349,12 +1352,14 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
three arguments: scope, receive, send. See the ASGI reference for more
details: https://asgi.readthedocs.io/en/latest
"""
self.asgi = True
if scope["type"] == "lifespan":
self.asgi = True
self.motd("")
self._asgi_app = await ASGIApp.create(self, scope, receive, send)
asgi_app = self._asgi_app
await asgi_app()
self._asgi_lifespan = Lifespan(self, scope, receive, send)
await self._asgi_lifespan()
else:
self._asgi_app = await ASGIApp.create(self, scope, receive, send)
await self._asgi_app()
_asgi_single_callable = True # We conform to ASGI 3.0 single-callable

View File

@@ -22,13 +22,15 @@ if TYPE_CHECKING:
class Lifespan:
def __init__(self, asgi_app: ASGIApp) -> None:
self.asgi_app = asgi_app
def __init__(
self, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.sanic_app = sanic_app
self.scope = scope
self.receive = receive
self.send = send
if (
"server.init.before"
in self.asgi_app.sanic_app.signal_router.name_index
):
if "server.init.before" in self.sanic_app.signal_router.name_index:
logger.debug(
'You have set a listener for "before_server_start" '
"in ASGI mode. "
@@ -36,10 +38,7 @@ class Lifespan:
"the ASGI server is started.",
extra={"verbosity": 1},
)
if (
"server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index
):
if "server.shutdown.after" in self.sanic_app.signal_router.name_index:
logger.debug(
'You have set a listener for "after_server_stop" '
"in ASGI mode. "
@@ -57,11 +56,11 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
startup event.
"""
await self.asgi_app.sanic_app._startup()
await self.asgi_app.sanic_app._server_event("init", "before")
await self.asgi_app.sanic_app._server_event("init", "after")
await self.sanic_app._startup()
await self.sanic_app._server_event("init", "before")
await self.sanic_app._server_event("init", "after")
if not isinstance(self.asgi_app.sanic_app.config.USE_UVLOOP, Default):
if not isinstance(self.sanic_app.config.USE_UVLOOP, Default):
warnings.warn(
"You have set the USE_UVLOOP configuration option, but Sanic "
"cannot control the event loop when running in ASGI mode."
@@ -77,35 +76,33 @@ class Lifespan:
in sequence since the ASGI lifespan protocol only supports a single
shutdown event.
"""
await self.asgi_app.sanic_app._server_event("shutdown", "before")
await self.asgi_app.sanic_app._server_event("shutdown", "after")
await self.sanic_app._server_event("shutdown", "before")
await self.sanic_app._server_event("shutdown", "after")
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
message = await receive()
if message["type"] == "lifespan.startup":
try:
await self.startup()
except Exception as e:
error_logger.exception(e)
await send(
{"type": "lifespan.startup.failed", "message": str(e)}
)
else:
await send({"type": "lifespan.startup.complete"})
message = await receive()
if message["type"] == "lifespan.shutdown":
try:
await self.shutdown()
except Exception as e:
error_logger.exception(e)
await send(
{"type": "lifespan.shutdown.failed", "message": str(e)}
)
else:
await send({"type": "lifespan.shutdown.complete"})
async def __call__(self) -> None:
while True:
message = await self.receive()
if message["type"] == "lifespan.startup":
try:
await self.startup()
except Exception as e:
error_logger.exception(e)
await self.send(
{"type": "lifespan.startup.failed", "message": str(e)}
)
else:
await self.send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
try:
await self.shutdown()
except Exception as e:
error_logger.exception(e)
await self.send(
{"type": "lifespan.shutdown.failed", "message": str(e)}
)
else:
await self.send({"type": "lifespan.shutdown.complete"})
return
class ASGIApp:
@@ -117,19 +114,22 @@ class ASGIApp:
stage: Stage
response: Optional[BaseHTTPResponse]
def __init__(self) -> None:
self.ws = None
@classmethod
async def create(
cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> "ASGIApp":
cls,
sanic_app: Sanic,
scope: ASGIScope,
receive: ASGIReceive,
send: ASGISend,
) -> ASGIApp:
instance = cls()
instance.ws = None
instance.sanic_app = sanic_app
instance.transport = MockTransport(scope, receive, send)
instance.transport.loop = sanic_app.loop
instance.stage = Stage.IDLE
instance.response = None
instance.sanic_app.state.is_started = True
setattr(instance.transport, "add_task", sanic_app.loop.create_task)
headers = Header(
@@ -138,52 +138,47 @@ class ASGIApp:
for key, value in scope.get("headers", [])
]
)
instance.lifespan = Lifespan(instance)
path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"]
if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send)
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
else:
path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"]
raise ServerError("Received unknown ASGI scope")
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
instance.request.stream = instance # type: ignore
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport)
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
else:
raise ServerError("Received unknown ASGI scope")
request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
instance.request.stream = instance
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport)
await sanic_app.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": instance.request},
fail_not_found=False,
)
await instance.sanic_app.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": instance.request},
fail_not_found=False,
)
return instance

View File

@@ -2,13 +2,14 @@ import asyncio
import logging
from collections import deque, namedtuple
from unittest.mock import call
import pytest
import uvicorn
from sanic import Sanic
from sanic.application.state import Mode
from sanic.asgi import ASGIApp, MockTransport
from sanic.asgi import ASGIApp, Lifespan, MockTransport
from sanic.exceptions import BadRequest, Forbidden, ServiceUnavailable
from sanic.request import Request
from sanic.response import json, text
@@ -116,10 +117,6 @@ def test_listeners_triggered(caplog):
stop_message,
) not in caplog.record_tuples
all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
@@ -218,10 +215,6 @@ def test_listeners_triggered_async(app, caplog):
stop_message,
) not in caplog.record_tuples
all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()
assert before_server_start
assert after_server_start
assert before_server_stop
@@ -272,10 +265,6 @@ def test_non_default_uvloop_config_raises_warning(app):
with pytest.warns(UserWarning) as records:
server.run()
all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()
msg = ""
for record in records:
_msg = str(record.message)
@@ -583,15 +572,28 @@ async def test_error_on_lifespan_exception_start(app, caplog):
async def before_server_start(_):
1 / 0
recv = AsyncMock(return_value={"type": "lifespan.startup"})
recv = AsyncMock(
side_effect=[
{"type": "lifespan.startup"},
{"type": "lifespan.shutdown"},
]
)
send = AsyncMock()
app.asgi = True
lifespan = Lifespan(app, {"type": "lifespan"}, recv, send)
with caplog.at_level(logging.ERROR):
await ASGIApp.create(app, {"type": "lifespan"}, recv, send)
await lifespan()
send.assert_awaited_once_with(
{"type": "lifespan.startup.failed", "message": "division by zero"}
send.assert_has_calls(
[
call(
{
"type": "lifespan.startup.failed",
"message": "division by zero",
}
)
]
)
@@ -601,13 +603,26 @@ async def test_error_on_lifespan_exception_stop(app: Sanic):
async def before_server_stop(_):
1 / 0
recv = AsyncMock(return_value={"type": "lifespan.shutdown"})
recv = AsyncMock(
side_effect=[
{"type": "lifespan.startup"},
{"type": "lifespan.shutdown"},
]
)
send = AsyncMock()
app.asgi = True
await app._startup()
await ASGIApp.create(app, {"type": "lifespan"}, recv, send)
lifespan = Lifespan(app, {"type": "lifespan"}, recv, send)
await lifespan()
send.assert_awaited_once_with(
{"type": "lifespan.shutdown.failed", "message": "division by zero"}
send.assert_has_calls(
[
call(
{
"type": "lifespan.shutdown.failed",
"message": "division by zero",
}
)
]
)