diff --git a/sanic/asgi.py b/sanic/asgi.py index 205b7187..be598ec1 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -164,10 +164,12 @@ class ASGIApp: Read and stream the body in chunks from an incoming ASGI message. """ message = await self.transport.receive() + body = message.get("body", b"") if not message.get("more_body", False): self.request_body = False - return None - return message.get("body", b"") + if not body: + return None + return body async def __aiter__(self): while self.request_body: diff --git a/sanic/blueprint_group.py b/sanic/blueprint_group.py index c8a36534..27716d8f 100644 --- a/sanic/blueprint_group.py +++ b/sanic/blueprint_group.py @@ -1,9 +1,13 @@ from collections.abc import MutableSequence -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import sanic +if TYPE_CHECKING: + from sanic.blueprints import Blueprint + + class BlueprintGroup(MutableSequence): """ This class provides a mechanism to implement a Blueprint Group @@ -56,7 +60,12 @@ class BlueprintGroup(MutableSequence): __slots__ = ("_blueprints", "_url_prefix", "_version", "_strict_slashes") - def __init__(self, url_prefix=None, version=None, strict_slashes=None): + def __init__( + self, + url_prefix: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, + ): """ Create a new Blueprint Group @@ -65,13 +74,13 @@ class BlueprintGroup(MutableSequence): inherited by each of the Blueprint :param strict_slashes: URL Strict slash behavior indicator """ - self._blueprints = [] + self._blueprints: List[Blueprint] = [] self._url_prefix = url_prefix self._version = version self._strict_slashes = strict_slashes @property - def url_prefix(self) -> str: + def url_prefix(self) -> Optional[Union[int, str, float]]: """ Retrieve the URL prefix being used for the Current Blueprint Group diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 2d1e6a1d..aec661b2 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -70,7 +70,7 @@ class Blueprint(BaseSanic): name: str, url_prefix: Optional[str] = None, host: Optional[str] = None, - version: Optional[int] = None, + version: Optional[Union[int, str, float]] = None, strict_slashes: Optional[bool] = None, ): super().__init__() diff --git a/sanic/http.py b/sanic/http.py index 19ed89d3..24e45d4b 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -82,6 +82,7 @@ class Http: "request_max_size", "response", "response_func", + "response_size", "response_bytes_left", "upgrade_websocket", ] @@ -272,6 +273,7 @@ class Http: size = len(data) headers = res.headers status = res.status + self.response_size = size if not isinstance(status, int) or status < 200: raise RuntimeError(f"Invalid response status {status!r}") @@ -426,7 +428,9 @@ class Http: req, res = self.request, self.response extra = { "status": getattr(res, "status", 0), - "byte": getattr(self, "response_bytes_left", -1), + "byte": getattr( + self, "response_bytes_left", getattr(self, "response_size", -1) + ), "host": "UNKNOWN", "request": "nil", } diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index be494a2a..0b5df17e 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -45,7 +45,7 @@ class RouteMixin: host: Optional[str] = None, strict_slashes: Optional[bool] = None, stream: bool = False, - version: Optional[int] = None, + version: Optional[Union[int, str, float]] = None, name: Optional[str] = None, ignore_body: bool = False, apply: bool = True, diff --git a/sanic/router.py b/sanic/router.py index 83067c27..47dd69f4 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -33,7 +33,7 @@ class Router(BaseRouter): return self.resolve( path=path, method=method, - extra={"host": host}, + extra={"host": host} if host else None, ) except RoutingNotFound as e: raise NotFound("Requested URL {} not found".format(e.path)) @@ -161,7 +161,7 @@ class Router(BaseRouter): @property def routes_all(self): - return self.routes + return {route.parts: route for route in self.routes} @property def routes_static(self): diff --git a/sanic/signals.py b/sanic/signals.py index 0e6f73f1..87263f1f 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -5,7 +5,7 @@ import asyncio from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union -from sanic_routing import BaseRouter, Route # type: ignore +from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore from sanic_routing.exceptions import NotFound # type: ignore from sanic_routing.utils import path_to_parts # type: ignore @@ -20,17 +20,11 @@ RESERVED_NAMESPACES = ( class Signal(Route): - def get_handler(self, raw_path, method, _): - method = method or self.router.DEFAULT_METHOD - raw_path = raw_path.lstrip(self.router.delimiter) - try: - return self.handlers[raw_path][method] - except (IndexError, KeyError): - raise self.router.method_handler_exception( - f"Method '{method}' not found on {self}", - method=method, - allowed_methods=set(self.methods[raw_path]), - ) + ... + + +class SignalGroup(RouteGroup): + ... class SignalRouter(BaseRouter): @@ -38,6 +32,7 @@ class SignalRouter(BaseRouter): super().__init__( delimiter=".", route_class=Signal, + group_class=SignalGroup, stacking=True, ) self.ctx.loop = None @@ -49,7 +44,13 @@ class SignalRouter(BaseRouter): ): extra = condition or {} try: - return self.resolve(f".{event}", extra=extra) + group, param_basket = self.find_route( + f".{event}", + self.DEFAULT_METHOD, + self, + {"__params__": {}}, + extra=extra, + ) except NotFound: message = "Could not find signal %s" terms: List[Union[str, Optional[Dict[str, str]]]] = [event] @@ -58,16 +59,20 @@ class SignalRouter(BaseRouter): terms.append(extra) raise NotFound(message % tuple(terms)) + params = param_basket.pop("__params__") + return group, [route.handler for route in group], params + async def _dispatch( self, event: str, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, ) -> None: - signal, handlers, params = self.get(event, condition=condition) + group, handlers, params = self.get(event, condition=condition) - signal_event = signal.ctx.event - signal_event.set() + events = [signal.ctx.event for signal in group] + for signal_event in events: + signal_event.set() if context: params.update(context) @@ -78,7 +83,8 @@ class SignalRouter(BaseRouter): if isawaitable(maybe_coroutine): await maybe_coroutine finally: - signal_event.clear() + for signal_event in events: + signal_event.clear() async def dispatch( self, @@ -116,7 +122,7 @@ class SignalRouter(BaseRouter): handler, requirements=condition, name=name, - overwrite=True, + append=True, ) # type: ignore def finalize(self, do_compile: bool = True): @@ -125,7 +131,7 @@ class SignalRouter(BaseRouter): except RuntimeError: raise RuntimeError("Cannot finalize signals outside of event loop") - for signal in self.routes.values(): + for signal in self.routes: signal.ctx.event = asyncio.Event() return super().finalize(do_compile=do_compile) diff --git a/setup.py b/setup.py index b2d31cc1..3a43b11e 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ ujson = "ujson>=1.35" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency requirements = [ - "sanic-routing", + "sanic-routing>=0.6.0", "httptools>=0.0.10", uvloop, ujson, diff --git a/tests/test_config.py b/tests/test_config.py index b1336497..ce790800 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -80,7 +80,7 @@ def test_dont_load_env(): del environ["SANIC_TEST_ANSWER"] -@pytest.mark.parametrize('load_env', [None, False, "", "MYAPP_"]) +@pytest.mark.parametrize("load_env", [None, False, "", "MYAPP_"]) def test_load_env_deprecation(load_env): with pytest.warns(DeprecationWarning, match=r"21\.12"): _ = Sanic(name=__name__, load_env=load_env) @@ -93,7 +93,7 @@ def test_load_env_prefix(): del environ["MYAPP_TEST_ANSWER"] -@pytest.mark.parametrize('env_prefix', [None, ""]) +@pytest.mark.parametrize("env_prefix", [None, ""]) def test_empty_load_env_prefix(env_prefix): environ["SANIC_TEST_ANSWER"] = "42" app = Sanic(name=__name__, env_prefix=env_prefix) diff --git a/tests/test_named_routes.py b/tests/test_named_routes.py index e748d529..39108594 100644 --- a/tests/test_named_routes.py +++ b/tests/test_named_routes.py @@ -209,13 +209,13 @@ def test_named_static_routes(): return text("OK2") assert app.router.routes_all[("test",)].name == "app.route_test" - assert app.router.routes_static[("test",)].name == "app.route_test" + assert app.router.routes_static[("test",)][0].name == "app.route_test" assert app.url_for("route_test") == "/test" with pytest.raises(URLBuildError): app.url_for("handler1") assert app.router.routes_all[("pizazz",)].name == "app.route_pizazz" - assert app.router.routes_static[("pizazz",)].name == "app.route_pizazz" + assert app.router.routes_static[("pizazz",)][0].name == "app.route_pizazz" assert app.url_for("route_pizazz") == "/pizazz" with pytest.raises(URLBuildError): app.url_for("handler2") @@ -347,13 +347,13 @@ def test_static_add_named_route(): app.add_route(handler2, "/test2", name="route_test2") assert app.router.routes_all[("test",)].name == "app.route_test" - assert app.router.routes_static[("test",)].name == "app.route_test" + assert app.router.routes_static[("test",)][0].name == "app.route_test" assert app.url_for("route_test") == "/test" with pytest.raises(URLBuildError): app.url_for("handler1") assert app.router.routes_all[("test2",)].name == "app.route_test2" - assert app.router.routes_static[("test2",)].name == "app.route_test2" + assert app.router.routes_static[("test2",)][0].name == "app.route_test2" assert app.url_for("route_test2") == "/test2" with pytest.raises(URLBuildError): app.url_for("handler2") diff --git a/tests/test_request.py b/tests/test_request.py index 0cbf0994..049b152f 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -106,7 +106,7 @@ def test_route_assigned_to_request(app): return response.empty() request, _ = app.test_client.get("/") - assert request.route is list(app.router.routes.values())[0] + assert request.route is list(app.router.routes)[0] def test_protocol_attribute(app): diff --git a/tests/test_requests.py b/tests/test_requests.py index 665f936d..b459a605 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -253,6 +253,31 @@ async def test_empty_json_asgi(app): assert response.body == b"null" +def test_echo_json(app): + @app.post("/") + async def handler(request): + return json(request.json) + + data = {"foo": "bar"} + request, response = app.test_client.post("/", json=data) + + assert response.status == 200 + assert response.json == data + + +@pytest.mark.asyncio +async def test_echo_json_asgi(app): + @app.post("/") + async def handler(request): + return json(request.json) + + data = {"foo": "bar"} + request, response = await app.asgi_client.post("/", json=data) + + assert response.status == 200 + assert response.json == data + + def test_invalid_json(app): @app.post("/") async def handler(request): @@ -291,18 +316,18 @@ def test_query_string(app): assert request.args.getlist("test1") == ["1"] assert request.args.get("test3", default="My value") == "My value" + def test_popped_stays_popped(app): @app.route("/") async def handler(request): return text("OK") - request, response = app.test_client.get( - "/", params=[("test1", "1")] - ) + request, response = app.test_client.get("/", params=[("test1", "1")]) assert request.args.pop("test1") == ["1"] assert "test1" not in request.args + @pytest.mark.asyncio async def test_query_string_asgi(app): @app.route("/") @@ -2170,3 +2195,72 @@ def test_safe_method_with_body(app): assert request.body == data.encode("utf-8") assert request.json.get("test") == "OK" assert response.body == b"OK" + + +def test_conflicting_body_methods_overload(app): + @app.put("/") + @app.put("/p/") + @app.put("/p/") + async def put(request, foo=None): + return json( + {"name": request.route.name, "body": str(request.body), "foo": foo} + ) + + @app.delete("/p/") + async def delete(request, foo): + return json( + {"name": request.route.name, "body": str(request.body), "foo": foo} + ) + + payload = {"test": "OK"} + data = str(json_dumps(payload).encode()) + + _, response = app.test_client.put("/", json=payload) + assert response.status == 200 + assert response.json == { + "name": "test_conflicting_body_methods_overload.put", + "foo": None, + "body": data, + } + _, response = app.test_client.put("/p", json=payload) + assert response.status == 200 + assert response.json == { + "name": "test_conflicting_body_methods_overload.put", + "foo": None, + "body": data, + } + _, response = app.test_client.put("/p/test", json=payload) + assert response.status == 200 + assert response.json == { + "name": "test_conflicting_body_methods_overload.put", + "foo": "test", + "body": data, + } + _, response = app.test_client.delete("/p/test") + assert response.status == 200 + assert response.json == { + "name": "test_conflicting_body_methods_overload.delete", + "foo": "test", + "body": str("".encode()), + } + + +def test_handler_overload(app): + @app.get( + "/long/sub/route/param_a//param_b/" + ) + @app.post("/long/sub/route/") + def handler(request, **kwargs): + return json(kwargs) + + _, response = app.test_client.get( + "/long/sub/route/param_a/foo/param_b/bar" + ) + assert response.status == 200 + assert response.json == { + "param_a": "foo", + "param_b": "bar", + } + _, response = app.test_client.post("/long/sub/route") + assert response.status == 200 + assert response.json == {} diff --git a/tests/test_response.py b/tests/test_response.py index 4c107eae..f9803a41 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -60,7 +60,9 @@ def test_method_not_allowed(): } request, response = app.test_client.post("/") - assert set(response.headers["Allow"].split(", ")) == {"GET", "HEAD"} + assert set(response.headers["Allow"].split(", ")) == { + "GET", + } app.router.reset() @@ -73,7 +75,6 @@ def test_method_not_allowed(): assert set(response.headers["Allow"].split(", ")) == { "GET", "POST", - "HEAD", } assert response.headers["Content-Length"] == "0" @@ -82,7 +83,6 @@ def test_method_not_allowed(): assert set(response.headers["Allow"].split(", ")) == { "GET", "POST", - "HEAD", } assert response.headers["Content-Length"] == "0" diff --git a/tests/test_routes.py b/tests/test_routes.py index 4b2927ef..06b4d799 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -543,9 +543,6 @@ def test_dynamic_route_regex(app): async def handler(request, folder_id): return text("OK") - app.router.finalize() - print(app.router.find_route_src) - request, response = app.test_client.get("/folder/test") assert response.status == 200 @@ -587,6 +584,9 @@ def test_dynamic_route_path(app): async def handler(request, path): return text("OK") + app.router.finalize() + print(app.router.find_route_src) + request, response = app.test_client.get("/path/1/info") assert response.status == 200 @@ -1008,14 +1008,8 @@ def test_unmergeable_overload_routes(app): async def handler2(request): return text("OK1") - assert ( - len( - dict(list(app.router.static_routes.values())[0].handlers)[ - "overload_whole" - ] - ) - == 3 - ) + assert len(app.router.static_routes) == 1 + assert len(app.router.static_routes[("overload_whole",)].methods) == 3 request, response = app.test_client.get("/overload_whole") assert response.text == "OK1" diff --git a/tests/test_signals.py b/tests/test_signals.py index 915d92d6..05fd6d11 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -28,7 +28,8 @@ def test_add_signal_decorator(app): async def async_signal(*_): ... - assert len(app.signal_router.routes) == 1 + assert len(app.signal_router.routes) == 2 + assert len(app.signal_router.dynamic_routes) == 1 @pytest.mark.parametrize( @@ -79,13 +80,13 @@ async def test_dispatch_signal_triggers_triggers_event(app): def sync_signal(*args): nonlocal app nonlocal counter - signal, *_ = app.signal_router.get("foo.bar.baz") - counter += signal.ctx.event.is_set() + group, *_ = app.signal_router.get("foo.bar.baz") + for signal in group: + counter += signal.ctx.event.is_set() app.signal_router.finalize() await app.dispatch("foo.bar.baz") - signal, *_ = app.signal_router.get("foo.bar.baz") assert counter == 1 @@ -224,7 +225,7 @@ async def test_dispatch_signal_triggers_event_on_bp(app): app.blueprint(bp) app.signal_router.finalize() - signal, *_ = app.signal_router.get( + signal_group, *_ = app.signal_router.get( "foo.bar.baz", condition={"blueprint": "bp"} ) @@ -233,7 +234,8 @@ async def test_dispatch_signal_triggers_event_on_bp(app): assert isawaitable(waiter) fut = asyncio.ensure_future(do_wait()) - signal.ctx.event.set() + for signal in signal_group: + signal.ctx.event.set() await fut assert bp_counter == 1 diff --git a/tests/test_static.py b/tests/test_static.py index acef3468..d702ca69 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -490,3 +490,20 @@ def test_no_stack_trace_on_not_found(app, static_file_directory, caplog): assert counter[logging.INFO] == 5 assert logging.ERROR not in counter assert response.text == "No file: /static/non_existing_file.file" + + +def test_multiple_statics(app, static_file_directory): + app.static("/file", get_file_path(static_file_directory, "test.file")) + app.static("/png", get_file_path(static_file_directory, "python.png")) + + _, response = app.test_client.get("/file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + _, response = app.test_client.get("/png") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "python.png" + )