Implement ASGI lifespan events to match Sanic listeners
This commit is contained in:
parent
22c0d97783
commit
9172399b8c
|
@ -56,3 +56,33 @@ async def handler_stream(request):
|
||||||
body = body.decode("utf-8").replace("1", "A")
|
body = body.decode("utf-8").replace("1", "A")
|
||||||
# await response.write(body)
|
# await response.write(body)
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
|
|
||||||
|
@app.listener("before_server_start")
|
||||||
|
async def listener_before_server_start(*args, **kwargs):
|
||||||
|
print("before_server_start")
|
||||||
|
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
async def listener_after_server_start(*args, **kwargs):
|
||||||
|
print("after_server_start")
|
||||||
|
|
||||||
|
|
||||||
|
@app.listener("before_server_stop")
|
||||||
|
async def listener_before_server_stop(*args, **kwargs):
|
||||||
|
print("before_server_stop")
|
||||||
|
|
||||||
|
|
||||||
|
@app.listener("after_server_stop")
|
||||||
|
async def listener_after_server_stop(*args, **kwargs):
|
||||||
|
print("after_server_stop")
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("request")
|
||||||
|
async def print_on_request(request):
|
||||||
|
print("print_on_request")
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("response")
|
||||||
|
async def print_on_response(request, response):
|
||||||
|
print("print_on_response")
|
||||||
|
|
138
sanic/asgi.py
138
sanic/asgi.py
|
@ -1,11 +1,18 @@
|
||||||
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from multidict import CIMultiDict
|
import warnings
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from inspect import isawaitable
|
||||||
|
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
||||||
|
|
||||||
|
from multidict import CIMultiDict
|
||||||
|
|
||||||
|
from sanic.exceptions import InvalidUsage
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||||
from sanic.websocket import WebSocketConnection
|
|
||||||
from sanic.server import StreamBuffer
|
from sanic.server import StreamBuffer
|
||||||
|
from sanic.websocket import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
ASGIScope = MutableMapping[str, Any]
|
ASGIScope = MutableMapping[str, Any]
|
||||||
ASGIMessage = MutableMapping[str, Any]
|
ASGIMessage = MutableMapping[str, Any]
|
||||||
|
@ -20,30 +27,29 @@ class MockProtocol:
|
||||||
self._not_paused.set()
|
self._not_paused.set()
|
||||||
self._complete = asyncio.Event(loop=loop)
|
self._complete = asyncio.Event(loop=loop)
|
||||||
|
|
||||||
def pause_writing(self):
|
def pause_writing(self) -> None:
|
||||||
self._not_paused.clear()
|
self._not_paused.clear()
|
||||||
|
|
||||||
def resume_writing(self):
|
def resume_writing(self) -> None:
|
||||||
self._not_paused.set()
|
self._not_paused.set()
|
||||||
|
|
||||||
async def complete(self):
|
async def complete(self) -> None:
|
||||||
self._not_paused.set()
|
self._not_paused.set()
|
||||||
await self.transport.send(
|
await self.transport.send(
|
||||||
{"type": "http.response.body", "body": b"", "more_body": False}
|
{"type": "http.response.body", "body": b"", "more_body": False}
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_complete(self):
|
def is_complete(self) -> bool:
|
||||||
return self._complete.is_set()
|
return self._complete.is_set()
|
||||||
|
|
||||||
async def push_data(self, data):
|
async def push_data(self, data: bytes) -> None:
|
||||||
if not self.is_complete:
|
if not self.is_complete:
|
||||||
await self.transport.send(
|
await self.transport.send(
|
||||||
{"type": "http.response.body", "body": data, "more_body": True}
|
{"type": "http.response.body", "body": data, "more_body": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def drain(self):
|
async def drain(self) -> None:
|
||||||
print("draining")
|
|
||||||
await self._not_paused.wait()
|
await self._not_paused.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +63,7 @@ class MockTransport:
|
||||||
self._protocol = None
|
self._protocol = None
|
||||||
self.loop = None
|
self.loop = None
|
||||||
|
|
||||||
def get_protocol(self):
|
def get_protocol(self) -> MockProtocol:
|
||||||
if not self._protocol:
|
if not self._protocol:
|
||||||
self._protocol = MockProtocol(self, self.loop)
|
self._protocol = MockProtocol(self, self.loop)
|
||||||
return self._protocol
|
return self._protocol
|
||||||
|
@ -69,7 +75,10 @@ class MockTransport:
|
||||||
return self.scope.get("scheme") in ["https", "wss"]
|
return self.scope.get("scheme") in ["https", "wss"]
|
||||||
|
|
||||||
def get_websocket_connection(self) -> WebSocketConnection:
|
def get_websocket_connection(self) -> WebSocketConnection:
|
||||||
return self._websocket_connection
|
try:
|
||||||
|
return self._websocket_connection
|
||||||
|
except AttributeError:
|
||||||
|
raise InvalidUsage("Improper websocket connection.")
|
||||||
|
|
||||||
def create_websocket_connection(
|
def create_websocket_connection(
|
||||||
self, send: ASGISend, receive: ASGIReceive
|
self, send: ASGISend, receive: ASGIReceive
|
||||||
|
@ -77,19 +86,61 @@ class MockTransport:
|
||||||
self._websocket_connection = WebSocketConnection(send, receive)
|
self._websocket_connection = WebSocketConnection(send, receive)
|
||||||
return self._websocket_connection
|
return self._websocket_connection
|
||||||
|
|
||||||
def add_task(self):
|
def add_task(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def send(self, data):
|
async def send(self, data) -> None:
|
||||||
print(">> sending. more:", data.get("more_body"))
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# - Validation on data and that it is formatted properly and is valid
|
# - Validation on data and that it is formatted properly and is valid
|
||||||
await self._send(data)
|
await self._send(data)
|
||||||
|
|
||||||
async def receive(self):
|
async def receive(self) -> ASGIMessage:
|
||||||
return await self._receive()
|
return await self._receive()
|
||||||
|
|
||||||
|
|
||||||
|
class Lifespan:
|
||||||
|
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||||
|
self.asgi_app = asgi_app
|
||||||
|
|
||||||
|
async def startup(self) -> None:
|
||||||
|
if self.asgi_app.sanic_app.listeners["before_server_start"]:
|
||||||
|
warnings.warn(
|
||||||
|
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
|
||||||
|
)
|
||||||
|
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
|
||||||
|
warnings.warn(
|
||||||
|
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
|
||||||
|
)
|
||||||
|
|
||||||
|
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
|
||||||
|
response = handler(
|
||||||
|
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||||
|
)
|
||||||
|
if isawaitable(response):
|
||||||
|
await response
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]:
|
||||||
|
response = handler(
|
||||||
|
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||||
|
)
|
||||||
|
if isawaitable(response):
|
||||||
|
await response
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||||
|
) -> None:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
await self.startup()
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.shutdown":
|
||||||
|
await self.shutdown()
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
|
||||||
|
|
||||||
class ASGIApp:
|
class ASGIApp:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.ws = None
|
self.ws = None
|
||||||
|
@ -104,42 +155,51 @@ class ASGIApp:
|
||||||
instance.transport.add_task = sanic_app.loop.create_task
|
instance.transport.add_task = sanic_app.loop.create_task
|
||||||
instance.transport.loop = sanic_app.loop
|
instance.transport.loop = sanic_app.loop
|
||||||
|
|
||||||
url_bytes = scope.get("root_path", "") + scope["path"]
|
|
||||||
url_bytes = url_bytes.encode("latin-1")
|
|
||||||
url_bytes += scope["query_string"]
|
|
||||||
headers = CIMultiDict(
|
headers = CIMultiDict(
|
||||||
[
|
[
|
||||||
(key.decode("latin-1"), value.decode("latin-1"))
|
(key.decode("latin-1"), value.decode("latin-1"))
|
||||||
for key, value in scope.get("headers", [])
|
for key, value in scope.get("headers", [])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
instance.do_stream = (
|
instance.do_stream = (
|
||||||
True if headers.get("expect") == "100-continue" else False
|
True if headers.get("expect") == "100-continue" else False
|
||||||
)
|
)
|
||||||
|
|
||||||
if scope["type"] == "http":
|
if scope["type"] == "lifespan":
|
||||||
version = scope["http_version"]
|
lifespan = Lifespan(instance)
|
||||||
method = scope["method"]
|
await lifespan(scope, receive, send)
|
||||||
elif scope["type"] == "websocket":
|
|
||||||
version = "1.1"
|
|
||||||
method = "GET"
|
|
||||||
|
|
||||||
instance.ws = instance.transport.create_websocket_connection(
|
|
||||||
send, receive
|
|
||||||
)
|
|
||||||
await instance.ws.accept()
|
|
||||||
else:
|
else:
|
||||||
pass
|
url_bytes = scope.get("root_path", "") + scope["path"]
|
||||||
# TODO:
|
url_bytes = url_bytes.encode("latin-1")
|
||||||
# - close connection
|
url_bytes += scope["query_string"]
|
||||||
|
|
||||||
instance.request = Request(
|
if scope["type"] == "http":
|
||||||
url_bytes, headers, version, method, instance.transport, sanic_app
|
version = scope["http_version"]
|
||||||
)
|
method = scope["method"]
|
||||||
|
elif scope["type"] == "websocket":
|
||||||
|
version = "1.1"
|
||||||
|
method = "GET"
|
||||||
|
|
||||||
if sanic_app.is_request_stream:
|
instance.ws = instance.transport.create_websocket_connection(
|
||||||
instance.request.stream = StreamBuffer()
|
send, receive
|
||||||
|
)
|
||||||
|
await instance.ws.accept()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# TODO:
|
||||||
|
# - close connection
|
||||||
|
|
||||||
|
instance.request = Request(
|
||||||
|
url_bytes,
|
||||||
|
headers,
|
||||||
|
version,
|
||||||
|
method,
|
||||||
|
instance.transport,
|
||||||
|
sanic_app,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sanic_app.is_request_stream:
|
||||||
|
instance.request.stream = StreamBuffer()
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import typing
|
|
||||||
import types
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
|
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from socket import socket
|
from socket import socket
|
||||||
|
|
Loading…
Reference in New Issue
Block a user