Lifespan and code cleanup

This commit is contained in:
Adam Hopkins
2019-06-04 10:58:00 +03:00
parent aebe2b5809
commit 3685b4de85
16 changed files with 924 additions and 179 deletions

View File

@@ -54,7 +54,7 @@ class Sanic:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
self.name = name
self.asgi = True
self.asgi = False
self.router = router or Router()
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
@@ -1393,5 +1393,6 @@ class Sanic:
# -------------------------------------------------------------------- #
async def __call__(self, scope, receive, send):
self.asgi = True
asgi_app = await ASGIApp.create(self, scope, receive, send)
await asgi_app()

View File

@@ -1,13 +1,15 @@
import asyncio
import warnings
from functools import partial
from http.cookies import SimpleCookie
from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union
from urllib.parse import quote
from multidict import CIMultiDict
from sanic.exceptions import InvalidUsage
from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import StreamBuffer
@@ -102,16 +104,30 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
async def startup(self) -> None:
if self.asgi_app.sanic_app.listeners["before_server_start"]:
if "before_server_start" in self.asgi_app.sanic_app.listeners:
warnings.warn(
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
'You have set a listener for "before_server_start" in ASGI mode. '
"It will be executed as early as possible, but not before "
"the ASGI server is started."
)
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
warnings.warn(
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
'You have set a listener for "after_server_stop" in ASGI mode. '
"It will be executed as late as possible, but not before "
"the ASGI server is stopped."
)
async def pre_startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
@@ -127,6 +143,16 @@ class Lifespan:
if isawaitable(response):
await response
async def post_shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
@@ -164,14 +190,15 @@ class ASGIApp:
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance)
await instance.pre_startup()
if scope["type"] == "lifespan":
lifespan = Lifespan(instance)
await lifespan(scope, receive, send)
await instance.lifespan(scope, receive, send)
else:
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = scope.get("root_path", "") + quote(scope["path"])
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
url_bytes += b"?" + scope["query_string"]
if scope["type"] == "http":
version = scope["http_version"]
@@ -250,10 +277,28 @@ class ASGIApp:
Write the response.
"""
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
]
try:
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
# if name not in ("Set-Cookie",)
]
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.request.url,
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name not in (b"Set-Cookie",)
]
if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse
@@ -262,6 +307,14 @@ class ASGIApp:
(b"content-length", str(len(response.body)).encode("latin-1"))
]
if response.cookies:
cookies = SimpleCookie()
cookies.load(response.cookies)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for name, cookie in response.cookies.items()
]
await self.transport.send(
{
"type": "http.response.start",

View File

@@ -406,6 +406,7 @@ class Router:
if not self.hosts:
return self._get(request.path, request.method, "")
# virtual hosts specified; try to match route to the host header
try:
return self._get(
request.path, request.method, request.headers.get("Host", "")

View File

@@ -16,6 +16,7 @@ from sanic.log import logger
from sanic.response import text
ASGI_HOST = "mockserver"
HOST = "127.0.0.1"
PORT = 42101
@@ -275,7 +276,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
raw_kwargs["body"] += body
raw_kwargs["content"] += body
if not more_body:
response_complete = True
elif message["type"] == "http.response.template":
@@ -285,7 +286,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
request_complete = False
response_started = False
response_complete = False
raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any]
raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any]
template = None
context = None
return_value = None
@@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession):
def __init__(
self,
app: "Sanic",
base_url: str = "http://mockserver",
base_url: str = "http://{}".format(ASGI_HOST),
suppress_exceptions: bool = False,
) -> None:
app.__class__.__call__ = app_call_with_return
app.asgi = True
super().__init__(app)
adapter = SanicASGIAdapter(
@@ -343,12 +344,16 @@ class SanicASGITestClient(requests.ASGISession):
self.app = app
self.base_url = base_url
async def send(self, *args, **kwargs):
return await super().send(*args, **kwargs)
# async def send(self, prepared_request, *args, **kwargs):
# return await super().send(*args, **kwargs)
async def request(self, method, url, gather_request=True, *args, **kwargs):
self.gather_request = gather_request
print(url)
response = await super().request(method, url, *args, **kwargs)
response.status = response.status_code
response.body = response.content
response.content_type = response.headers.get("content-type")
if hasattr(response, "return_value"):
request = response.return_value
@@ -361,124 +366,3 @@ class SanicASGITestClient(requests.ASGISession):
settings = super().merge_environment_settings(*args, **kwargs)
settings.update({"gather_return": self.gather_request})
return settings
# class SanicASGITestClient(requests.ASGISession):
# __test__ = False # For pytest to not discover this up.
# def __init__(
# self,
# app: "Sanic",
# base_url: str = "http://mockserver",
# suppress_exceptions: bool = False,
# ) -> None:
# app.testing = True
# super().__init__(
# app, base_url=base_url, suppress_exceptions=suppress_exceptions
# )
# # adapter = _ASGIAdapter(
# # app, raise_server_exceptions=raise_server_exceptions
# # )
# # self.mount("http://", adapter)
# # self.mount("https://", adapter)
# # self.mount("ws://", adapter)
# # self.mount("wss://", adapter)
# # self.headers.update({"user-agent": "testclient"})
# # self.base_url = base_url
# # def request(
# # self,
# # method: str,
# # url: str = "/",
# # params: typing.Any = None,
# # data: typing.Any = None,
# # headers: typing.MutableMapping[str, str] = None,
# # cookies: typing.Any = None,
# # files: typing.Any = None,
# # auth: typing.Any = None,
# # timeout: typing.Any = None,
# # allow_redirects: bool = None,
# # proxies: typing.MutableMapping[str, str] = None,
# # hooks: typing.Any = None,
# # stream: bool = None,
# # verify: typing.Union[bool, str] = None,
# # cert: typing.Union[str, typing.Tuple[str, str]] = None,
# # json: typing.Any = None,
# # debug=None,
# # gather_request=True,
# # ) -> requests.Response:
# # if debug is not None:
# # self.app.debug = debug
# # url = urljoin(self.base_url, url)
# # response = super().request(
# # method,
# # url,
# # params=params,
# # data=data,
# # headers=headers,
# # cookies=cookies,
# # files=files,
# # auth=auth,
# # timeout=timeout,
# # allow_redirects=allow_redirects,
# # proxies=proxies,
# # hooks=hooks,
# # stream=stream,
# # verify=verify,
# # cert=cert,
# # json=json,
# # )
# # response.status = response.status_code
# # response.body = response.content
# # try:
# # response.json = response.json()
# # except:
# # response.json = None
# # if gather_request:
# # request = response.request
# # parsed = urlparse(request.url)
# # request.scheme = parsed.scheme
# # request.path = parsed.path
# # request.args = parse_qs(parsed.query)
# # return request, response
# # return response
# # def get(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("get", *args, **kwargs)
# # def post(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("post", *args, **kwargs)
# # def put(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("put", *args, **kwargs)
# # def delete(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("delete", *args, **kwargs)
# # def patch(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("patch", *args, **kwargs)
# # def options(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("options", *args, **kwargs)
# # def head(self, *args, **kwargs):
# # return self._sanic_endpoint_test("head", *args, **kwargs)
# # def websocket(self, *args, **kwargs):
# # return self._sanic_endpoint_test("websocket", *args, **kwargs)