Lifespan and code cleanup

This commit is contained in:
Adam Hopkins 2019-06-04 10:58:00 +03:00
parent aebe2b5809
commit 3685b4de85
16 changed files with 924 additions and 179 deletions

View File

@ -54,7 +54,7 @@ class Sanic:
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
self.name = name self.name = name
self.asgi = True self.asgi = False
self.router = router or Router() self.router = router or Router()
self.request_class = request_class self.request_class = request_class
self.error_handler = error_handler or ErrorHandler() self.error_handler = error_handler or ErrorHandler()
@ -1393,5 +1393,6 @@ class Sanic:
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
self.asgi = True
asgi_app = await ASGIApp.create(self, scope, receive, send) asgi_app = await ASGIApp.create(self, scope, receive, send)
await asgi_app() await asgi_app()

View File

@ -1,13 +1,15 @@
import asyncio import asyncio
import warnings import warnings
from functools import partial from http.cookies import SimpleCookie
from inspect import isawaitable from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Union from typing import Any, Awaitable, Callable, MutableMapping, Union
from urllib.parse import quote
from multidict import CIMultiDict from multidict import CIMultiDict
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import StreamBuffer from sanic.server import StreamBuffer
@ -102,16 +104,30 @@ class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None: def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app self.asgi_app = asgi_app
async def startup(self) -> None: if "before_server_start" in self.asgi_app.sanic_app.listeners:
if self.asgi_app.sanic_app.listeners["before_server_start"]:
warnings.warn( warnings.warn(
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?' 'You have set a listener for "before_server_start" in ASGI mode. '
"It will be executed as early as possible, but not before "
"the ASGI server is started."
) )
if self.asgi_app.sanic_app.listeners["after_server_stop"]: if "after_server_stop" in self.asgi_app.sanic_app.listeners:
warnings.warn( warnings.warn(
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?' 'You have set a listener for "after_server_stop" in ASGI mode. '
"It will be executed as late as possible, but not before "
"the ASGI server is stopped."
) )
async def pre_startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def startup(self) -> None:
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
response = handler( response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
@ -127,6 +143,16 @@ class Lifespan:
if isawaitable(response): if isawaitable(response):
await response await response
async def post_shutdown(self) -> None:
for handler in self.asgi_app.sanic_app.listeners[
"before_server_start"
]:
response = handler(
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
)
if isawaitable(response):
await response
async def __call__( async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None: ) -> None:
@ -164,14 +190,15 @@ class ASGIApp:
instance.do_stream = ( instance.do_stream = (
True if headers.get("expect") == "100-continue" else False True if headers.get("expect") == "100-continue" else False
) )
instance.lifespan = Lifespan(instance)
await instance.pre_startup()
if scope["type"] == "lifespan": if scope["type"] == "lifespan":
lifespan = Lifespan(instance) await instance.lifespan(scope, receive, send)
await lifespan(scope, receive, send)
else: else:
url_bytes = scope.get("root_path", "") + scope["path"] url_bytes = scope.get("root_path", "") + quote(scope["path"])
url_bytes = url_bytes.encode("latin-1") url_bytes = url_bytes.encode("latin-1")
url_bytes += scope["query_string"] url_bytes += b"?" + scope["query_string"]
if scope["type"] == "http": if scope["type"] == "http":
version = scope["http_version"] version = scope["http_version"]
@ -250,10 +277,28 @@ class ASGIApp:
Write the response. Write the response.
""" """
headers = [ try:
(str(name).encode("latin-1"), str(value).encode("latin-1")) headers = [
for name, value in response.headers.items() (str(name).encode("latin-1"), str(value).encode("latin-1"))
] for name, value in response.headers.items()
# if name not in ("Set-Cookie",)
]
except AttributeError:
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.request.url,
type(response),
)
exception = ServerError("Invalid response type")
response = self.sanic_app.error_handler.response(
self.request, exception
)
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
if name not in (b"Set-Cookie",)
]
if "content-length" not in response.headers and not isinstance( if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse response, StreamingHTTPResponse
@ -262,6 +307,14 @@ class ASGIApp:
(b"content-length", str(len(response.body)).encode("latin-1")) (b"content-length", str(len(response.body)).encode("latin-1"))
] ]
if response.cookies:
cookies = SimpleCookie()
cookies.load(response.cookies)
headers += [
(b"set-cookie", cookie.encode("utf-8"))
for name, cookie in response.cookies.items()
]
await self.transport.send( await self.transport.send(
{ {
"type": "http.response.start", "type": "http.response.start",

View File

@ -406,6 +406,7 @@ class Router:
if not self.hosts: if not self.hosts:
return self._get(request.path, request.method, "") return self._get(request.path, request.method, "")
# virtual hosts specified; try to match route to the host header # virtual hosts specified; try to match route to the host header
try: try:
return self._get( return self._get(
request.path, request.method, request.headers.get("Host", "") request.path, request.method, request.headers.get("Host", "")

View File

@ -16,6 +16,7 @@ from sanic.log import logger
from sanic.response import text from sanic.response import text
ASGI_HOST = "mockserver"
HOST = "127.0.0.1" HOST = "127.0.0.1"
PORT = 42101 PORT = 42101
@ -275,7 +276,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
body = message.get("body", b"") body = message.get("body", b"")
more_body = message.get("more_body", False) more_body = message.get("more_body", False)
if request.method != "HEAD": if request.method != "HEAD":
raw_kwargs["body"] += body raw_kwargs["content"] += body
if not more_body: if not more_body:
response_complete = True response_complete = True
elif message["type"] == "http.response.template": elif message["type"] == "http.response.template":
@ -285,7 +286,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
request_complete = False request_complete = False
response_started = False response_started = False
response_complete = False response_complete = False
raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any] raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any]
template = None template = None
context = None context = None
return_value = None return_value = None
@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession):
def __init__( def __init__(
self, self,
app: "Sanic", app: "Sanic",
base_url: str = "http://mockserver", base_url: str = "http://{}".format(ASGI_HOST),
suppress_exceptions: bool = False, suppress_exceptions: bool = False,
) -> None: ) -> None:
app.__class__.__call__ = app_call_with_return app.__class__.__call__ = app_call_with_return
app.asgi = True
super().__init__(app) super().__init__(app)
adapter = SanicASGIAdapter( adapter = SanicASGIAdapter(
@ -343,12 +344,16 @@ class SanicASGITestClient(requests.ASGISession):
self.app = app self.app = app
self.base_url = base_url self.base_url = base_url
async def send(self, *args, **kwargs): # async def send(self, prepared_request, *args, **kwargs):
return await super().send(*args, **kwargs) # return await super().send(*args, **kwargs)
async def request(self, method, url, gather_request=True, *args, **kwargs): async def request(self, method, url, gather_request=True, *args, **kwargs):
self.gather_request = gather_request self.gather_request = gather_request
print(url)
response = await super().request(method, url, *args, **kwargs) response = await super().request(method, url, *args, **kwargs)
response.status = response.status_code
response.body = response.content
response.content_type = response.headers.get("content-type")
if hasattr(response, "return_value"): if hasattr(response, "return_value"):
request = response.return_value request = response.return_value
@ -361,124 +366,3 @@ class SanicASGITestClient(requests.ASGISession):
settings = super().merge_environment_settings(*args, **kwargs) settings = super().merge_environment_settings(*args, **kwargs)
settings.update({"gather_return": self.gather_request}) settings.update({"gather_return": self.gather_request})
return settings return settings
# class SanicASGITestClient(requests.ASGISession):
# __test__ = False # For pytest to not discover this up.
# def __init__(
# self,
# app: "Sanic",
# base_url: str = "http://mockserver",
# suppress_exceptions: bool = False,
# ) -> None:
# app.testing = True
# super().__init__(
# app, base_url=base_url, suppress_exceptions=suppress_exceptions
# )
# # adapter = _ASGIAdapter(
# # app, raise_server_exceptions=raise_server_exceptions
# # )
# # self.mount("http://", adapter)
# # self.mount("https://", adapter)
# # self.mount("ws://", adapter)
# # self.mount("wss://", adapter)
# # self.headers.update({"user-agent": "testclient"})
# # self.base_url = base_url
# # def request(
# # self,
# # method: str,
# # url: str = "/",
# # params: typing.Any = None,
# # data: typing.Any = None,
# # headers: typing.MutableMapping[str, str] = None,
# # cookies: typing.Any = None,
# # files: typing.Any = None,
# # auth: typing.Any = None,
# # timeout: typing.Any = None,
# # allow_redirects: bool = None,
# # proxies: typing.MutableMapping[str, str] = None,
# # hooks: typing.Any = None,
# # stream: bool = None,
# # verify: typing.Union[bool, str] = None,
# # cert: typing.Union[str, typing.Tuple[str, str]] = None,
# # json: typing.Any = None,
# # debug=None,
# # gather_request=True,
# # ) -> requests.Response:
# # if debug is not None:
# # self.app.debug = debug
# # url = urljoin(self.base_url, url)
# # response = super().request(
# # method,
# # url,
# # params=params,
# # data=data,
# # headers=headers,
# # cookies=cookies,
# # files=files,
# # auth=auth,
# # timeout=timeout,
# # allow_redirects=allow_redirects,
# # proxies=proxies,
# # hooks=hooks,
# # stream=stream,
# # verify=verify,
# # cert=cert,
# # json=json,
# # )
# # response.status = response.status_code
# # response.body = response.content
# # try:
# # response.json = response.json()
# # except:
# # response.json = None
# # if gather_request:
# # request = response.request
# # parsed = urlparse(request.url)
# # request.scheme = parsed.scheme
# # request.path = parsed.path
# # request.args = parse_qs(parsed.query)
# # return request, response
# # return response
# # def get(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("get", *args, **kwargs)
# # def post(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("post", *args, **kwargs)
# # def put(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("put", *args, **kwargs)
# # def delete(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("delete", *args, **kwargs)
# # def patch(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("patch", *args, **kwargs)
# # def options(self, *args, **kwargs):
# # if "uri" in kwargs:
# # kwargs["url"] = kwargs.pop("uri")
# # return self.request("options", *args, **kwargs)
# # def head(self, *args, **kwargs):
# # return self._sanic_endpoint_test("head", *args, **kwargs)
# # def websocket(self, *args, **kwargs):
# # return self._sanic_endpoint_test("websocket", *args, **kwargs)

View File

@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app):
def test_app_loop_not_running(app): def test_app_loop_not_running(app):
with pytest.raises(SanicException) as excinfo: with pytest.raises(SanicException) as excinfo:
_ = app.loop app.loop
assert str(excinfo.value) == ( assert str(excinfo.value) == (
"Loop can only be retrieved after the app has started " "Loop can only be retrieved after the app has started "

View File

@ -1,5 +1,5 @@
from sanic.testing import SanicASGITestClient from sanic.testing import SanicASGITestClient
def asgi_client_instantiation(app): def test_asgi_client_instantiation(app):
assert isinstance(app.asgi_client, SanicASGITestClient) assert isinstance(app.asgi_client, SanicASGITestClient)

View File

@ -226,6 +226,7 @@ def test_config_access_log_passing_in_run(app):
assert app.config.ACCESS_LOG == True assert app.config.ACCESS_LOG == True
@pytest.mark.asyncio
async def test_config_access_log_passing_in_create_server(app): async def test_config_access_log_passing_in_create_server(app):
assert app.config.ACCESS_LOG == True assert app.config.ACCESS_LOG == True

View File

@ -27,6 +27,24 @@ def test_cookies(app):
assert response_cookies["right_back"].value == "at you" assert response_cookies["right_back"].value == "at you"
@pytest.mark.asyncio
async def test_cookies_asgi(app):
@app.route("/")
def handler(request):
response = text("Cookies are: {}".format(request.cookies["test"]))
response.cookies["right_back"] = "at you"
return response
request, response = await app.asgi_client.get(
"/", cookies={"test": "working!"}
)
response_cookies = SimpleCookie()
response_cookies.load(response.headers.get("set-cookie", {}))
assert response.text == "Cookies are: working!"
assert response_cookies["right_back"].value == "at you"
@pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)]) @pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)])
def test_false_cookies_encoded(app, httponly, expected): def test_false_cookies_encoded(app, httponly, expected):
@app.route("/") @app.route("/")

View File

@ -24,7 +24,9 @@ old_conn = None
class ReusableSanicConnectionPool(httpcore.ConnectionPool): class ReusableSanicConnectionPool(httpcore.ConnectionPool):
async def acquire_connection(self, origin): async def acquire_connection(self, origin):
global old_conn global old_conn
connection = self.active_connections.pop_by_origin(origin, http2_only=True) connection = self.active_connections.pop_by_origin(
origin, http2_only=True
)
if connection is None: if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin) connection = self.keepalive_connections.pop_by_origin(origin)
@ -187,11 +189,7 @@ class ReuseableSanicTestClient(SanicTestClient):
self._session = ResusableSanicSession() self._session = ResusableSanicSession()
try: try:
response = await getattr(self._session, method.lower())( response = await getattr(self._session, method.lower())(
url, url, verify=False, timeout=request_keepalive, *args, **kwargs
verify=False,
timeout=request_keepalive,
*args,
**kwargs,
) )
except NameError: except NameError:
raise Exception(response.status_code) raise Exception(response.status_code)

View File

@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app):
@pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"]) @pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"])
async def test_redirect_with_params(app, test_client, test_str): def test_redirect_with_params(app, test_str):
use_in_uri = quote(test_str)
@app.route("/api/v1/test/<test>/") @app.route("/api/v1/test/<test>/")
async def init_handler(request, test): async def init_handler(request, test):
assert test == test_str return redirect("/api/v2/test/{}/".format(use_in_uri))
return redirect("/api/v2/test/{}/".format(quote(test)))
@app.route("/api/v2/test/<test>/") @app.route("/api/v2/test/<test>/")
async def target_handler(request, test): async def target_handler(request, test):
assert test == test_str assert test == test_str
return text("OK") return text("OK")
test_cli = await test_client(app) _, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri))
response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str)))
assert response.status == 200 assert response.status == 200
txt = await response.text() assert response.content == b"OK"
assert txt == "OK"

View File

@ -1,10 +1,13 @@
import asyncio import asyncio
import contextlib import contextlib
import pytest
from sanic.response import stream, text from sanic.response import stream, text
async def test_request_cancel_when_connection_lost(loop, app, test_client): @pytest.mark.asyncio
async def test_request_cancel_when_connection_lost(app):
app.still_serving_cancelled_request = False app.still_serving_cancelled_request = False
@app.get("/") @app.get("/")
@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
app.still_serving_cancelled_request = True app.still_serving_cancelled_request = True
return text("OK") return text("OK")
test_cli = await test_client(app)
# schedule client call # schedule client call
task = loop.create_task(test_cli.get("/")) loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.get("/"))
loop.call_later(0.01, task) loop.call_later(0.01, task)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
assert app.still_serving_cancelled_request is False assert app.still_serving_cancelled_request is False
async def test_stream_request_cancel_when_conn_lost(loop, app, test_client): @pytest.mark.asyncio
async def test_stream_request_cancel_when_conn_lost(app):
app.still_serving_cancelled_request = False app.still_serving_cancelled_request = False
@app.post("/post/<id>", stream=True) @app.post("/post/<id>", stream=True)
@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, test_client):
return stream(streaming) return stream(streaming)
test_cli = await test_client(app)
# schedule client call # schedule client call
task = loop.create_task(test_cli.post("/post/1")) loop = asyncio.get_event_loop()
task = loop.create_task(app.asgi_client.post("/post/1"))
loop.call_later(0.01, task) loop.call_later(0.01, task)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)

View File

@ -111,7 +111,6 @@ def test_request_stream_app(app):
result += body.decode("utf-8") result += body.decode("utf-8")
return text(result) return text(result)
assert app.is_request_stream is True assert app.is_request_stream is True
request, response = app.test_client.get("/get") request, response = app.test_client.get("/get")

View File

@ -13,15 +13,12 @@ class DelayableSanicConnectionPool(httpcore.ConnectionPool):
self._request_delay = request_delay self._request_delay = request_delay
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def send( async def send(self, request, stream=False, ssl=None, timeout=None):
self,
request,
stream=False,
ssl=None,
timeout=None,
):
connection = await self.acquire_connection(request.url.origin) connection = await self.acquire_connection(request.url.origin)
if connection.h11_connection is None and connection.h2_connection is None: if (
connection.h11_connection is None
and connection.h2_connection is None
):
await connection.connect(ssl=ssl, timeout=timeout) await connection.connect(ssl=ssl, timeout=timeout)
if self._request_delay: if self._request_delay:
await asyncio.sleep(self._request_delay) await asyncio.sleep(self._request_delay)

File diff suppressed because it is too large Load Diff

View File

@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked(
async def mock_drain(): async def mock_drain():
pass pass
def mock_push_data(data): async def mock_push_data(data):
response.protocol.transport.write(data) response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data response.protocol.push_data = mock_push_data
@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
async def mock_drain(): async def mock_drain():
pass pass
def mock_push_data(data): async def mock_push_data(data):
response.protocol.transport.write(data) response.protocol.transport.write(data)
response.protocol.push_data = mock_push_data response.protocol.push_data = mock_push_data

View File

@ -76,6 +76,7 @@ def test_all_listeners(app):
assert app.name + listener_name == output.pop() assert app.name + listener_name == output.pop()
@pytest.mark.asyncio
async def test_trigger_before_events_create_server(app): async def test_trigger_before_events_create_server(app):
class MySanicDb: class MySanicDb:
pass pass