From 27f64ddae2bd5a3ce40387f493d3e1b1068cdac7 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 1 Mar 2021 15:30:52 +0200 Subject: [PATCH] Router tweaks (#2031) * Add trailing slash when defined and strict_slashes * Add partial matching, and fix some issues with url_for * Cover additional edge cases * cleanup tests --- sanic/app.py | 36 +++++----- sanic/mixins/routes.py | 24 ++++--- sanic/request.py | 5 ++ sanic/router.py | 49 +++++++------- .../test_route_resolution_benchmark.py | 4 +- tests/test_app.py | 4 +- tests/test_blueprints.py | 2 +- tests/test_exceptions_handler.py | 1 - tests/test_reloader.py | 1 - tests/test_request.py | 9 +++ tests/test_request_stream.py | 1 - tests/test_routes.py | 66 ++++++++++++++++++- tests/test_url_building.py | 20 ++++++ tests/test_url_for.py | 16 +++++ 14 files changed, 175 insertions(+), 63 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 5dc2efeb..c3c2b9ea 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -393,17 +393,22 @@ class Sanic(BaseSanic): if getattr(route.ctx, "static", None): filename = kwargs.pop("filename", "") # it's static folder - if "file_uri" in uri: - folder_ = uri.split(" Tuple[RouteHandler, Dict[str, Any], str, str, bool]: + ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: try: route, handler, params = self.resolve( path=path, @@ -50,14 +51,14 @@ class Router(BaseRouter): ) return ( + route, handler, params, - route.path, - route.name, - route.ctx.ignore_body, ) - def get(self, request: Request): + def get( # type: ignore + self, request: Request + ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: """ Retrieve a `Route` object containg the details about how to handle a response for a given request @@ -66,14 +67,13 @@ class Router(BaseRouter): :type request: Request :return: details needed for handling the request and returning the correct response - :rtype: Tuple[ RouteHandler, Tuple[Any, ...], Dict[str, Any], str, str, - Optional[str], bool, ] + :rtype: Tuple[ Route, RouteHandler, Dict[str, Any]] """ return self._get( request.path, request.method, request.headers.get("host") ) - def add( + def add( # type: ignore self, uri: str, methods: Iterable[str], @@ -138,7 +138,7 @@ class Router(BaseRouter): if host: params.update({"requirements": {"host": host}}) - route = super().add(**params) + route = super().add(**params) # type: ignore route.ctx.ignore_body = ignore_body route.ctx.stream = stream route.ctx.hosts = hosts @@ -150,23 +150,6 @@ class Router(BaseRouter): return routes[0] return routes - def is_stream_handler(self, request) -> bool: - """ - Handler for request is stream or not. - - :param request: Request object - :return: bool - """ - try: - handler = self.get(request)[0] - except (NotFound, MethodNotSupported): - return False - if hasattr(handler, "view_class") and hasattr( - handler.view_class, request.method.lower() - ): - handler = getattr(handler.view_class, request.method.lower()) - return hasattr(handler, "is_stream") - @lru_cache(maxsize=ROUTER_CACHE_SIZE) def find_route_by_view_name(self, view_name, name=None): """ @@ -204,3 +187,15 @@ class Router(BaseRouter): @property def routes_regex(self): return self.regex_routes + + def finalize(self, *args, **kwargs): + super().finalize(*args, **kwargs) + + for route in self.dynamic_routes.values(): + if any( + label.startswith("__") and label not in ALLOWED_LABELS + for label in route.labels + ): + raise SanicException( + f"Invalid route: {route}. Parameter names cannot use '__'." + ) diff --git a/tests/benchmark/test_route_resolution_benchmark.py b/tests/benchmark/test_route_resolution_benchmark.py index 467254a4..7d857bb3 100644 --- a/tests/benchmark/test_route_resolution_benchmark.py +++ b/tests/benchmark/test_route_resolution_benchmark.py @@ -39,7 +39,7 @@ class TestSanicRouteResolution: iterations=1000, rounds=1000, ) - assert await result[0](None) == 1 + assert await result[1](None) == 1 @mark.asyncio async def test_resolve_route_with_typed_args( @@ -72,4 +72,4 @@ class TestSanicRouteResolution: iterations=1000, rounds=1000, ) - assert await result[0](None) == 1 + assert await result[1](None) == 1 diff --git a/tests/test_app.py b/tests/test_app.py index 6ed0293e..edcb3b1d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -4,7 +4,7 @@ import sys from inspect import isawaitable from os import environ -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -123,7 +123,7 @@ def test_app_route_raise_value_error(app): def test_app_handle_request_handler_is_none(app, monkeypatch): def mockreturn(*args, **kwargs): - return None, {}, "", "", False + return Mock(), None, {} # Not sure how to make app.router.get() return None, so use mock here. monkeypatch.setattr(app.router, "get", mockreturn) diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index 88055b57..d7bf231c 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -752,7 +752,7 @@ def test_static_blueprint_name(static_file_directory, file_name): app.blueprint(bp) uri = app.url_for("static", name="static.testing") - assert uri == "/static/test.file" + assert uri == "/static/test.file/" _, response = app.test_client.get("/static/test.file") assert response.status == 404 diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index 9c724182..e6fd42eb 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -126,7 +126,6 @@ def test_html_traceback_output_in_debug_mode(): assert response.status == 500 soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - print(html) assert "response = handler(request, **kwargs)" in html assert "handler_4" in html diff --git a/tests/test_reloader.py b/tests/test_reloader.py index d2e5ff6b..90b29343 100644 --- a/tests/test_reloader.py +++ b/tests/test_reloader.py @@ -59,7 +59,6 @@ def write_app(filename, **runargs): def scanner(proc): for line in proc.stdout: line = line.decode().strip() - print(">", line) if line.startswith("complete"): yield line diff --git a/tests/test_request.py b/tests/test_request.py index 2ac35efe..43d96386 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -74,3 +74,12 @@ def test_custom_generator(): "/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"} ) assert request.id == REQUEST_ID * 2 + + +def test_route_assigned_to_request(app): + @app.get("/") + async def get(request): + return response.empty() + + request, _ = app.test_client.get("/") + assert request.route is list(app.router.routes.values())[0] diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 0261bc98..e39524c6 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -646,7 +646,6 @@ def test_streaming_echo(): data = await reader.read(4096) assert data buffer += data - print(res) assert buffer[size : size + 2] == b"\r\n" ret, buffer = buffer[:size], buffer[size + 2 :] return ret diff --git a/tests/test_routes.py b/tests/test_routes.py index f3477400..08a5e92c 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,15 +1,20 @@ import asyncio +import re from unittest.mock import Mock import pytest -from sanic_routing.exceptions import ParameterNameConflicts, RouteExists +from sanic_routing.exceptions import ( + InvalidUsage, + ParameterNameConflicts, + RouteExists, +) from sanic_testing.testing import SanicTestClient from sanic import Blueprint, Sanic from sanic.constants import HTTP_METHODS -from sanic.exceptions import NotFound +from sanic.exceptions import NotFound, SanicException from sanic.request import Request from sanic.response import json, text @@ -189,7 +194,6 @@ def test_versioned_routes_get(app, method): return text("OK") else: - print(func) raise Exception(f"Method: {method} is not callable") client_method = getattr(app.test_client, method) @@ -1113,3 +1117,59 @@ def test_route_invalid_host(app): assert str(excinfo.value) == ( "Expected either string or Iterable of " "host strings, not {!r}" ).format(host) + + +def test_route_with_regex_group(app): + @app.route("/path/to/") + async def handler(request, ext): + return text(ext) + + _, response = app.test_client.get("/path/to/file.txt") + assert response.text == "txt" + + +def test_route_with_regex_named_group(app): + @app.route(r"/path/to/txt)>") + async def handler(request, ext): + return text(ext) + + _, response = app.test_client.get("/path/to/file.txt") + assert response.text == "txt" + + +def test_route_with_regex_named_group_invalid(app): + @app.route(r"/path/to/txt)>") + async def handler(request, ext): + return text(ext) + + with pytest.raises(InvalidUsage) as e: + app.router.finalize() + + assert e.match( + re.escape("Named group (wrong) must match your named parameter (ext)") + ) + + +def test_route_with_regex_group_ambiguous(app): + @app.route("/path/to/") + async def handler(request, ext): + return text(ext) + + with pytest.raises(InvalidUsage) as e: + app.router.finalize() + + assert e.match( + re.escape( + "Could not compile pattern file(?:\.)(txt). Try using a named " + "group instead: '(?Pyour_matching_group)'" + ) + ) + + +def test_route_with_bad_named_param(app): + @app.route("/foo/<__bar__>") + async def handler(request): + return text("...") + + with pytest.raises(SanicException): + app.router.finalize() diff --git a/tests/test_url_building.py b/tests/test_url_building.py index 5d9fcf41..2898dbd9 100644 --- a/tests/test_url_building.py +++ b/tests/test_url_building.py @@ -344,3 +344,23 @@ def test_methodview_naming(methodview_app): assert viewone_url == "/view_one" assert viewtwo_url == "/view_two" + + +@pytest.mark.parametrize( + "path,version,expected", + ( + ("/foo", 1, "/v1/foo"), + ("/foo", 1.1, "/v1.1/foo"), + ("/foo", "1", "/v1/foo"), + ("/foo", "1.1", "/v1.1/foo"), + ("/foo", "1.0.1", "/v1.0.1/foo"), + ("/foo", "v1.0.1", "/v1.0.1/foo"), + ), +) +def test_versioning(app, path, version, expected): + @app.route(path, version=version) + def handler(*_): + ... + + url = app.url_for("handler") + assert url == expected diff --git a/tests/test_url_for.py b/tests/test_url_for.py index bf9a4722..d623cc4a 100644 --- a/tests/test_url_for.py +++ b/tests/test_url_for.py @@ -88,3 +88,19 @@ def test_websocket_bp_route_name(app): # TODO: add test with a route with multiple hosts # TODO: add test with a route with _host in url_for +@pytest.mark.parametrize( + "path,strict,expected", + ( + ("/foo", False, "/foo"), + ("/foo/", False, "/foo"), + ("/foo", True, "/foo"), + ("/foo/", True, "/foo/"), + ), +) +def test_trailing_slash_url_for(app, path, strict, expected): + @app.route(path, strict_slashes=strict) + def handler(*_): + ... + + url = app.url_for("handler") + assert url == expected