diff --git a/sanic/app.py b/sanic/app.py index ccc7680f..5760ebca 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -54,7 +54,7 @@ class Sanic: logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) self.name = name - self.asgi = True + self.asgi = False self.router = router or Router() self.request_class = request_class self.error_handler = error_handler or ErrorHandler() @@ -1393,5 +1393,6 @@ class Sanic: # -------------------------------------------------------------------- # async def __call__(self, scope, receive, send): + self.asgi = True asgi_app = await ASGIApp.create(self, scope, receive, send) await asgi_app() diff --git a/sanic/asgi.py b/sanic/asgi.py index 56460e69..336e477f 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,13 +1,15 @@ import asyncio import warnings -from functools import partial +from http.cookies import SimpleCookie from inspect import isawaitable from typing import Any, Awaitable, Callable, MutableMapping, Union +from urllib.parse import quote from multidict import CIMultiDict -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, ServerError +from sanic.log import logger from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import StreamBuffer @@ -102,16 +104,30 @@ class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app - async def startup(self) -> None: - if self.asgi_app.sanic_app.listeners["before_server_start"]: + if "before_server_start" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?' + 'You have set a listener for "before_server_start" in ASGI mode. ' + "It will be executed as early as possible, but not before " + "the ASGI server is started." ) - if self.asgi_app.sanic_app.listeners["after_server_stop"]: + if "after_server_stop" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?' + 'You have set a listener for "after_server_stop" in ASGI mode. ' + "It will be executed as late as possible, but not before " + "the ASGI server is stopped." ) + async def pre_startup(self) -> None: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + + async def startup(self) -> None: for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop @@ -127,6 +143,16 @@ class Lifespan: if isawaitable(response): await response + async def post_shutdown(self) -> None: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + async def __call__( self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend ) -> None: @@ -164,14 +190,15 @@ class ASGIApp: instance.do_stream = ( True if headers.get("expect") == "100-continue" else False ) + instance.lifespan = Lifespan(instance) + await instance.pre_startup() if scope["type"] == "lifespan": - lifespan = Lifespan(instance) - await lifespan(scope, receive, send) + await instance.lifespan(scope, receive, send) else: - url_bytes = scope.get("root_path", "") + scope["path"] + url_bytes = scope.get("root_path", "") + quote(scope["path"]) url_bytes = url_bytes.encode("latin-1") - url_bytes += scope["query_string"] + url_bytes += b"?" + scope["query_string"] if scope["type"] == "http": version = scope["http_version"] @@ -250,10 +277,28 @@ class ASGIApp: Write the response. """ - headers = [ - (str(name).encode("latin-1"), str(value).encode("latin-1")) - for name, value in response.headers.items() - ] + try: + headers = [ + (str(name).encode("latin-1"), str(value).encode("latin-1")) + for name, value in response.headers.items() + # if name not in ("Set-Cookie",) + ] + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.request.url, + type(response), + ) + exception = ServerError("Invalid response type") + response = self.sanic_app.error_handler.response( + self.request, exception + ) + headers = [ + (str(name).encode("latin-1"), str(value).encode("latin-1")) + for name, value in response.headers.items() + if name not in (b"Set-Cookie",) + ] if "content-length" not in response.headers and not isinstance( response, StreamingHTTPResponse @@ -262,6 +307,14 @@ class ASGIApp: (b"content-length", str(len(response.body)).encode("latin-1")) ] + if response.cookies: + cookies = SimpleCookie() + cookies.load(response.cookies) + headers += [ + (b"set-cookie", cookie.encode("utf-8")) + for name, cookie in response.cookies.items() + ] + await self.transport.send( { "type": "http.response.start", diff --git a/sanic/router.py b/sanic/router.py index 4c1ea0a0..63119446 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -406,6 +406,7 @@ class Router: if not self.hosts: return self._get(request.path, request.method, "") # virtual hosts specified; try to match route to the host header + try: return self._get( request.path, request.method, request.headers.get("Host", "") diff --git a/sanic/testing.py b/sanic/testing.py index d7211f3d..6f32896a 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -16,6 +16,7 @@ from sanic.log import logger from sanic.response import text +ASGI_HOST = "mockserver" HOST = "127.0.0.1" PORT = 42101 @@ -275,7 +276,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter): body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": - raw_kwargs["body"] += body + raw_kwargs["content"] += body if not more_body: response_complete = True elif message["type"] == "http.response.template": @@ -285,7 +286,7 @@ class SanicASGIAdapter(requests.asgi.ASGIAdapter): request_complete = False response_started = False response_complete = False - raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any] + raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any] template = None context = None return_value = None @@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession): def __init__( self, app: "Sanic", - base_url: str = "http://mockserver", + base_url: str = "http://{}".format(ASGI_HOST), suppress_exceptions: bool = False, ) -> None: app.__class__.__call__ = app_call_with_return - + app.asgi = True super().__init__(app) adapter = SanicASGIAdapter( @@ -343,12 +344,16 @@ class SanicASGITestClient(requests.ASGISession): self.app = app self.base_url = base_url - async def send(self, *args, **kwargs): - return await super().send(*args, **kwargs) + # async def send(self, prepared_request, *args, **kwargs): + # return await super().send(*args, **kwargs) async def request(self, method, url, gather_request=True, *args, **kwargs): self.gather_request = gather_request + print(url) response = await super().request(method, url, *args, **kwargs) + response.status = response.status_code + response.body = response.content + response.content_type = response.headers.get("content-type") if hasattr(response, "return_value"): request = response.return_value @@ -361,124 +366,3 @@ class SanicASGITestClient(requests.ASGISession): settings = super().merge_environment_settings(*args, **kwargs) settings.update({"gather_return": self.gather_request}) return settings - - -# class SanicASGITestClient(requests.ASGISession): -# __test__ = False # For pytest to not discover this up. - -# def __init__( -# self, -# app: "Sanic", -# base_url: str = "http://mockserver", -# suppress_exceptions: bool = False, -# ) -> None: -# app.testing = True -# super().__init__( -# app, base_url=base_url, suppress_exceptions=suppress_exceptions -# ) -# # adapter = _ASGIAdapter( -# # app, raise_server_exceptions=raise_server_exceptions -# # ) -# # self.mount("http://", adapter) -# # self.mount("https://", adapter) -# # self.mount("ws://", adapter) -# # self.mount("wss://", adapter) -# # self.headers.update({"user-agent": "testclient"}) -# # self.base_url = base_url - -# # def request( -# # self, -# # method: str, -# # url: str = "/", -# # params: typing.Any = None, -# # data: typing.Any = None, -# # headers: typing.MutableMapping[str, str] = None, -# # cookies: typing.Any = None, -# # files: typing.Any = None, -# # auth: typing.Any = None, -# # timeout: typing.Any = None, -# # allow_redirects: bool = None, -# # proxies: typing.MutableMapping[str, str] = None, -# # hooks: typing.Any = None, -# # stream: bool = None, -# # verify: typing.Union[bool, str] = None, -# # cert: typing.Union[str, typing.Tuple[str, str]] = None, -# # json: typing.Any = None, -# # debug=None, -# # gather_request=True, -# # ) -> requests.Response: -# # if debug is not None: -# # self.app.debug = debug - -# # url = urljoin(self.base_url, url) -# # response = super().request( -# # method, -# # url, -# # params=params, -# # data=data, -# # headers=headers, -# # cookies=cookies, -# # files=files, -# # auth=auth, -# # timeout=timeout, -# # allow_redirects=allow_redirects, -# # proxies=proxies, -# # hooks=hooks, -# # stream=stream, -# # verify=verify, -# # cert=cert, -# # json=json, -# # ) - -# # response.status = response.status_code -# # response.body = response.content -# # try: -# # response.json = response.json() -# # except: -# # response.json = None - -# # if gather_request: -# # request = response.request -# # parsed = urlparse(request.url) -# # request.scheme = parsed.scheme -# # request.path = parsed.path -# # request.args = parse_qs(parsed.query) -# # return request, response - -# # return response - -# # def get(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("get", *args, **kwargs) - -# # def post(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("post", *args, **kwargs) - -# # def put(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("put", *args, **kwargs) - -# # def delete(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("delete", *args, **kwargs) - -# # def patch(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("patch", *args, **kwargs) - -# # def options(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("options", *args, **kwargs) - -# # def head(self, *args, **kwargs): -# # return self._sanic_endpoint_test("head", *args, **kwargs) - -# # def websocket(self, *args, **kwargs): -# # return self._sanic_endpoint_test("websocket", *args, **kwargs) diff --git a/tests/test_app.py b/tests/test_app.py index 5ddae42d..deb050b6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app): def test_app_loop_not_running(app): with pytest.raises(SanicException) as excinfo: - _ = app.loop + app.loop assert str(excinfo.value) == ( "Loop can only be retrieved after the app has started " diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d51b4f2f..d0fa1d91 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,5 @@ from sanic.testing import SanicASGITestClient -def asgi_client_instantiation(app): +def test_asgi_client_instantiation(app): assert isinstance(app.asgi_client, SanicASGITestClient) diff --git a/tests/test_config.py b/tests/test_config.py index 7b203311..2445d02c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -226,6 +226,7 @@ def test_config_access_log_passing_in_run(app): assert app.config.ACCESS_LOG == True +@pytest.mark.asyncio async def test_config_access_log_passing_in_create_server(app): assert app.config.ACCESS_LOG == True diff --git a/tests/test_cookies.py b/tests/test_cookies.py index a77fda2f..737a752d 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -27,6 +27,24 @@ def test_cookies(app): assert response_cookies["right_back"].value == "at you" +@pytest.mark.asyncio +async def test_cookies_asgi(app): + @app.route("/") + def handler(request): + response = text("Cookies are: {}".format(request.cookies["test"])) + response.cookies["right_back"] = "at you" + return response + + request, response = await app.asgi_client.get( + "/", cookies={"test": "working!"} + ) + response_cookies = SimpleCookie() + response_cookies.load(response.headers.get("set-cookie", {})) + + assert response.text == "Cookies are: working!" + assert response_cookies["right_back"].value == "at you" + + @pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)]) def test_false_cookies_encoded(app, httponly, expected): @app.route("/") diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index c6fc0831..672d78ac 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -24,7 +24,9 @@ old_conn = None class ReusableSanicConnectionPool(httpcore.ConnectionPool): async def acquire_connection(self, origin): global old_conn - connection = self.active_connections.pop_by_origin(origin, http2_only=True) + connection = self.active_connections.pop_by_origin( + origin, http2_only=True + ) if connection is None: connection = self.keepalive_connections.pop_by_origin(origin) @@ -187,11 +189,7 @@ class ReuseableSanicTestClient(SanicTestClient): self._session = ResusableSanicSession() try: response = await getattr(self._session, method.lower())( - url, - verify=False, - timeout=request_keepalive, - *args, - **kwargs, + url, verify=False, timeout=request_keepalive, *args, **kwargs ) except NameError: raise Exception(response.status_code) diff --git a/tests/test_redirect.py b/tests/test_redirect.py index 86c4ace3..7d0c0edf 100644 --- a/tests/test_redirect.py +++ b/tests/test_redirect.py @@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app): @pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"]) -async def test_redirect_with_params(app, test_client, test_str): +def test_redirect_with_params(app, test_str): + use_in_uri = quote(test_str) + @app.route("/api/v1/test//") async def init_handler(request, test): - assert test == test_str - return redirect("/api/v2/test/{}/".format(quote(test))) + return redirect("/api/v2/test/{}/".format(use_in_uri)) @app.route("/api/v2/test//") async def target_handler(request, test): assert test == test_str return text("OK") - test_cli = await test_client(app) - - response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str))) + _, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri)) assert response.status == 200 - txt = await response.text() - assert txt == "OK" + assert response.content == b"OK" diff --git a/tests/test_request_cancel.py b/tests/test_request_cancel.py index e9499f6d..b5d90882 100644 --- a/tests/test_request_cancel.py +++ b/tests/test_request_cancel.py @@ -1,10 +1,13 @@ import asyncio import contextlib +import pytest + from sanic.response import stream, text -async def test_request_cancel_when_connection_lost(loop, app, test_client): +@pytest.mark.asyncio +async def test_request_cancel_when_connection_lost(app): app.still_serving_cancelled_request = False @app.get("/") @@ -14,10 +17,9 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client): app.still_serving_cancelled_request = True return text("OK") - test_cli = await test_client(app) - # schedule client call - task = loop.create_task(test_cli.get("/")) + loop = asyncio.get_event_loop() + task = loop.create_task(app.asgi_client.get("/")) loop.call_later(0.01, task) await asyncio.sleep(0.5) @@ -33,7 +35,8 @@ async def test_request_cancel_when_connection_lost(loop, app, test_client): assert app.still_serving_cancelled_request is False -async def test_stream_request_cancel_when_conn_lost(loop, app, test_client): +@pytest.mark.asyncio +async def test_stream_request_cancel_when_conn_lost(app): app.still_serving_cancelled_request = False @app.post("/post/", stream=True) @@ -53,10 +56,9 @@ async def test_stream_request_cancel_when_conn_lost(loop, app, test_client): return stream(streaming) - test_cli = await test_client(app) - # schedule client call - task = loop.create_task(test_cli.post("/post/1")) + loop = asyncio.get_event_loop() + task = loop.create_task(app.asgi_client.post("/post/1")) loop.call_later(0.01, task) await asyncio.sleep(0.5) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index d845dc85..65472a1e 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -111,7 +111,6 @@ def test_request_stream_app(app): result += body.decode("utf-8") return text(result) - assert app.is_request_stream is True request, response = app.test_client.get("/get") diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 3a41e462..e3e02d7c 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -13,15 +13,12 @@ class DelayableSanicConnectionPool(httpcore.ConnectionPool): self._request_delay = request_delay super().__init__(*args, **kwargs) - async def send( - self, - request, - stream=False, - ssl=None, - timeout=None, - ): + async def send(self, request, stream=False, ssl=None, timeout=None): connection = await self.acquire_connection(request.url.origin) - if connection.h11_connection is None and connection.h2_connection is None: + if ( + connection.h11_connection is None + and connection.h2_connection is None + ): await connection.connect(ssl=ssl, timeout=timeout) if self._request_delay: await asyncio.sleep(self._request_delay) diff --git a/tests/test_requests.py b/tests/test_requests.py index 64a919e8..ea1946dd 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -12,7 +12,7 @@ from sanic import Blueprint, Sanic from sanic.exceptions import ServerError from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters from sanic.response import json, text -from sanic.testing import HOST, PORT +from sanic.testing import ASGI_HOST, HOST, PORT # ------------------------------------------------------------ # @@ -72,6 +72,17 @@ def test_text(app): assert response.text == "Hello" +@pytest.mark.asyncio +async def test_text_asgi(app): + @app.route("/") + async def handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert response.text == "Hello" + + def test_headers(app): @app.route("/") async def handler(request): @@ -83,6 +94,18 @@ def test_headers(app): assert response.headers.get("spam") == "great" +@pytest.mark.asyncio +async def test_headers_asgi(app): + @app.route("/") + async def handler(request): + headers = {"spam": "great"} + return text("Hello", headers=headers) + + request, response = await app.asgi_client.get("/") + + assert response.headers.get("spam") == "great" + + def test_non_str_headers(app): @app.route("/") async def handler(request): @@ -94,6 +117,18 @@ def test_non_str_headers(app): assert response.headers.get("answer") == "42" +@pytest.mark.asyncio +async def test_non_str_headers_asgi(app): + @app.route("/") + async def handler(request): + headers = {"answer": 42} + return text("Hello", headers=headers) + + request, response = await app.asgi_client.get("/") + + assert response.headers.get("answer") == "42" + + def test_invalid_response(app): @app.exception(ServerError) def handler_exception(request, exception): @@ -108,6 +143,21 @@ def test_invalid_response(app): assert response.text == "Internal Server Error." +@pytest.mark.asyncio +async def test_invalid_response_asgi(app): + @app.exception(ServerError) + def handler_exception(request, exception): + return text("Internal Server Error.", 500) + + @app.route("/") + async def handler(request): + return "This should fail" + + request, response = await app.asgi_client.get("/") + assert response.status == 500 + assert response.text == "Internal Server Error." + + def test_json(app): @app.route("/") async def handler(request): @@ -120,6 +170,19 @@ def test_json(app): assert results.get("test") is True +@pytest.mark.asyncio +async def test_json_asgi(app): + @app.route("/") + async def handler(request): + return json({"test": True}) + + request, response = await app.asgi_client.get("/") + + results = json_loads(response.text) + + assert results.get("test") is True + + def test_empty_json(app): @app.route("/") async def handler(request): @@ -131,6 +194,18 @@ def test_empty_json(app): assert response.text == "null" +@pytest.mark.asyncio +async def test_empty_json_asgi(app): + @app.route("/") + async def handler(request): + assert request.json is None + return json(request.json) + + request, response = await app.asgi_client.get("/") + assert response.status == 200 + assert response.text == "null" + + def test_invalid_json(app): @app.route("/") async def handler(request): @@ -142,6 +217,18 @@ def test_invalid_json(app): assert response.status == 400 +@pytest.mark.asyncio +async def test_invalid_json_asgi(app): + @app.route("/") + async def handler(request): + return json(request.json) + + data = "I am not json" + request, response = await app.asgi_client.get("/", data=data) + + assert response.status == 400 + + def test_query_string(app): @app.route("/") async def handler(request): @@ -158,6 +245,23 @@ def test_query_string(app): assert request.args.get("test3", default="My value") == "My value" +@pytest.mark.asyncio +async def test_query_string_asgi(app): + @app.route("/") + async def handler(request): + return text("OK") + + request, response = await app.asgi_client.get( + "/", params=[("test1", "1"), ("test2", "false"), ("test2", "true")] + ) + + assert request.args.get("test1") == "1" + assert request.args.get("test2") == "false" + assert request.args.getlist("test2") == ["false", "true"] + assert request.args.getlist("test1") == ["1"] + assert request.args.get("test3", default="My value") == "My value" + + def test_uri_template(app): @app.route("/foo//bar/") async def handler(request, id, name): @@ -167,6 +271,16 @@ def test_uri_template(app): assert request.uri_template == "/foo//bar/" +@pytest.mark.asyncio +async def test_uri_template_asgi(app): + @app.route("/foo//bar/") + async def handler(request, id, name): + return text("OK") + + request, response = await app.asgi_client.get("/foo/123/bar/baz") + assert request.uri_template == "/foo//bar/" + + def test_token(app): @app.route("/") async def handler(request): @@ -211,6 +325,51 @@ def test_token(app): assert request.token is None +@pytest.mark.asyncio +async def test_token_asgi(app): + @app.route("/") + async def handler(request): + return text("OK") + + # uuid4 generated token. + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "{}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "Token {}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + # no Authorization headers + headers = {"content-type": "application/json"} + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token is None + + def test_content_type(app): @app.route("/") async def handler(request): @@ -226,6 +385,22 @@ def test_content_type(app): assert response.text == "application/json" +@pytest.mark.asyncio +async def test_content_type_asgi(app): + @app.route("/") + async def handler(request): + return text(request.content_type) + + request, response = await app.asgi_client.get("/") + assert request.content_type == DEFAULT_HTTP_CONTENT_TYPE + assert response.text == DEFAULT_HTTP_CONTENT_TYPE + + headers = {"content-type": "application/json"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.content_type == "application/json" + assert response.text == "application/json" + + def test_remote_addr_with_two_proxies(app): app.config.PROXIES_COUNT = 2 @@ -265,6 +440,46 @@ def test_remote_addr_with_two_proxies(app): assert response.text == "127.0.0.1" +@pytest.mark.asyncio +async def test_remote_addr_with_two_proxies_asgi(app): + app.config.PROXIES_COUNT = 2 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + request, response = await app.asgi_client.get("/") + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, , ,,127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + headers = { + "X-Forwarded-For": ", 127.0.2.2, , ,127.0.0.1, , ,,127.0.1.2" + } + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + def test_remote_addr_with_infinite_number_of_proxies(app): app.config.PROXIES_COUNT = -1 @@ -290,6 +505,32 @@ def test_remote_addr_with_infinite_number_of_proxies(app): assert response.text == "127.0.0.5" +@pytest.mark.asyncio +async def test_remote_addr_with_infinite_number_of_proxies_asgi(app): + app.config.PROXIES_COUNT = -1 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.1.1" + assert response.text == "127.0.1.1" + + headers = { + "X-Forwarded-For": "127.0.0.5, 127.0.0.4, 127.0.0.3, 127.0.0.2, 127.0.0.1" + } + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.5" + assert response.text == "127.0.0.5" + + def test_remote_addr_without_proxy(app): app.config.PROXIES_COUNT = 0 @@ -313,6 +554,30 @@ def test_remote_addr_without_proxy(app): assert response.text == "" +@pytest.mark.asyncio +async def test_remote_addr_without_proxy_asgi(app): + app.config.PROXIES_COUNT = 0 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + def test_remote_addr_custom_headers(app): app.config.PROXIES_COUNT = 1 app.config.REAL_IP_HEADER = "Client-IP" @@ -338,6 +603,32 @@ def test_remote_addr_custom_headers(app): assert response.text == "127.0.0.2" +@pytest.mark.asyncio +async def test_remote_addr_custom_headers_asgi(app): + app.config.PROXIES_COUNT = 1 + app.config.REAL_IP_HEADER = "Client-IP" + app.config.FORWARDED_FOR_HEADER = "Forwarded" + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.1.1" + assert response.text == "127.0.1.1" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"Client-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + def test_match_info(app): @app.route("/api/v1/user//") async def handler(request, user_id): @@ -349,6 +640,18 @@ def test_match_info(app): assert json_loads(response.text) == {"user_id": "sanic_user"} +@pytest.mark.asyncio +async def test_match_info_asgi(app): + @app.route("/api/v1/user//") + async def handler(request, user_id): + return json(request.match_info) + + request, response = await app.asgi_client.get("/api/v1/user/sanic_user/") + + assert request.match_info == {"user_id": "sanic_user"} + assert json_loads(response.text) == {"user_id": "sanic_user"} + + # ------------------------------------------------------------ # # POST # ------------------------------------------------------------ # @@ -371,6 +674,24 @@ def test_post_json(app): assert response.text == "OK" +@pytest.mark.asyncio +async def test_post_json_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = {"test": "OK"} + headers = {"content-type": "application/json"} + + request, response = await app.asgi_client.post( + "/", data=json_dumps(payload), headers=headers + ) + + assert request.json.get("test") == "OK" + assert request.json.get("test") == "OK" # for request.parsed_json + assert response.text == "OK" + + def test_post_form_urlencoded(app): @app.route("/", methods=["POST"]) async def handler(request): @@ -387,6 +708,23 @@ def test_post_form_urlencoded(app): assert request.form.get("test") == "OK" # For request.parsed_form +@pytest.mark.asyncio +async def test_post_form_urlencoded_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "test=OK" + headers = {"content-type": "application/x-www-form-urlencoded"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.get("test") == "OK" + assert request.form.get("test") == "OK" # For request.parsed_form + + @pytest.mark.parametrize( "payload", [ @@ -414,6 +752,36 @@ def test_post_form_multipart_form_data(app, payload): assert request.form.get("test") == "OK" +@pytest.mark.parametrize( + "payload", + [ + "------sanic\r\n" + 'Content-Disposition: form-data; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "------sanic\r\n" + 'content-disposition: form-data; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + ], +) +@pytest.mark.asyncio +async def test_post_form_multipart_form_data_asgi(app, payload): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + headers = {"content-type": "multipart/form-data; boundary=----sanic"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.get("test") == "OK" + + @pytest.mark.parametrize( "path,query,expected_url", [ @@ -439,6 +807,32 @@ def test_url_attributes_no_ssl(app, path, query, expected_url): assert parsed.netloc == request.host +@pytest.mark.parametrize( + "path,query,expected_url", + [ + ("/foo", "", "http://{}/foo"), + ("/bar/baz", "", "http://{}/bar/baz"), + ("/moo/boo", "arg1=val1", "http://{}/moo/boo?arg1=val1"), + ], +) +@pytest.mark.asyncio +async def test_url_attributes_no_ssl_asgi(app, path, query, expected_url): + async def handler(request): + return text("OK") + + app.add_route(handler, path) + + request, response = await app.asgi_client.get(path + "?{}".format(query)) + assert request.url == expected_url.format(ASGI_HOST) + + parsed = urlparse(request.url) + + assert parsed.scheme == request.scheme + assert parsed.path == request.path + assert parsed.query == request.query_string + assert parsed.netloc == request.host + + @pytest.mark.parametrize( "path,query,expected_url", [ @@ -540,6 +934,23 @@ def test_form_with_multiple_values(app): assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] +@pytest.mark.asyncio +async def test_form_with_multiple_values_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "selectedItems=v1&selectedItems=v2&selectedItems=v3" + + headers = {"content-type": "application/x-www-form-urlencoded"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] + + def test_request_string_representation(app): @app.route("/", methods=["GET"]) async def get(request): @@ -549,6 +960,16 @@ def test_request_string_representation(app): assert repr(request) == "" +@pytest.mark.asyncio +async def test_request_string_representation_asgi(app): + @app.route("/", methods=["GET"]) + async def get(request): + return text("OK") + + request, _ = await app.asgi_client.get("/") + assert repr(request) == "" + + @pytest.mark.parametrize( "payload,filename", [ @@ -613,6 +1034,71 @@ def test_request_multipart_files(app, payload, filename): assert request.files.get("test").name == filename +@pytest.mark.parametrize( + "payload,filename", + [ + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename="filename"; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "filename", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename="filename"; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "filename", + ), + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename=""; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename=""; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "", + ), + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "filename_\u00A0_test", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "filename_\u00A0_test", + ), + ], +) +@pytest.mark.asyncio +async def test_request_multipart_files_asgi(app, payload, filename): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + headers = {"content-type": "multipart/form-data; boundary=----sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert request.files.get("test").name == filename + + def test_request_multipart_file_with_json_content_type(app): @app.route("/", methods=["POST"]) async def post(request): @@ -634,6 +1120,28 @@ def test_request_multipart_file_with_json_content_type(app): assert request.files.get("file").type == "application/json" +@pytest.mark.asyncio +async def test_request_multipart_file_with_json_content_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + "------sanic\r\n" + 'Content-Disposition: form-data; name="file"; filename="test.json"\r\n' + "Content-Type: application/json\r\n" + "Content-Length: 0" + "\r\n" + "\r\n" + "------sanic--" + ) + + headers = {"content-type": "multipart/form-data; boundary=------sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert request.files.get("file").type == "application/json" + + def test_request_multipart_file_without_field_name(app, caplog): @app.route("/", methods=["POST"]) async def post(request): @@ -694,6 +1202,39 @@ def test_request_multipart_file_duplicate_filed_name(app): ] +@pytest.mark.asyncio +async def test_request_multipart_file_duplicate_filed_name_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" + 'Content-Disposition: form-data; name="file"\r\n' + "Content-Type: application/octet-stream\r\n" + "Content-Length: 15\r\n" + "\r\n" + '{"test":"json"}\r\n' + "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" + 'Content-Disposition: form-data; name="file"\r\n' + "Content-Type: application/octet-stream\r\n" + "Content-Length: 15\r\n" + "\r\n" + '{"test":"json2"}\r\n' + "--e73ffaa8b1b2472b8ec848de833cb05b--\r\n" + ) + + headers = { + "Content-Type": "multipart/form-data; boundary=e73ffaa8b1b2472b8ec848de833cb05b" + } + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert request.form.getlist("file") == [ + '{"test":"json"}', + '{"test":"json2"}', + ] + + def test_request_multipart_with_multiple_files_and_type(app): @app.route("/", methods=["POST"]) async def post(request): @@ -713,6 +1254,26 @@ def test_request_multipart_with_multiple_files_and_type(app): assert request.files.getlist("file")[1].type == "application/pdf" +@pytest.mark.asyncio +async def test_request_multipart_with_multiple_files_and_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + '------sanic\r\nContent-Disposition: form-data; name="file"; filename="test.json"' + "\r\nContent-Type: application/json\r\n\r\n\r\n" + '------sanic\r\nContent-Disposition: form-data; name="file"; filename="some_file.pdf"\r\n' + "Content-Type: application/pdf\r\n\r\n\r\n------sanic--" + ) + headers = {"content-type": "multipart/form-data; boundary=------sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert len(request.files.getlist("file")) == 2 + assert request.files.getlist("file")[0].type == "application/json" + assert request.files.getlist("file")[1].type == "application/pdf" + + def test_request_repr(app): @app.get("/") def handler(request): @@ -725,6 +1286,19 @@ def test_request_repr(app): assert repr(request) == "" +@pytest.mark.asyncio +async def test_request_repr_asgi(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/") + assert repr(request) == "" + + request.method = None + assert repr(request) == "" + + def test_request_bool(app): @app.get("/") def handler(request): @@ -759,6 +1333,29 @@ def test_request_parsing_form_failed(app, caplog): ) +@pytest.mark.asyncio +async def test_request_parsing_form_failed_asgi(app, caplog): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "test=OK" + headers = {"content-type": "multipart/form-data"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + with caplog.at_level(logging.ERROR): + request.form + + assert caplog.record_tuples[-1] == ( + "sanic.error", + logging.ERROR, + "Failed when parsing form", + ) + + def test_request_args_no_query_string(app): @app.get("/") def handler(request): @@ -769,6 +1366,17 @@ def test_request_args_no_query_string(app): assert request.args == {} +@pytest.mark.asyncio +async def test_request_args_no_query_string_await(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/") + + assert request.args == {} + + def test_request_raw_args(app): params = {"test": "OK"} @@ -782,6 +1390,20 @@ def test_request_raw_args(app): assert request.raw_args == params +@pytest.mark.asyncio +async def test_request_raw_args_asgi(app): + + params = {"test": "OK"} + + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/", params=params) + + assert request.raw_args == params + + def test_request_query_args(app): # test multiple params with the same key params = [("test", "value1"), ("test", "value2")] @@ -818,6 +1440,43 @@ def test_request_query_args(app): assert not request.query_args +@pytest.mark.asyncio +async def test_request_query_args_asgi(app): + # test multiple params with the same key + params = [("test", "value1"), ("test", "value2")] + + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/", params=params) + + assert request.query_args == params + + # test cached value + assert ( + request.parsed_not_grouped_args[(False, False, "utf-8", "replace")] + == request.query_args + ) + + # test params directly in the url + request, response = await app.asgi_client.get("/?test=value1&test=value2") + + assert request.query_args == params + + # test unique params + params = [("test1", "value1"), ("test2", "value2")] + + request, response = await app.asgi_client.get("/", params=params) + + assert request.query_args == params + + # test no params + request, response = await app.asgi_client.get("/") + + assert not request.query_args + + def test_request_query_args_custom_parsing(app): @app.get("/") def handler(request): @@ -851,6 +1510,40 @@ def test_request_query_args_custom_parsing(app): ) +@pytest.mark.asyncio +async def test_request_query_args_custom_parsing_asgi(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get( + "/?test1=value1&test2=&test3=value3" + ) + + assert request.get_query_args(keep_blank_values=True) == [ + ("test1", "value1"), + ("test2", ""), + ("test3", "value3"), + ] + assert request.query_args == [("test1", "value1"), ("test3", "value3")] + assert request.get_query_args(keep_blank_values=False) == [ + ("test1", "value1"), + ("test3", "value3"), + ] + + assert request.get_args(keep_blank_values=True) == RequestParameters( + {"test1": ["value1"], "test2": [""], "test3": ["value3"]} + ) + + assert request.args == RequestParameters( + {"test1": ["value1"], "test3": ["value3"]} + ) + + assert request.get_args(keep_blank_values=False) == RequestParameters( + {"test1": ["value1"], "test3": ["value3"]} + ) + + def test_request_cookies(app): cookies = {"test": "OK"} @@ -865,6 +1558,21 @@ def test_request_cookies(app): assert request.cookies == cookies # For request._cookies +@pytest.mark.asyncio +async def test_request_cookies_asgi(app): + + cookies = {"test": "OK"} + + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/", cookies=cookies) + + assert request.cookies == cookies + assert request.cookies == cookies # For request._cookies + + def test_request_cookies_without_cookies(app): @app.get("/") def handler(request): @@ -875,6 +1583,17 @@ def test_request_cookies_without_cookies(app): assert request.cookies == {} +@pytest.mark.asyncio +async def test_request_cookies_without_cookies_asgi(app): + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/") + + assert request.cookies == {} + + def test_request_port(app): @app.get("/") def handler(request): @@ -894,6 +1613,26 @@ def test_request_port(app): assert hasattr(request, "_port") +@pytest.mark.asyncio +async def test_request_port_asgi(app): + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/") + + port = request.port + assert isinstance(port, int) + + delattr(request, "_socket") + delattr(request, "_port") + + port = request.port + assert isinstance(port, int) + assert hasattr(request, "_socket") + assert hasattr(request, "_port") + + def test_request_socket(app): @app.get("/") def handler(request): @@ -927,6 +1666,17 @@ def test_request_form_invalid_content_type(app): assert request.form == {} +@pytest.mark.asyncio +async def test_request_form_invalid_content_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + request, response = await app.asgi_client.post("/", json={"test": "OK"}) + + assert request.form == {} + + def test_endpoint_basic(): app = Sanic() @@ -939,6 +1689,19 @@ def test_endpoint_basic(): assert request.endpoint == "test_requests.my_unique_handler" +@pytest.mark.asyncio +async def test_endpoint_basic_asgi(): + app = Sanic() + + @app.route("/") + def my_unique_handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert request.endpoint == "test_requests.my_unique_handler" + + def test_endpoint_named_app(): app = Sanic("named") @@ -951,6 +1714,19 @@ def test_endpoint_named_app(): assert request.endpoint == "named.my_unique_handler" +@pytest.mark.asyncio +async def test_endpoint_named_app_asgi(): + app = Sanic("named") + + @app.route("/") + def my_unique_handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert request.endpoint == "named.my_unique_handler" + + def test_endpoint_blueprint(): bp = Blueprint("my_blueprint", url_prefix="/bp") @@ -964,3 +1740,19 @@ def test_endpoint_blueprint(): request, response = app.test_client.get("/bp") assert request.endpoint == "named.my_blueprint.bp_root" + + +@pytest.mark.asyncio +async def test_endpoint_blueprint_asgi(): + bp = Blueprint("my_blueprint", url_prefix="/bp") + + @bp.route("/") + async def bp_root(request): + return text("Hello") + + app = Sanic("named") + app.blueprint(bp) + + request, response = await app.asgi_client.get("/bp") + + assert request.endpoint == "named.my_blueprint.bp_root" diff --git a/tests/test_response.py b/tests/test_response.py index c47dd1db..8feadb06 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked( async def mock_drain(): pass - def mock_push_data(data): + async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data @@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked( async def mock_drain(): pass - def mock_push_data(data): + async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data diff --git a/tests/test_server_events.py b/tests/test_server_events.py index be17e801..412f9fa6 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -76,6 +76,7 @@ def test_all_listeners(app): assert app.name + listener_name == output.pop() +@pytest.mark.asyncio async def test_trigger_before_events_create_server(app): class MySanicDb: pass