Implement ASGI lifespan events to match Sanic listeners

This commit is contained in:
Adam Hopkins 2019-05-27 12:33:25 +03:00
parent 22c0d97783
commit 9172399b8c
3 changed files with 131 additions and 41 deletions

View File

@ -56,3 +56,33 @@ async def handler_stream(request):
body = body.decode("utf-8").replace("1", "A")
# await response.write(body)
return stream(streaming)
@app.listener("before_server_start")
async def listener_before_server_start(*args, **kwargs):
print("before_server_start")
@app.listener("after_server_start")
async def listener_after_server_start(*args, **kwargs):
print("after_server_start")
@app.listener("before_server_stop")
async def listener_before_server_stop(*args, **kwargs):
print("before_server_stop")
@app.listener("after_server_stop")
async def listener_after_server_stop(*args, **kwargs):
print("after_server_stop")
@app.middleware("request")
async def print_on_request(request):
print("print_on_request")
@app.middleware("response")
async def print_on_response(request, response):
print("print_on_response")

View File

@ -1,11 +1,18 @@
from typing import Any, Awaitable, Callable, MutableMapping, Union
import asyncio
from multidict import CIMultiDict
import warnings
from functools import partial
from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union
from multidict import CIMultiDict
from sanic.exceptions import InvalidUsage
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection
from sanic.server import StreamBuffer
from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any]
@ -20,30 +27,29 @@ class MockProtocol:
self._not_paused.set()
self._complete = asyncio.Event(loop=loop)
def pause_writing(self):
def pause_writing(self) -> None:
self._not_paused.clear()
def resume_writing(self):
def resume_writing(self) -> None:
self._not_paused.set()
async def complete(self):
async def complete(self) -> None:
self._not_paused.set()
await self.transport.send(
{"type": "http.response.body", "body": b"", "more_body": False}
)
@property
def is_complete(self):
def is_complete(self) -> bool:
return self._complete.is_set()
async def push_data(self, data):
async def push_data(self, data: bytes) -> None:
if not self.is_complete:
await self.transport.send(
{"type": "http.response.body", "body": data, "more_body": True}
)
async def drain(self):
print("draining")
async def drain(self) -> None:
await self._not_paused.wait()
@ -57,7 +63,7 @@ class MockTransport:
self._protocol = None
self.loop = None
def get_protocol(self):
def get_protocol(self) -> MockProtocol:
if not self._protocol:
self._protocol = MockProtocol(self, self.loop)
return self._protocol
@ -69,7 +75,10 @@ class MockTransport:
return self.scope.get("scheme") in ["https", "wss"]
def get_websocket_connection(self) -> WebSocketConnection:
return self._websocket_connection
try:
return self._websocket_connection
except AttributeError:
raise InvalidUsage("Improper websocket connection.")
def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive
@ -77,19 +86,61 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection
def add_task(self):
def add_task(self) -> None:
raise NotImplementedError
async def send(self, data):
print(">> sending. more:", data.get("more_body"))
async def send(self, data) -> None:
# TODO:
# - Validation on data and that it is formatted properly and is valid
await self._send(data)
async def receive(self):
async def receive(self) -> ASGIMessage:
return await self._receive()
class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
async def startup(self) -> None:
if self.asgi_app.sanic_app.listeners["before_server_start"]:
warnings.warn(
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
)
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
warnings.warn(
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
)
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
message = await receive()
if message["type"] == "lifespan.startup":
await self.startup()
await send({"type": "lifespan.startup.complete"})
message = await receive()
if message["type"] == "lifespan.shutdown":
await self.shutdown()
await send({"type": "lifespan.shutdown.complete"})
class ASGIApp:
def __init__(self) -> None:
self.ws = None
@ -104,42 +155,51 @@ class ASGIApp:
instance.transport.add_task = sanic_app.loop.create_task
instance.transport.loop = sanic_app.loop
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
headers = CIMultiDict(
[
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in scope.get("headers", [])
]
)
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept()
if scope["type"] == "lifespan":
lifespan = Lifespan(instance)
await lifespan(scope, receive, send)
else:
pass
# TODO:
# - close connection
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
instance.request = Request(
url_bytes, headers, version, method, instance.transport, sanic_app
)
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
if sanic_app.is_request_stream:
instance.request.stream = StreamBuffer()
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept()
else:
pass
# TODO:
# - close connection
instance.request = Request(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
if sanic_app.is_request_stream:
instance.request.stream = StreamBuffer()
return instance

View File

@ -1,6 +1,6 @@
import typing
import types
import asyncio
import types
import typing
from json import JSONDecodeError
from socket import socket