Router tweaks (#2031)
* Add trailing slash when defined and strict_slashes * Add partial matching, and fix some issues with url_for * Cover additional edge cases * cleanup tests
This commit is contained in:
		
							
								
								
									
										36
									
								
								sanic/app.py
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								sanic/app.py
									
									
									
									
									
								
							| @@ -393,17 +393,22 @@ class Sanic(BaseSanic): | ||||
|         if getattr(route.ctx, "static", None): | ||||
|             filename = kwargs.pop("filename", "") | ||||
|             # it's static folder | ||||
|             if "file_uri" in uri: | ||||
|                 folder_ = uri.split("<file_uri:", 1)[0] | ||||
|             if "__file_uri__" in uri: | ||||
|                 folder_ = uri.split("<__file_uri__:", 1)[0] | ||||
|                 if folder_.endswith("/"): | ||||
|                     folder_ = folder_[:-1] | ||||
|  | ||||
|                 if filename.startswith("/"): | ||||
|                     filename = filename[1:] | ||||
|  | ||||
|                 kwargs["file_uri"] = filename | ||||
|                 kwargs["__file_uri__"] = filename | ||||
|  | ||||
|         if uri != "/" and uri.endswith("/"): | ||||
|         if ( | ||||
|             uri != "/" | ||||
|             and uri.endswith("/") | ||||
|             and not route.strict | ||||
|             and not route.raw_path[:-1] | ||||
|         ): | ||||
|             uri = uri[:-1] | ||||
|  | ||||
|         if not uri.startswith("/"): | ||||
| @@ -573,25 +578,27 @@ class Sanic(BaseSanic): | ||||
|         # Define `response` var here to remove warnings about | ||||
|         # allocation before assignment below. | ||||
|         response = None | ||||
|         name = None | ||||
|         try: | ||||
|             # Fetch handler from router | ||||
|             ( | ||||
|                 route, | ||||
|                 handler, | ||||
|                 kwargs, | ||||
|                 uri, | ||||
|                 name, | ||||
|                 ignore_body, | ||||
|             ) = self.router.get(request) | ||||
|             request.name = name | ||||
|  | ||||
|             request._match_info = kwargs | ||||
|             request.route = route | ||||
|             request.name = route.name | ||||
|             request.uri_template = f"/{route.path}" | ||||
|             request.endpoint = request.name | ||||
|  | ||||
|             if ( | ||||
|                 request.stream | ||||
|                 and request.stream.request_body | ||||
|                 and not ignore_body | ||||
|                 and not route.ctx.ignore_body | ||||
|             ): | ||||
|                 if self.router.is_stream_handler(request): | ||||
|  | ||||
|                 if hasattr(handler, "is_stream"): | ||||
|                     # Streaming handler: lift the size limit | ||||
|                     request.stream.request_max_size = float("inf") | ||||
|                 else: | ||||
| @@ -602,15 +609,15 @@ class Sanic(BaseSanic): | ||||
|             # Request Middleware | ||||
|             # -------------------------------------------- # | ||||
|             response = await self._run_request_middleware( | ||||
|                 request, request_name=name | ||||
|                 request, request_name=route.name | ||||
|             ) | ||||
|  | ||||
|             # No middleware results | ||||
|             if not response: | ||||
|                 # -------------------------------------------- # | ||||
|                 # Execute Handler | ||||
|                 # -------------------------------------------- # | ||||
|  | ||||
|                 request.uri_template = f"/{uri}" | ||||
|                 if handler is None: | ||||
|                     raise ServerError( | ||||
|                         ( | ||||
| @@ -619,12 +626,11 @@ class Sanic(BaseSanic): | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
|                 request.endpoint = request.name | ||||
|  | ||||
|                 # Run response handler | ||||
|                 response = handler(request, **kwargs) | ||||
|                 if isawaitable(response): | ||||
|                     response = await response | ||||
|  | ||||
|             if response: | ||||
|                 response = await request.respond(response) | ||||
|             else: | ||||
|   | ||||
| @@ -611,7 +611,7 @@ class RouteMixin: | ||||
|                 else: | ||||
|                     break | ||||
|  | ||||
|         if not name:  # noq | ||||
|         if not name:  # noqa | ||||
|             raise ValueError("Could not generate a name for handler") | ||||
|  | ||||
|         if not name.startswith(f"{self.name}."): | ||||
| @@ -627,19 +627,19 @@ class RouteMixin: | ||||
|         stream_large_files, | ||||
|         request, | ||||
|         content_type=None, | ||||
|         file_uri=None, | ||||
|         __file_uri__=None, | ||||
|     ): | ||||
|         # Using this to determine if the URL is trying to break out of the path | ||||
|         # served.  os.path.realpath seems to be very slow | ||||
|         if file_uri and "../" in file_uri: | ||||
|         if __file_uri__ and "../" in __file_uri__: | ||||
|             raise InvalidUsage("Invalid URL") | ||||
|         # Merge served directory and requested file if provided | ||||
|         # Strip all / that in the beginning of the URL to help prevent python | ||||
|         # from herping a derp and treating the uri as an absolute path | ||||
|         root_path = file_path = file_or_directory | ||||
|         if file_uri: | ||||
|         if __file_uri__: | ||||
|             file_path = path.join( | ||||
|                 file_or_directory, sub("^[/]*", "", file_uri) | ||||
|                 file_or_directory, sub("^[/]*", "", __file_uri__) | ||||
|             ) | ||||
|  | ||||
|         # URL decode the path sent by the browser otherwise we won't be able to | ||||
| @@ -648,10 +648,12 @@ class RouteMixin: | ||||
|         if not file_path.startswith(path.abspath(unquote(root_path))): | ||||
|             error_logger.exception( | ||||
|                 f"File not found: path={file_or_directory}, " | ||||
|                 f"relative_url={file_uri}" | ||||
|                 f"relative_url={__file_uri__}" | ||||
|             ) | ||||
|             raise FileNotFound( | ||||
|                 "File not found", path=file_or_directory, relative_url=file_uri | ||||
|                 "File not found", | ||||
|                 path=file_or_directory, | ||||
|                 relative_url=__file_uri__, | ||||
|             ) | ||||
|         try: | ||||
|             headers = {} | ||||
| @@ -719,10 +721,12 @@ class RouteMixin: | ||||
|         except Exception: | ||||
|             error_logger.exception( | ||||
|                 f"File not found: path={file_or_directory}, " | ||||
|                 f"relative_url={file_uri}" | ||||
|                 f"relative_url={__file_uri__}" | ||||
|             ) | ||||
|             raise FileNotFound( | ||||
|                 "File not found", path=file_or_directory, relative_url=file_uri | ||||
|                 "File not found", | ||||
|                 path=file_or_directory, | ||||
|                 relative_url=__file_uri__, | ||||
|             ) | ||||
|  | ||||
|     def _register_static( | ||||
| @@ -772,7 +776,7 @@ class RouteMixin: | ||||
|         # If we're not trying to match a file directly, | ||||
|         # serve from the folder | ||||
|         if not path.isfile(file_or_directory): | ||||
|             uri += "/<file_uri>" | ||||
|             uri += "/<__file_uri__>" | ||||
|  | ||||
|         # special prefix for static files | ||||
|         # if not static.name.startswith("_static_"): | ||||
|   | ||||
| @@ -12,6 +12,8 @@ from typing import ( | ||||
|     Union, | ||||
| ) | ||||
|  | ||||
| from sanic_routing.route import Route  # type: ignore | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from sanic.server import ConnInfo | ||||
| @@ -104,6 +106,7 @@ class Request: | ||||
|         "parsed_forwarded", | ||||
|         "raw_url", | ||||
|         "request_middleware_started", | ||||
|         "route", | ||||
|         "stream", | ||||
|         "transport", | ||||
|         "uri_template", | ||||
| @@ -151,6 +154,7 @@ class Request: | ||||
|         self._match_info: Dict[str, Any] = {} | ||||
|         self.stream: Optional[Http] = None | ||||
|         self.endpoint: Optional[str] = None | ||||
|         self.route: Optional[Route] = None | ||||
|  | ||||
|     def __repr__(self): | ||||
|         class_name = self.__class__.__name__ | ||||
| @@ -431,6 +435,7 @@ class Request: | ||||
|         :return: Incoming cookies on the request | ||||
|         :rtype: Dict[str, str] | ||||
|         """ | ||||
|  | ||||
|         if self._cookies is None: | ||||
|             cookie = self.headers.get("Cookie") | ||||
|             if cookie is not None: | ||||
|   | ||||
| @@ -9,12 +9,13 @@ from sanic_routing.exceptions import ( | ||||
| from sanic_routing.route import Route  # type: ignore | ||||
|  | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.exceptions import MethodNotSupported, NotFound | ||||
| from sanic.exceptions import MethodNotSupported, NotFound, SanicException | ||||
| from sanic.handlers import RouteHandler | ||||
| from sanic.request import Request | ||||
|  | ||||
|  | ||||
| ROUTER_CACHE_SIZE = 1024 | ||||
| ALLOWED_LABELS = ("__file_uri__",) | ||||
|  | ||||
|  | ||||
| class Router(BaseRouter): | ||||
| @@ -33,7 +34,7 @@ class Router(BaseRouter): | ||||
|     @lru_cache(maxsize=ROUTER_CACHE_SIZE) | ||||
|     def _get( | ||||
|         self, path, method, host | ||||
|     ) -> Tuple[RouteHandler, Dict[str, Any], str, str, bool]: | ||||
|     ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: | ||||
|         try: | ||||
|             route, handler, params = self.resolve( | ||||
|                 path=path, | ||||
| @@ -50,14 +51,14 @@ class Router(BaseRouter): | ||||
|             ) | ||||
|  | ||||
|         return ( | ||||
|             route, | ||||
|             handler, | ||||
|             params, | ||||
|             route.path, | ||||
|             route.name, | ||||
|             route.ctx.ignore_body, | ||||
|         ) | ||||
|  | ||||
|     def get(self, request: Request): | ||||
|     def get(  # type: ignore | ||||
|         self, request: Request | ||||
|     ) -> Tuple[Route, RouteHandler, Dict[str, Any]]: | ||||
|         """ | ||||
|         Retrieve a `Route` object containg the details about how to handle | ||||
|         a response for a given request | ||||
| @@ -66,14 +67,13 @@ class Router(BaseRouter): | ||||
|         :type request: Request | ||||
|         :return: details needed for handling the request and returning the | ||||
|             correct response | ||||
|         :rtype: Tuple[ RouteHandler, Tuple[Any, ...], Dict[str, Any], str, str, | ||||
|             Optional[str], bool, ] | ||||
|         :rtype: Tuple[ Route, RouteHandler, Dict[str, Any]] | ||||
|         """ | ||||
|         return self._get( | ||||
|             request.path, request.method, request.headers.get("host") | ||||
|         ) | ||||
|  | ||||
|     def add( | ||||
|     def add(  # type: ignore | ||||
|         self, | ||||
|         uri: str, | ||||
|         methods: Iterable[str], | ||||
| @@ -138,7 +138,7 @@ class Router(BaseRouter): | ||||
|             if host: | ||||
|                 params.update({"requirements": {"host": host}}) | ||||
|  | ||||
|             route = super().add(**params) | ||||
|             route = super().add(**params)  # type: ignore | ||||
|             route.ctx.ignore_body = ignore_body | ||||
|             route.ctx.stream = stream | ||||
|             route.ctx.hosts = hosts | ||||
| @@ -150,23 +150,6 @@ class Router(BaseRouter): | ||||
|             return routes[0] | ||||
|         return routes | ||||
|  | ||||
|     def is_stream_handler(self, request) -> bool: | ||||
|         """ | ||||
|         Handler for request is stream or not. | ||||
|  | ||||
|         :param request: Request object | ||||
|         :return: bool | ||||
|         """ | ||||
|         try: | ||||
|             handler = self.get(request)[0] | ||||
|         except (NotFound, MethodNotSupported): | ||||
|             return False | ||||
|         if hasattr(handler, "view_class") and hasattr( | ||||
|             handler.view_class, request.method.lower() | ||||
|         ): | ||||
|             handler = getattr(handler.view_class, request.method.lower()) | ||||
|         return hasattr(handler, "is_stream") | ||||
|  | ||||
|     @lru_cache(maxsize=ROUTER_CACHE_SIZE) | ||||
|     def find_route_by_view_name(self, view_name, name=None): | ||||
|         """ | ||||
| @@ -204,3 +187,15 @@ class Router(BaseRouter): | ||||
|     @property | ||||
|     def routes_regex(self): | ||||
|         return self.regex_routes | ||||
|  | ||||
|     def finalize(self, *args, **kwargs): | ||||
|         super().finalize(*args, **kwargs) | ||||
|  | ||||
|         for route in self.dynamic_routes.values(): | ||||
|             if any( | ||||
|                 label.startswith("__") and label not in ALLOWED_LABELS | ||||
|                 for label in route.labels | ||||
|             ): | ||||
|                 raise SanicException( | ||||
|                     f"Invalid route: {route}. Parameter names cannot use '__'." | ||||
|                 ) | ||||
|   | ||||
| @@ -39,7 +39,7 @@ class TestSanicRouteResolution: | ||||
|             iterations=1000, | ||||
|             rounds=1000, | ||||
|         ) | ||||
|         assert await result[0](None) == 1 | ||||
|         assert await result[1](None) == 1 | ||||
|  | ||||
|     @mark.asyncio | ||||
|     async def test_resolve_route_with_typed_args( | ||||
| @@ -72,4 +72,4 @@ class TestSanicRouteResolution: | ||||
|             iterations=1000, | ||||
|             rounds=1000, | ||||
|         ) | ||||
|         assert await result[0](None) == 1 | ||||
|         assert await result[1](None) == 1 | ||||
|   | ||||
| @@ -4,7 +4,7 @@ import sys | ||||
|  | ||||
| from inspect import isawaitable | ||||
| from os import environ | ||||
| from unittest.mock import patch | ||||
| from unittest.mock import Mock, patch | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| @@ -123,7 +123,7 @@ def test_app_route_raise_value_error(app): | ||||
|  | ||||
| def test_app_handle_request_handler_is_none(app, monkeypatch): | ||||
|     def mockreturn(*args, **kwargs): | ||||
|         return None, {}, "", "", False | ||||
|         return Mock(), None, {} | ||||
|  | ||||
|     # Not sure how to make app.router.get() return None, so use mock here. | ||||
|     monkeypatch.setattr(app.router, "get", mockreturn) | ||||
|   | ||||
| @@ -752,7 +752,7 @@ def test_static_blueprint_name(static_file_directory, file_name): | ||||
|     app.blueprint(bp) | ||||
|  | ||||
|     uri = app.url_for("static", name="static.testing") | ||||
|     assert uri == "/static/test.file" | ||||
|     assert uri == "/static/test.file/" | ||||
|  | ||||
|     _, response = app.test_client.get("/static/test.file") | ||||
|     assert response.status == 404 | ||||
|   | ||||
| @@ -126,7 +126,6 @@ def test_html_traceback_output_in_debug_mode(): | ||||
|     assert response.status == 500 | ||||
|     soup = BeautifulSoup(response.body, "html.parser") | ||||
|     html = str(soup) | ||||
|     print(html) | ||||
|  | ||||
|     assert "response = handler(request, **kwargs)" in html | ||||
|     assert "handler_4" in html | ||||
|   | ||||
| @@ -59,7 +59,6 @@ def write_app(filename, **runargs): | ||||
| def scanner(proc): | ||||
|     for line in proc.stdout: | ||||
|         line = line.decode().strip() | ||||
|         print(">", line) | ||||
|         if line.startswith("complete"): | ||||
|             yield line | ||||
|  | ||||
|   | ||||
| @@ -74,3 +74,12 @@ def test_custom_generator(): | ||||
|         "/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"} | ||||
|     ) | ||||
|     assert request.id == REQUEST_ID * 2 | ||||
|  | ||||
|  | ||||
| def test_route_assigned_to_request(app): | ||||
|     @app.get("/") | ||||
|     async def get(request): | ||||
|         return response.empty() | ||||
|  | ||||
|     request, _ = app.test_client.get("/") | ||||
|     assert request.route is list(app.router.routes.values())[0] | ||||
|   | ||||
| @@ -646,7 +646,6 @@ def test_streaming_echo(): | ||||
|                 data = await reader.read(4096) | ||||
|                 assert data | ||||
|                 buffer += data | ||||
|                 print(res) | ||||
|             assert buffer[size : size + 2] == b"\r\n" | ||||
|             ret, buffer = buffer[:size], buffer[size + 2 :] | ||||
|             return ret | ||||
|   | ||||
| @@ -1,15 +1,20 @@ | ||||
| import asyncio | ||||
| import re | ||||
|  | ||||
| from unittest.mock import Mock | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from sanic_routing.exceptions import ParameterNameConflicts, RouteExists | ||||
| from sanic_routing.exceptions import ( | ||||
|     InvalidUsage, | ||||
|     ParameterNameConflicts, | ||||
|     RouteExists, | ||||
| ) | ||||
| from sanic_testing.testing import SanicTestClient | ||||
|  | ||||
| from sanic import Blueprint, Sanic | ||||
| from sanic.constants import HTTP_METHODS | ||||
| from sanic.exceptions import NotFound | ||||
| from sanic.exceptions import NotFound, SanicException | ||||
| from sanic.request import Request | ||||
| from sanic.response import json, text | ||||
|  | ||||
| @@ -189,7 +194,6 @@ def test_versioned_routes_get(app, method): | ||||
|             return text("OK") | ||||
|  | ||||
|     else: | ||||
|         print(func) | ||||
|         raise Exception(f"Method: {method} is not callable") | ||||
|  | ||||
|     client_method = getattr(app.test_client, method) | ||||
| @@ -1113,3 +1117,59 @@ def test_route_invalid_host(app): | ||||
|     assert str(excinfo.value) == ( | ||||
|         "Expected either string or Iterable of " "host strings, not {!r}" | ||||
|     ).format(host) | ||||
|  | ||||
|  | ||||
| def test_route_with_regex_group(app): | ||||
|     @app.route("/path/to/<ext:file\.(txt)>") | ||||
|     async def handler(request, ext): | ||||
|         return text(ext) | ||||
|  | ||||
|     _, response = app.test_client.get("/path/to/file.txt") | ||||
|     assert response.text == "txt" | ||||
|  | ||||
|  | ||||
| def test_route_with_regex_named_group(app): | ||||
|     @app.route(r"/path/to/<ext:file\.(?P<ext>txt)>") | ||||
|     async def handler(request, ext): | ||||
|         return text(ext) | ||||
|  | ||||
|     _, response = app.test_client.get("/path/to/file.txt") | ||||
|     assert response.text == "txt" | ||||
|  | ||||
|  | ||||
| def test_route_with_regex_named_group_invalid(app): | ||||
|     @app.route(r"/path/to/<ext:file\.(?P<wrong>txt)>") | ||||
|     async def handler(request, ext): | ||||
|         return text(ext) | ||||
|  | ||||
|     with pytest.raises(InvalidUsage) as e: | ||||
|         app.router.finalize() | ||||
|  | ||||
|     assert e.match( | ||||
|         re.escape("Named group (wrong) must match your named parameter (ext)") | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_route_with_regex_group_ambiguous(app): | ||||
|     @app.route("/path/to/<ext:file(?:\.)(txt)>") | ||||
|     async def handler(request, ext): | ||||
|         return text(ext) | ||||
|  | ||||
|     with pytest.raises(InvalidUsage) as e: | ||||
|         app.router.finalize() | ||||
|  | ||||
|     assert e.match( | ||||
|         re.escape( | ||||
|             "Could not compile pattern file(?:\.)(txt). Try using a named " | ||||
|             "group instead: '(?P<ext>your_matching_group)'" | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def test_route_with_bad_named_param(app): | ||||
|     @app.route("/foo/<__bar__>") | ||||
|     async def handler(request): | ||||
|         return text("...") | ||||
|  | ||||
|     with pytest.raises(SanicException): | ||||
|         app.router.finalize() | ||||
|   | ||||
| @@ -344,3 +344,23 @@ def test_methodview_naming(methodview_app): | ||||
|  | ||||
|     assert viewone_url == "/view_one" | ||||
|     assert viewtwo_url == "/view_two" | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize( | ||||
|     "path,version,expected", | ||||
|     ( | ||||
|         ("/foo", 1, "/v1/foo"), | ||||
|         ("/foo", 1.1, "/v1.1/foo"), | ||||
|         ("/foo", "1", "/v1/foo"), | ||||
|         ("/foo", "1.1", "/v1.1/foo"), | ||||
|         ("/foo", "1.0.1", "/v1.0.1/foo"), | ||||
|         ("/foo", "v1.0.1", "/v1.0.1/foo"), | ||||
|     ), | ||||
| ) | ||||
| def test_versioning(app, path, version, expected): | ||||
|     @app.route(path, version=version) | ||||
|     def handler(*_): | ||||
|         ... | ||||
|  | ||||
|     url = app.url_for("handler") | ||||
|     assert url == expected | ||||
|   | ||||
| @@ -88,3 +88,19 @@ def test_websocket_bp_route_name(app): | ||||
|  | ||||
| # TODO: add test with a route with multiple hosts | ||||
| # TODO: add test with a route with _host in url_for | ||||
| @pytest.mark.parametrize( | ||||
|     "path,strict,expected", | ||||
|     ( | ||||
|         ("/foo", False, "/foo"), | ||||
|         ("/foo/", False, "/foo"), | ||||
|         ("/foo", True, "/foo"), | ||||
|         ("/foo/", True, "/foo/"), | ||||
|     ), | ||||
| ) | ||||
| def test_trailing_slash_url_for(app, path, strict, expected): | ||||
|     @app.route(path, strict_slashes=strict) | ||||
|     def handler(*_): | ||||
|         ... | ||||
|  | ||||
|     url = app.url_for("handler") | ||||
|     assert url == expected | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Adam Hopkins
					Adam Hopkins