Lifespan and code cleanup

This commit is contained in:
Adam Hopkins 2019-06-04 10:58:00 +03:00
parent aebe2b5809
commit 3685b4de85
16 changed files with 924 additions and 179 deletions

View File

@ -54,7 +54,7 @@ class Sanic:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
self.name = name
self.asgi = True
self.asgi = False
self.router = router or Router()
self.request_class = request_class
self.error_handler = error_handler or ErrorHandler()
@ -1393,5 +1393,6 @@ class Sanic:
# -------------------------------------------------------------------- #
async def __call__(self, scope, receive, send):
self.asgi = True
asgi_app = await ASGIApp.create(self, scope, receive, send)
await asgi_app()

View File

@ -1,13 +1,15 @@
import asyncio
import warnings
from functools import partial
from http.cookies import SimpleCookie
from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union
from urllib.parse import quote
from multidict import CIMultiDict
from sanic.exceptions import InvalidUsage
from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import StreamBuffer
@ -102,16 +104,30 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app
async def startup(self) -> None:
if self.asgi_app.sanic_app.listeners["before_server_start"]:
if "before_server_start" in self.asgi_app.sanic_app.listeners:
warnings.warn(
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
'You have set a listener for "before_server_start" in ASGI mode. '
"It will be executed as early as possible, but not before "
"the ASGI server is started."
)
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
warnings.warn(
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
'You have set a listener for "after_server_stop" in ASGI mode. '
"It will be executed as late as possible, but not before "
"the ASGI server is stopped."
)
async def pre_startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
@ -127,6 +143,16 @@ class Lifespan:
if isawaitable(response):
await response
async def post_shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
@ -164,14 +190,15 @@ class ASGIApp:
instance.do_stream = (
True if headers.get("expect") == "100-continue" else False
)
instance.lifespan = Lifespan(instance)
await instance.pre_startup()
if scope["type"] == "lifespan":
lifespan = Lifespan(instance)
await lifespan(scope, receive, send)
await instance.lifespan(scope, receive, send)
else:
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = scope.get("root_path", "") + quote(scope["path"])
url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"]
url_bytes += b"?" + scope["query_string"]
if scope["type"] == "http":
version = scope["http_version"]
@ -250,10 +277,28 @@ class ASGIApp:
Write the response.
"""
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
]
try:
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
# if name not in ("Set-Cookie",)
]
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.request.url,
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name not in (b"Set-Cookie",)
]
if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse
@ -262,6 +307,14 @@ class ASGIApp:
(b"content-length", str(len(response.body)).encode("latin-1"))
]
if response.cookies:
cookies = SimpleCookie()
cookies.load(response.cookies)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for name, cookie in response.cookies.items()
]
await self.transport.send(
{
"type": "http.response.start",

View File

@ -406,6 +406,7 @@ class Router:
if not self.hosts:
return self._get(request.path, request.method, "")
# virtual hosts specified; try to match route to the host header
try:
return self._get(
request.path, request.method, request.headers.get("Host", "")

View File

@ -16,6 +16,7 @@ from sanic.log import logger
from sanic.response import text
ASGI_HOST = "mockserver"
HOST = "127.0.0.1"
PORT = 42101
@ -275,7 +276,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
raw_kwargs["body"] += body
raw_kwargs["content"] += body
if not more_body:
response_complete = True
elif message["type"] == "http.response.template":
@ -285,7 +286,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
request_complete = False
response_started = False
response_complete = False
raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any]
raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any]
template = None
context = None
return_value = None
@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession):
def __init__(
self,
app: "Sanic",
base_url: str = "http://mockserver",
base_url: str = "http://{}".format(ASGI_HOST),
suppress_exceptions: bool = False,
) -> None:
app.__class__.__call__ = app_call_with_return
app.asgi = True
super().__init__(app)
adapter = SanicASGIAdapter(
@ -343,12 +344,16 @@ class SanicASGITestClient(requests.ASGISession):
self.app = app
self.base_url = base_url
async def send(self, *args, **kwargs):
return await super().send(*args, **kwargs)
# async def send(self, prepared_request, *args, **kwargs):
# return await super().send(*args, **kwargs)
async def request(self, method, url, gather_request=True, *args, **kwargs):
self.gather_request = gather_request
print(url)
response = await super().request(method, url, *args, **kwargs)
response.status = response.status_code
response.body = response.content
response.content_type = response.headers.get("content-type")
if hasattr(response, "return_value"):
request = response.return_value
@ -361,124 +366,3 @@ class SanicASGITestClient(requests.ASGISession):
settings = super().merge_environment_settings(*args, **kwargs)
settings.update({"gather_return": self.gather_request})
return settings
# class SanicASGITestClient(requests.ASGISession):
# __test__ = False # For pytest to not discover this up.
# def __init__(
# self,
# app: "Sanic",
# base_url: str = "http://mockserver",
# suppress_exceptions: bool = False,
# ) -> None:
# app.testing = True
# super().__init__(
# app, base_url=base_url, suppress_exceptions=suppress_exceptions
# )
# # adapter = _ASGIAdapter(
# # app, raise_server_exceptions=raise_server_exceptions
# # )
# # self.mount("http://", adapter)
# # self.mount("https://", adapter)
# # self.mount("ws://", adapter)
# # self.mount("wss://", adapter)
# # self.headers.update({"user-agent": "testclient"})
# # self.base_url = base_url
# # def request(
# # self,
# # method: str,
# # url: str = "/",
# # params: typing.Any = None,
# # data: typing.Any = None,
# # headers: typing.MutableMapping[str, str] = None,
# # cookies: typing.Any = None,
# # files: typing.Any = None,
# # auth: typing.Any = None,
# # timeout: typing.Any = None,
# # allow_redirects: bool = None,
# # proxies: typing.MutableMapping[str, str] = None,
# # hooks: typing.Any = None,
# # stream: bool = None,
# # verify: typing.Union[bool, str] = None,
# # cert: typing.Union[str, typing.Tuple[str, str]] = None,
# # json: typing.Any = None,
# # debug=None,
# # gather_request=True,
# # ) -> requests.Response:
# # if debug is not None:
# # self.app.debug = debug
# # url = urljoin(self.base_url, url)
# # response = super().request(
# # method,
# # url,
# # params=params,
# # data=data,
# # headers=headers,
# # cookies=cookies,
# # files=files,
# # auth=auth,
# # timeout=timeout,
# # allow_redirects=allow_redirects,
# # proxies=proxies,
# # hooks=hooks,
# # stream=stream,
# # verify=verify,
# # cert=cert,
# # json=json,
# # )
# # response.status = response.status_code
# # response.body = response.content
# # try:
# # response.json = response.json()
# # except:
# # response.json = None
# # if gather_request:
# # request = response.request
# # parsed = urlparse(request.url)
# # request.scheme = parsed.scheme
# # request.path = parsed.path
# # request.args = parse_qs(parsed.query)
# # return request, response
# # return response
# # def get(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("get", *args, **kwargs)
# # def post(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("post", *args, **kwargs)
# # def put(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("put", *args, **kwargs)
# # def delete(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("delete", *args, **kwargs)
# # def patch(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("patch", *args, **kwargs)
# # def options(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("options", *args, **kwargs)
# # def head(self, *args, **kwargs):
# # return self._sanic_endpoint_test("head", *args, **kwargs)
# # def websocket(self, *args, **kwargs):
# # return self._sanic_endpoint_test("websocket", *args, **kwargs)

View File

@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app):
def test_app_loop_not_running(app):
with pytest.raises(SanicException) as excinfo:
_ = app.loop
app.loop
assert str(excinfo.value) == (
"Loop can only be retrieved after the app has started "

View File

@ -1,5 +1,5 @@
from sanic.testing import SanicASGITestClient
def asgi_client_instantiation(app):
def test_asgi_client_instantiation(app):
assert isinstance(app.asgi_client, SanicASGITestClient)

View File

@ -226,6 +226,7 @@ def test_config_access_log_passing_in_run(app):
assert app.config.ACCESS_LOG == True
@pytest.mark.asyncio
async def test_config_access_log_passing_in_create_server(app):
assert app.config.ACCESS_LOG == True

View File

@ -27,6 +27,24 @@ def test_cookies(app):
assert response_cookies["right_back"].value == "at you"
@pytest.mark.asyncio
async def test_cookies_asgi(app):
@app.route("/")
def handler(request):
response = text("Cookies are: {}".format(request.cookies["test"]))
response.cookies["right_back"] = "at you"
return response
request, response = await app.asgi_client.get(
"/", cookies={"test": "working!"}
)
response_cookies = SimpleCookie()
response_cookies.load(response.headers.get("set-cookie", {}))
assert response.text == "Cookies are: working!"
assert response_cookies["right_back"].value == "at you"
@pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)])
def test_false_cookies_encoded(app, httponly, expected):
@app.route("/")

View File

@ -24,7 +24,9 @@ old_conn = None
class ReusableSanicConnectionPool(httpcore.ConnectionPool):
async def acquire_connection(self, origin):
global old_conn
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
connection = self.active_connections.pop_by_origin(
origin, http2_only=True
)
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)
@ -187,11 +189,7 @@ class ReuseableSanicTestClient(SanicTestClient):
self._session = ResusableSanicSession()
try:
response = await getattr(self._session, method.lower())(
url,
verify=False,
timeout=request_keepalive,
*args,
**kwargs,
url, verify=False, timeout=request_keepalive, *args, **kwargs
)
except NameError:
raise Exception(response.status_code)

View File

@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app):
@pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"])
async def test_redirect_with_params(app, test_client, test_str):
def test_redirect_with_params(app, test_str):
use_in_uri = quote(test_str)
@app.route("/api/v1/test/<test>/")
async def init_handler(request, test):
assert test == test_str
return redirect("/api/v2/test/{}/".format(quote(test)))
return redirect("/api/v2/test/{}/".format(use_in_uri))
@app.route("/api/v2/test/<test>/")
async def target_handler(request, test):
assert test == test_str
return text("OK")
test_cli = await test_client(app)
response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str)))
_, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri))
assert response.status == 200
txt = await response.text()
assert txt == "OK"
assert response.content == b"OK"

View File

@ -1,10 +1,13 @@
import asyncio
import contextlib
import pytest
from sanic.response import stream, text
async def test_request_cancel_when_connection_lost(loop, app, test_client):
@pytest.mark.asyncio
async def test_request_cancel_when_connection_lost(app):
app.still_serving_cancelled_request = False
@app.get("/")
@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
app.still_serving_cancelled_request = True
return text("OK")
test_cli = await test_client(app)
# schedule client call
task = loop.create_task(test_cli.get("/"))
loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.get("/"))
loop.call_later(0.01, task)
await asyncio.sleep(0.5)
@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
assert app.still_serving_cancelled_request is False
async def test_stream_request_cancel_when_conn_lost(loop, app, test_client):
@pytest.mark.asyncio
async def test_stream_request_cancel_when_conn_lost(app):
app.still_serving_cancelled_request = False
@app.post("/post/<id>", stream=True)
@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, test_client):
return stream(streaming)
test_cli = await test_client(app)
# schedule client call
task = loop.create_task(test_cli.post("/post/1"))
loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.post("/post/1"))
loop.call_later(0.01, task)
await asyncio.sleep(0.5)

View File

@ -111,7 +111,6 @@ def test_request_stream_app(app):
result += body.decode("utf-8")
return text(result)
assert app.is_request_stream is True
request, response = app.test_client.get("/get")

View File

@ -13,15 +13,12 @@ class DelayableSanicConnectionPool(httpcore.ConnectionPool):
self._request_delay = request_delay
super().__init__(*args, **kwargs)
async def send(
self,
request,
stream=False,
ssl=None,
timeout=None,
):
async def send(self, request, stream=False, ssl=None, timeout=None):
connection = await self.acquire_connection(request.url.origin)
if connection.h11_connection is None and connection.h2_connection is None:
if (
connection.h11_connection is None
and connection.h2_connection is None
):
await connection.connect(ssl=ssl, timeout=timeout)
if self._request_delay:
await asyncio.sleep(self._request_delay)

File diff suppressed because it is too large Load Diff

View File

@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked(
async def mock_drain():
pass
def mock_push_data(data):
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data
@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
async def mock_drain():
pass
def mock_push_data(data):
async def mock_push_data(data):
response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data

View File

@ -76,6 +76,7 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop()
@pytest.mark.asyncio
async def test_trigger_before_events_create_server(app):
class MySanicDb:
pass