Lifespan and code cleanup
This commit is contained in:
parent
aebe2b5809
commit
3685b4de85
|
@ -54,7 +54,7 @@ class Sanic:
|
|||
logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS)
|
||||
|
||||
self.name = name
|
||||
self.asgi = True
|
||||
self.asgi = False
|
||||
self.router = router or Router()
|
||||
self.request_class = request_class
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
|
@ -1393,5 +1393,6 @@ class Sanic:
|
|||
# -------------------------------------------------------------------- #
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
self.asgi = True
|
||||
asgi_app = await ASGIApp.create(self, scope, receive, send)
|
||||
await asgi_app()
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import asyncio
|
||||
import warnings
|
||||
|
||||
from functools import partial
|
||||
from http.cookies import SimpleCookie
|
||||
from inspect import isawaitable
|
||||
from typing import Any, Awaitable, Callable, MutableMapping, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.exceptions import InvalidUsage, ServerError
|
||||
from sanic.log import logger
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||
from sanic.server import StreamBuffer
|
||||
|
@ -102,16 +104,30 @@ class Lifespan:
|
|||
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||
self.asgi_app = asgi_app
|
||||
|
||||
async def startup(self) -> None:
|
||||
if self.asgi_app.sanic_app.listeners["before_server_start"]:
|
||||
if "before_server_start" in self.asgi_app.sanic_app.listeners:
|
||||
warnings.warn(
|
||||
'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?'
|
||||
'You have set a listener for "before_server_start" in ASGI mode. '
|
||||
"It will be executed as early as possible, but not before "
|
||||
"the ASGI server is started."
|
||||
)
|
||||
if self.asgi_app.sanic_app.listeners["after_server_stop"]:
|
||||
if "after_server_stop" in self.asgi_app.sanic_app.listeners:
|
||||
warnings.warn(
|
||||
'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?'
|
||||
'You have set a listener for "after_server_stop" in ASGI mode. '
|
||||
"It will be executed as late as possible, but not before "
|
||||
"the ASGI server is stopped."
|
||||
)
|
||||
|
||||
async def pre_startup(self) -> None:
|
||||
for handler in self.asgi_app.sanic_app.listeners[
|
||||
"before_server_start"
|
||||
]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
async def startup(self) -> None:
|
||||
for handler in self.asgi_app.sanic_app.listeners["after_server_start"]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
|
@ -127,6 +143,16 @@ class Lifespan:
|
|||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
async def post_shutdown(self) -> None:
|
||||
for handler in self.asgi_app.sanic_app.listeners[
|
||||
"before_server_start"
|
||||
]:
|
||||
response = handler(
|
||||
self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop
|
||||
)
|
||||
if isawaitable(response):
|
||||
await response
|
||||
|
||||
async def __call__(
|
||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
||||
) -> None:
|
||||
|
@ -164,14 +190,15 @@ class ASGIApp:
|
|||
instance.do_stream = (
|
||||
True if headers.get("expect") == "100-continue" else False
|
||||
)
|
||||
instance.lifespan = Lifespan(instance)
|
||||
await instance.pre_startup()
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
lifespan = Lifespan(instance)
|
||||
await lifespan(scope, receive, send)
|
||||
await instance.lifespan(scope, receive, send)
|
||||
else:
|
||||
url_bytes = scope.get("root_path", "") + scope["path"]
|
||||
url_bytes = scope.get("root_path", "") + quote(scope["path"])
|
||||
url_bytes = url_bytes.encode("latin-1")
|
||||
url_bytes += scope["query_string"]
|
||||
url_bytes += b"?" + scope["query_string"]
|
||||
|
||||
if scope["type"] == "http":
|
||||
version = scope["http_version"]
|
||||
|
@ -250,10 +277,28 @@ class ASGIApp:
|
|||
Write the response.
|
||||
"""
|
||||
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
]
|
||||
try:
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
# if name not in ("Set-Cookie",)
|
||||
]
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
"Invalid response object for url %s, "
|
||||
"Expected Type: HTTPResponse, Actual Type: %s",
|
||||
self.request.url,
|
||||
type(response),
|
||||
)
|
||||
exception = ServerError("Invalid response type")
|
||||
response = self.sanic_app.error_handler.response(
|
||||
self.request, exception
|
||||
)
|
||||
headers = [
|
||||
(str(name).encode("latin-1"), str(value).encode("latin-1"))
|
||||
for name, value in response.headers.items()
|
||||
if name not in (b"Set-Cookie",)
|
||||
]
|
||||
|
||||
if "content-length" not in response.headers and not isinstance(
|
||||
response, StreamingHTTPResponse
|
||||
|
@ -262,6 +307,14 @@ class ASGIApp:
|
|||
(b"content-length", str(len(response.body)).encode("latin-1"))
|
||||
]
|
||||
|
||||
if response.cookies:
|
||||
cookies = SimpleCookie()
|
||||
cookies.load(response.cookies)
|
||||
headers += [
|
||||
(b"set-cookie", cookie.encode("utf-8"))
|
||||
for name, cookie in response.cookies.items()
|
||||
]
|
||||
|
||||
await self.transport.send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
|
|
|
@ -406,6 +406,7 @@ class Router:
|
|||
if not self.hosts:
|
||||
return self._get(request.path, request.method, "")
|
||||
# virtual hosts specified; try to match route to the host header
|
||||
|
||||
try:
|
||||
return self._get(
|
||||
request.path, request.method, request.headers.get("Host", "")
|
||||
|
|
138
sanic/testing.py
138
sanic/testing.py
|
@ -16,6 +16,7 @@ from sanic.log import logger
|
|||
from sanic.response import text
|
||||
|
||||
|
||||
ASGI_HOST = "mockserver"
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 42101
|
||||
|
||||
|
@ -275,7 +276,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
|
|||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if request.method != "HEAD":
|
||||
raw_kwargs["body"] += body
|
||||
raw_kwargs["content"] += body
|
||||
if not more_body:
|
||||
response_complete = True
|
||||
elif message["type"] == "http.response.template":
|
||||
|
@ -285,7 +286,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter):
|
|||
request_complete = False
|
||||
response_started = False
|
||||
response_complete = False
|
||||
raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any]
|
||||
raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any]
|
||||
template = None
|
||||
context = None
|
||||
return_value = None
|
||||
|
@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession):
|
|||
def __init__(
|
||||
self,
|
||||
app: "Sanic",
|
||||
base_url: str = "http://mockserver",
|
||||
base_url: str = "http://{}".format(ASGI_HOST),
|
||||
suppress_exceptions: bool = False,
|
||||
) -> None:
|
||||
app.__class__.__call__ = app_call_with_return
|
||||
|
||||
app.asgi = True
|
||||
super().__init__(app)
|
||||
|
||||
adapter = SanicASGIAdapter(
|
||||
|
@ -343,12 +344,16 @@ class SanicASGITestClient(requests.ASGISession):
|
|||
self.app = app
|
||||
self.base_url = base_url
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return await super().send(*args, **kwargs)
|
||||
# async def send(self, prepared_request, *args, **kwargs):
|
||||
# return await super().send(*args, **kwargs)
|
||||
|
||||
async def request(self, method, url, gather_request=True, *args, **kwargs):
|
||||
self.gather_request = gather_request
|
||||
print(url)
|
||||
response = await super().request(method, url, *args, **kwargs)
|
||||
response.status = response.status_code
|
||||
response.body = response.content
|
||||
response.content_type = response.headers.get("content-type")
|
||||
|
||||
if hasattr(response, "return_value"):
|
||||
request = response.return_value
|
||||
|
@ -361,124 +366,3 @@ class SanicASGITestClient(requests.ASGISession):
|
|||
settings = super().merge_environment_settings(*args, **kwargs)
|
||||
settings.update({"gather_return": self.gather_request})
|
||||
return settings
|
||||
|
||||
|
||||
# class SanicASGITestClient(requests.ASGISession):
|
||||
# __test__ = False # For pytest to not discover this up.
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# app: "Sanic",
|
||||
# base_url: str = "http://mockserver",
|
||||
# suppress_exceptions: bool = False,
|
||||
# ) -> None:
|
||||
# app.testing = True
|
||||
# super().__init__(
|
||||
# app, base_url=base_url, suppress_exceptions=suppress_exceptions
|
||||
# )
|
||||
# # adapter = _ASGIAdapter(
|
||||
# # app, raise_server_exceptions=raise_server_exceptions
|
||||
# # )
|
||||
# # self.mount("http://", adapter)
|
||||
# # self.mount("https://", adapter)
|
||||
# # self.mount("ws://", adapter)
|
||||
# # self.mount("wss://", adapter)
|
||||
# # self.headers.update({"user-agent": "testclient"})
|
||||
# # self.base_url = base_url
|
||||
|
||||
# # def request(
|
||||
# # self,
|
||||
# # method: str,
|
||||
# # url: str = "/",
|
||||
# # params: typing.Any = None,
|
||||
# # data: typing.Any = None,
|
||||
# # headers: typing.MutableMapping[str, str] = None,
|
||||
# # cookies: typing.Any = None,
|
||||
# # files: typing.Any = None,
|
||||
# # auth: typing.Any = None,
|
||||
# # timeout: typing.Any = None,
|
||||
# # allow_redirects: bool = None,
|
||||
# # proxies: typing.MutableMapping[str, str] = None,
|
||||
# # hooks: typing.Any = None,
|
||||
# # stream: bool = None,
|
||||
# # verify: typing.Union[bool, str] = None,
|
||||
# # cert: typing.Union[str, typing.Tuple[str, str]] = None,
|
||||
# # json: typing.Any = None,
|
||||
# # debug=None,
|
||||
# # gather_request=True,
|
||||
# # ) -> requests.Response:
|
||||
# # if debug is not None:
|
||||
# # self.app.debug = debug
|
||||
|
||||
# # url = urljoin(self.base_url, url)
|
||||
# # response = super().request(
|
||||
# # method,
|
||||
# # url,
|
||||
# # params=params,
|
||||
# # data=data,
|
||||
# # headers=headers,
|
||||
# # cookies=cookies,
|
||||
# # files=files,
|
||||
# # auth=auth,
|
||||
# # timeout=timeout,
|
||||
# # allow_redirects=allow_redirects,
|
||||
# # proxies=proxies,
|
||||
# # hooks=hooks,
|
||||
# # stream=stream,
|
||||
# # verify=verify,
|
||||
# # cert=cert,
|
||||
# # json=json,
|
||||
# # )
|
||||
|
||||
# # response.status = response.status_code
|
||||
# # response.body = response.content
|
||||
# # try:
|
||||
# # response.json = response.json()
|
||||
# # except:
|
||||
# # response.json = None
|
||||
|
||||
# # if gather_request:
|
||||
# # request = response.request
|
||||
# # parsed = urlparse(request.url)
|
||||
# # request.scheme = parsed.scheme
|
||||
# # request.path = parsed.path
|
||||
# # request.args = parse_qs(parsed.query)
|
||||
# # return request, response
|
||||
|
||||
# # return response
|
||||
|
||||
# # def get(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("get", *args, **kwargs)
|
||||
|
||||
# # def post(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("post", *args, **kwargs)
|
||||
|
||||
# # def put(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("put", *args, **kwargs)
|
||||
|
||||
# # def delete(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("delete", *args, **kwargs)
|
||||
|
||||
# # def patch(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("patch", *args, **kwargs)
|
||||
|
||||
# # def options(self, *args, **kwargs):
|
||||
# # if "uri" in kwargs:
|
||||
# # kwargs["url"] = kwargs.pop("uri")
|
||||
# # return self.request("options", *args, **kwargs)
|
||||
|
||||
# # def head(self, *args, **kwargs):
|
||||
# # return self._sanic_endpoint_test("head", *args, **kwargs)
|
||||
|
||||
# # def websocket(self, *args, **kwargs):
|
||||
# # return self._sanic_endpoint_test("websocket", *args, **kwargs)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app):
|
|||
|
||||
def test_app_loop_not_running(app):
|
||||
with pytest.raises(SanicException) as excinfo:
|
||||
_ = app.loop
|
||||
app.loop
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Loop can only be retrieved after the app has started "
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sanic.testing import SanicASGITestClient
|
||||
|
||||
|
||||
def asgi_client_instantiation(app):
|
||||
def test_asgi_client_instantiation(app):
|
||||
assert isinstance(app.asgi_client, SanicASGITestClient)
|
||||
|
|
|
@ -226,6 +226,7 @@ def test_config_access_log_passing_in_run(app):
|
|||
assert app.config.ACCESS_LOG == True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_access_log_passing_in_create_server(app):
|
||||
assert app.config.ACCESS_LOG == True
|
||||
|
||||
|
|
|
@ -27,6 +27,24 @@ def test_cookies(app):
|
|||
assert response_cookies["right_back"].value == "at you"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cookies_asgi(app):
|
||||
@app.route("/")
|
||||
def handler(request):
|
||||
response = text("Cookies are: {}".format(request.cookies["test"]))
|
||||
response.cookies["right_back"] = "at you"
|
||||
return response
|
||||
|
||||
request, response = await app.asgi_client.get(
|
||||
"/", cookies={"test": "working!"}
|
||||
)
|
||||
response_cookies = SimpleCookie()
|
||||
response_cookies.load(response.headers.get("set-cookie", {}))
|
||||
|
||||
assert response.text == "Cookies are: working!"
|
||||
assert response_cookies["right_back"].value == "at you"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)])
|
||||
def test_false_cookies_encoded(app, httponly, expected):
|
||||
@app.route("/")
|
||||
|
|
|
@ -24,7 +24,9 @@ old_conn = None
|
|||
class ReusableSanicConnectionPool(httpcore.ConnectionPool):
|
||||
async def acquire_connection(self, origin):
|
||||
global old_conn
|
||||
connection = self.active_connections.pop_by_origin(origin, http2_only=True)
|
||||
connection = self.active_connections.pop_by_origin(
|
||||
origin, http2_only=True
|
||||
)
|
||||
if connection is None:
|
||||
connection = self.keepalive_connections.pop_by_origin(origin)
|
||||
|
||||
|
@ -187,11 +189,7 @@ class ReuseableSanicTestClient(SanicTestClient):
|
|||
self._session = ResusableSanicSession()
|
||||
try:
|
||||
response = await getattr(self._session, method.lower())(
|
||||
url,
|
||||
verify=False,
|
||||
timeout=request_keepalive,
|
||||
*args,
|
||||
**kwargs,
|
||||
url, verify=False, timeout=request_keepalive, *args, **kwargs
|
||||
)
|
||||
except NameError:
|
||||
raise Exception(response.status_code)
|
||||
|
|
|
@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"])
|
||||
async def test_redirect_with_params(app, test_client, test_str):
|
||||
def test_redirect_with_params(app, test_str):
|
||||
use_in_uri = quote(test_str)
|
||||
|
||||
@app.route("/api/v1/test/<test>/")
|
||||
async def init_handler(request, test):
|
||||
assert test == test_str
|
||||
return redirect("/api/v2/test/{}/".format(quote(test)))
|
||||
return redirect("/api/v2/test/{}/".format(use_in_uri))
|
||||
|
||||
@app.route("/api/v2/test/<test>/")
|
||||
async def target_handler(request, test):
|
||||
assert test == test_str
|
||||
return text("OK")
|
||||
|
||||
test_cli = await test_client(app)
|
||||
|
||||
response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str)))
|
||||
_, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri))
|
||||
assert response.status == 200
|
||||
|
||||
txt = await response.text()
|
||||
assert txt == "OK"
|
||||
assert response.content == b"OK"
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.response import stream, text
|
||||
|
||||
|
||||
async def test_request_cancel_when_connection_lost(loop, app, test_client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_cancel_when_connection_lost(app):
|
||||
app.still_serving_cancelled_request = False
|
||||
|
||||
@app.get("/")
|
||||
|
@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
|
|||
app.still_serving_cancelled_request = True
|
||||
return text("OK")
|
||||
|
||||
test_cli = await test_client(app)
|
||||
|
||||
# schedule client call
|
||||
task = loop.create_task(test_cli.get("/"))
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(app.asgi_client.get("/"))
|
||||
loop.call_later(0.01, task)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client):
|
|||
assert app.still_serving_cancelled_request is False
|
||||
|
||||
|
||||
async def test_stream_request_cancel_when_conn_lost(loop, app, test_client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_request_cancel_when_conn_lost(app):
|
||||
app.still_serving_cancelled_request = False
|
||||
|
||||
@app.post("/post/<id>", stream=True)
|
||||
|
@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, test_client):
|
|||
|
||||
return stream(streaming)
|
||||
|
||||
test_cli = await test_client(app)
|
||||
|
||||
# schedule client call
|
||||
task = loop.create_task(test_cli.post("/post/1"))
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(app.asgi_client.post("/post/1"))
|
||||
loop.call_later(0.01, task)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
|
|
@ -111,7 +111,6 @@ def test_request_stream_app(app):
|
|||
result += body.decode("utf-8")
|
||||
return text(result)
|
||||
|
||||
|
||||
assert app.is_request_stream is True
|
||||
|
||||
request, response = app.test_client.get("/get")
|
||||
|
|
|
@ -13,15 +13,12 @@ class DelayableSanicConnectionPool(httpcore.ConnectionPool):
|
|||
self._request_delay = request_delay
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request,
|
||||
stream=False,
|
||||
ssl=None,
|
||||
timeout=None,
|
||||
):
|
||||
async def send(self, request, stream=False, ssl=None, timeout=None):
|
||||
connection = await self.acquire_connection(request.url.origin)
|
||||
if connection.h11_connection is None and connection.h2_connection is None:
|
||||
if (
|
||||
connection.h11_connection is None
|
||||
and connection.h2_connection is None
|
||||
):
|
||||
await connection.connect(ssl=ssl, timeout=timeout)
|
||||
if self._request_delay:
|
||||
await asyncio.sleep(self._request_delay)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked(
|
|||
async def mock_drain():
|
||||
pass
|
||||
|
||||
def mock_push_data(data):
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
|
@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked(
|
|||
async def mock_drain():
|
||||
pass
|
||||
|
||||
def mock_push_data(data):
|
||||
async def mock_push_data(data):
|
||||
response.protocol.transport.write(data)
|
||||
|
||||
response.protocol.push_data = mock_push_data
|
||||
|
|
|
@ -76,6 +76,7 @@ def test_all_listeners(app):
|
|||
assert app.name + listener_name == output.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_before_events_create_server(app):
|
||||
class MySanicDb:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue
Block a user