Implement ASGI lifespan events to match Sanic listeners

This commit is contained in:
Adam Hopkins 2019-05-27 12:33:25 +03:00
parent 22c0d97783
commit 9172399b8c
3 changed files with 131 additions and 41 deletions

View File

@ -56,3 +56,33 @@ async def handler_stream(request):
body = body.decode("utf-8").replace("1", "A") body = body.decode("utf-8").replace("1", "A")
# await response.write(body) # await response.write(body)
return stream(streaming) return stream(streaming)
@app.listener("before_server_start")
async def listener_before_server_start(*args, **kwargs):
print("before_server_start")
@app.listener("after_server_start")
async def listener_after_server_start(*args, **kwargs):
print("after_server_start")
@app.listener("before_server_stop")
async def listener_before_server_stop(*args, **kwargs):
print("before_server_stop")
@app.listener("after_server_stop")
async def listener_after_server_stop(*args, **kwargs):
print("after_server_stop")
@app.middleware("request")
async def print_on_request(request):
print("print_on_request")
@app.middleware("response")
async def print_on_response(request, response):
print("print_on_response")

View File

@ -1,11 +1,18 @@
from typing import Any, Awaitable, Callable, MutableMapping, Union
import asyncio import asyncio
from multidict import CIMultiDict import warnings
from functools import partial from functools import partial
from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union
from multidict import CIMultiDict
from sanic.exceptions import InvalidUsage
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection
from sanic.server import StreamBuffer from sanic.server import StreamBuffer
from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any] ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any]
@ -20,30 +27,29 @@ class MockProtocol:
self._not_paused.set() self._not_paused.set()
self._complete = asyncio.Event(loop=loop) self._complete = asyncio.Event(loop=loop)
def pause_writing(self): def pause_writing(self) -> None:
self._not_paused.clear() self._not_paused.clear()
def resume_writing(self): def resume_writing(self) -> None:
self._not_paused.set() self._not_paused.set()
async def complete(self): async def complete(self) -> None:
self._not_paused.set() self._not_paused.set()
await self.transport.send( await self.transport.send(
{"type": "http.response.body", "body": b"", "more_body": False} {"type": "http.response.body", "body": b"", "more_body": False}
) )
@property @property
def is_complete(self): def is_complete(self) -> bool:
return self._complete.is_set() return self._complete.is_set()
async def push_data(self, data): async def push_data(self, data: bytes) -> None:
if not self.is_complete: if not self.is_complete:
await self.transport.send( await self.transport.send(
{"type": "http.response.body", "body": data, "more_body": True} {"type": "http.response.body", "body": data, "more_body": True}
) )
async def drain(self): async def drain(self) -> None:
print("draining")
await self._not_paused.wait() await self._not_paused.wait()
@ -57,7 +63,7 @@ class MockTransport:
self._protocol = None self._protocol = None
self.loop = None self.loop = None
def get_protocol(self): def get_protocol(self) -> MockProtocol:
if not self._protocol: if not self._protocol:
self._protocol = MockProtocol(self, self.loop) self._protocol = MockProtocol(self, self.loop)
return self._protocol return self._protocol
@ -69,7 +75,10 @@ class MockTransport:
return self.scope.get("scheme") in ["https", "wss"] return self.scope.get("scheme") in ["https", "wss"]
def get_websocket_connection(self) -> WebSocketConnection: def get_websocket_connection(self) -> WebSocketConnection:
return self._websocket_connection try:
return self._websocket_connection
except AttributeError:
raise InvalidUsage("Improper websocket connection.")
def create_websocket_connection( def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive self, send: ASGISend, receive: ASGIReceive
@ -77,19 +86,61 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive) self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection return self._websocket_connection
def add_task(self): def add_task(self) -> None:
raise NotImplementedError raise NotImplementedError
async def send(self, data): async def send(self, data) -> None:
print(">> sending. more:", data.get("more_body"))
# TODO: # TODO:
# - Validation on data and that it is formatted properly and is valid # - Validation on data and that it is formatted properly and is valid
await self._send(data) await self._send(data)
async def receive(self): async def receive(self) -> ASGIMessage:
return await self._receive() return await self._receive()
class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
async def startup(self) -> None:
if self.asgi_app.sanic_app.listeners["before_server_start"]:
warnings.warn(
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
)
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
warnings.warn(
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
)
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
message = await receive()
if message["type"] == "lifespan.startup":
await self.startup()
await send({"type": "lifespan.startup.complete"})
message = await receive()
if message["type"] == "lifespan.shutdown":
await self.shutdown()
await send({"type": "lifespan.shutdown.complete"})
class ASGIApp: class ASGIApp:
def __init__(self) -> None: def __init__(self) -> None:
self.ws = None self.ws = None
@ -104,42 +155,51 @@ class ASGIApp:
instance.transport.add_task = sanic_app.loop.create_task instance.transport.add_task = sanic_app.loop.create_task
instance.transport.loop = sanic_app.loop instance.transport.loop = sanic_app.loop
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
headers = CIMultiDict( headers = CIMultiDict(
[ [
(key.decode("latin-1"), value.decode("latin-1")) (key.decode("latin-1"), value.decode("latin-1"))
for key, value in scope.get("headers", []) for key, value in scope.get("headers", [])
] ]
) )
instance.do_stream = ( instance.do_stream = (
True if headers.get("expect") == "100-continue" else False True if headers.get("expect") == "100-continue" else False
) )
if scope["type"] == "http": if scope["type"] == "lifespan":
version = scope["http_version"] lifespan = Lifespan(instance)
method = scope["method"] await lifespan(scope, receive, send)
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept()
else: else:
pass url_bytes = scope.get("root_path", "") + scope["path"]
# TODO: url_bytes = url_bytes.encode("latin-1")
# - close connection url_bytes += scope["query_string"]
instance.request = Request( if scope["type"] == "http":
url_bytes, headers, version, method, instance.transport, sanic_app version = scope["http_version"]
) method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"
if sanic_app.is_request_stream: instance.ws = instance.transport.create_websocket_connection(
instance.request.stream = StreamBuffer() send, receive
)
await instance.ws.accept()
else:
pass
# TODO:
# - close connection
instance.request = Request(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
if sanic_app.is_request_stream:
instance.request.stream = StreamBuffer()
return instance return instance

View File

@ -1,6 +1,6 @@
import typing
import types
import asyncio import asyncio
import types
import typing
from json import JSONDecodeError from json import JSONDecodeError
from socket import socket from socket import socket