Router tweaks (#2031)

* Add trailing slash when defined and strict_slashes

* Add partial matching, and fix some issues with url_for

* Cover additional edge cases

* cleanup tests
This commit is contained in:
Adam Hopkins 2021-03-01 15:30:52 +02:00 committed by GitHub
parent 4b968dc611
commit 27f64ddae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 175 additions and 63 deletions

View File

@ -393,17 +393,22 @@ class Sanic(BaseSanic):
if getattr(route.ctx, "static", None): if getattr(route.ctx, "static", None):
filename = kwargs.pop("filename", "") filename = kwargs.pop("filename", "")
# it's static folder # it's static folder
if "file_uri" in uri: if "__file_uri__" in uri:
folder_ = uri.split("<file_uri:", 1)[0] folder_ = uri.split("<__file_uri__:", 1)[0]
if folder_.endswith("/"): if folder_.endswith("/"):
folder_ = folder_[:-1] folder_ = folder_[:-1]
if filename.startswith("/"): if filename.startswith("/"):
filename = filename[1:] filename = filename[1:]
kwargs["file_uri"] = filename kwargs["__file_uri__"] = filename
if uri != "/" and uri.endswith("/"): if (
uri != "/"
and uri.endswith("/")
and not route.strict
and not route.raw_path[:-1]
):
uri = uri[:-1] uri = uri[:-1]
if not uri.startswith("/"): if not uri.startswith("/"):
@ -573,25 +578,27 @@ class Sanic(BaseSanic):
# Define `response` var here to remove warnings about # Define `response` var here to remove warnings about
# allocation before assignment below. # allocation before assignment below.
response = None response = None
name = None
try: try:
# Fetch handler from router # Fetch handler from router
( (
route,
handler, handler,
kwargs, kwargs,
uri,
name,
ignore_body,
) = self.router.get(request) ) = self.router.get(request)
request.name = name
request._match_info = kwargs request._match_info = kwargs
request.route = route
request.name = route.name
request.uri_template = f"/{route.path}"
request.endpoint = request.name
if ( if (
request.stream request.stream
and request.stream.request_body and request.stream.request_body
and not ignore_body and not route.ctx.ignore_body
): ):
if self.router.is_stream_handler(request):
if hasattr(handler, "is_stream"):
# Streaming handler: lift the size limit # Streaming handler: lift the size limit
request.stream.request_max_size = float("inf") request.stream.request_max_size = float("inf")
else: else:
@ -602,15 +609,15 @@ class Sanic(BaseSanic):
# Request Middleware # Request Middleware
# -------------------------------------------- # # -------------------------------------------- #
response = await self._run_request_middleware( response = await self._run_request_middleware(
request, request_name=name request, request_name=route.name
) )
# No middleware results # No middleware results
if not response: if not response:
# -------------------------------------------- # # -------------------------------------------- #
# Execute Handler # Execute Handler
# -------------------------------------------- # # -------------------------------------------- #
request.uri_template = f"/{uri}"
if handler is None: if handler is None:
raise ServerError( raise ServerError(
( (
@ -619,12 +626,11 @@ class Sanic(BaseSanic):
) )
) )
request.endpoint = request.name
# Run response handler # Run response handler
response = handler(request, **kwargs) response = handler(request, **kwargs)
if isawaitable(response): if isawaitable(response):
response = await response response = await response
if response: if response:
response = await request.respond(response) response = await request.respond(response)
else: else:

View File

@ -611,7 +611,7 @@ class RouteMixin:
else: else:
break break
if not name: # noq if not name: # noqa
raise ValueError("Could not generate a name for handler") raise ValueError("Could not generate a name for handler")
if not name.startswith(f"{self.name}."): if not name.startswith(f"{self.name}."):
@ -627,19 +627,19 @@ class RouteMixin:
stream_large_files, stream_large_files,
request, request,
content_type=None, content_type=None,
file_uri=None, __file_uri__=None,
): ):
# Using this to determine if the URL is trying to break out of the path # Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow # served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri: if __file_uri__ and "../" in __file_uri__:
raise InvalidUsage("Invalid URL") raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided # Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python # Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path # from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory root_path = file_path = file_or_directory
if file_uri: if __file_uri__:
file_path = path.join( file_path = path.join(
file_or_directory, sub("^[/]*", "", file_uri) file_or_directory, sub("^[/]*", "", __file_uri__)
) )
# URL decode the path sent by the browser otherwise we won't be able to # URL decode the path sent by the browser otherwise we won't be able to
@ -648,10 +648,12 @@ class RouteMixin:
if not file_path.startswith(path.abspath(unquote(root_path))): if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception( error_logger.exception(
f"File not found: path={file_or_directory}, " f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}" f"relative_url={__file_uri__}"
) )
raise FileNotFound( raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri "File not found",
path=file_or_directory,
relative_url=__file_uri__,
) )
try: try:
headers = {} headers = {}
@ -719,10 +721,12 @@ class RouteMixin:
except Exception: except Exception:
error_logger.exception( error_logger.exception(
f"File not found: path={file_or_directory}, " f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}" f"relative_url={__file_uri__}"
) )
raise FileNotFound( raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri "File not found",
path=file_or_directory,
relative_url=__file_uri__,
) )
def _register_static( def _register_static(
@ -772,7 +776,7 @@ class RouteMixin:
# If we're not trying to match a file directly, # If we're not trying to match a file directly,
# serve from the folder # serve from the folder
if not path.isfile(file_or_directory): if not path.isfile(file_or_directory):
uri += "/<file_uri>" uri += "/<__file_uri__>"
# special prefix for static files # special prefix for static files
# if not static.name.startswith("_static_"): # if not static.name.startswith("_static_"):

View File

@ -12,6 +12,8 @@ from typing import (
Union, Union,
) )
from sanic_routing.route import Route # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic.server import ConnInfo from sanic.server import ConnInfo
@ -104,6 +106,7 @@ class Request:
"parsed_forwarded", "parsed_forwarded",
"raw_url", "raw_url",
"request_middleware_started", "request_middleware_started",
"route",
"stream", "stream",
"transport", "transport",
"uri_template", "uri_template",
@ -151,6 +154,7 @@ class Request:
self._match_info: Dict[str, Any] = {} self._match_info: Dict[str, Any] = {}
self.stream: Optional[Http] = None self.stream: Optional[Http] = None
self.endpoint: Optional[str] = None self.endpoint: Optional[str] = None
self.route: Optional[Route] = None
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -431,6 +435,7 @@ class Request:
:return: Incoming cookies on the request :return: Incoming cookies on the request
:rtype: Dict[str, str] :rtype: Dict[str, str]
""" """
if self._cookies is None: if self._cookies is None:
cookie = self.headers.get("Cookie") cookie = self.headers.get("Cookie")
if cookie is not None: if cookie is not None:

View File

@ -9,12 +9,13 @@ from sanic_routing.exceptions import (
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.exceptions import MethodNotSupported, NotFound from sanic.exceptions import MethodNotSupported, NotFound, SanicException
from sanic.handlers import RouteHandler from sanic.handlers import RouteHandler
from sanic.request import Request from sanic.request import Request
ROUTER_CACHE_SIZE = 1024 ROUTER_CACHE_SIZE = 1024
ALLOWED_LABELS = ("__file_uri__",)
class Router(BaseRouter): class Router(BaseRouter):
@ -33,7 +34,7 @@ class Router(BaseRouter):
@lru_cache(maxsize=ROUTER_CACHE_SIZE) @lru_cache(maxsize=ROUTER_CACHE_SIZE)
def _get( def _get(
self, path, method, host self, path, method, host
) -> Tuple[RouteHandler, Dict[str, Any], str, str, bool]: ) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
try: try:
route, handler, params = self.resolve( route, handler, params = self.resolve(
path=path, path=path,
@ -50,14 +51,14 @@ class Router(BaseRouter):
) )
return ( return (
route,
handler, handler,
params, params,
route.path,
route.name,
route.ctx.ignore_body,
) )
def get(self, request: Request): def get( # type: ignore
self, request: Request
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
""" """
Retrieve a `Route` object containg the details about how to handle Retrieve a `Route` object containg the details about how to handle
a response for a given request a response for a given request
@ -66,14 +67,13 @@ class Router(BaseRouter):
:type request: Request :type request: Request
:return: details needed for handling the request and returning the :return: details needed for handling the request and returning the
correct response correct response
:rtype: Tuple[ RouteHandler, Tuple[Any, ...], Dict[str, Any], str, str, :rtype: Tuple[ Route, RouteHandler, Dict[str, Any]]
Optional[str], bool, ]
""" """
return self._get( return self._get(
request.path, request.method, request.headers.get("host") request.path, request.method, request.headers.get("host")
) )
def add( def add( # type: ignore
self, self,
uri: str, uri: str,
methods: Iterable[str], methods: Iterable[str],
@ -138,7 +138,7 @@ class Router(BaseRouter):
if host: if host:
params.update({"requirements": {"host": host}}) params.update({"requirements": {"host": host}})
route = super().add(**params) route = super().add(**params) # type: ignore
route.ctx.ignore_body = ignore_body route.ctx.ignore_body = ignore_body
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
@ -150,23 +150,6 @@ class Router(BaseRouter):
return routes[0] return routes[0]
return routes return routes
def is_stream_handler(self, request) -> bool:
"""
Handler for request is stream or not.
:param request: Request object
:return: bool
"""
try:
handler = self.get(request)[0]
except (NotFound, MethodNotSupported):
return False
if hasattr(handler, "view_class") and hasattr(
handler.view_class, request.method.lower()
):
handler = getattr(handler.view_class, request.method.lower())
return hasattr(handler, "is_stream")
@lru_cache(maxsize=ROUTER_CACHE_SIZE) @lru_cache(maxsize=ROUTER_CACHE_SIZE)
def find_route_by_view_name(self, view_name, name=None): def find_route_by_view_name(self, view_name, name=None):
""" """
@ -204,3 +187,15 @@ class Router(BaseRouter):
@property @property
def routes_regex(self): def routes_regex(self):
return self.regex_routes return self.regex_routes
def finalize(self, *args, **kwargs):
super().finalize(*args, **kwargs)
for route in self.dynamic_routes.values():
if any(
label.startswith("__") and label not in ALLOWED_LABELS
for label in route.labels
):
raise SanicException(
f"Invalid route: {route}. Parameter names cannot use '__'."
)

View File

@ -39,7 +39,7 @@ class TestSanicRouteResolution:
iterations=1000, iterations=1000,
rounds=1000, rounds=1000,
) )
assert await result[0](None) == 1 assert await result[1](None) == 1
@mark.asyncio @mark.asyncio
async def test_resolve_route_with_typed_args( async def test_resolve_route_with_typed_args(
@ -72,4 +72,4 @@ class TestSanicRouteResolution:
iterations=1000, iterations=1000,
rounds=1000, rounds=1000,
) )
assert await result[0](None) == 1 assert await result[1](None) == 1

View File

@ -4,7 +4,7 @@ import sys
from inspect import isawaitable from inspect import isawaitable
from os import environ from os import environ
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
@ -123,7 +123,7 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch): def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs): def mockreturn(*args, **kwargs):
return None, {}, "", "", False return Mock(), None, {}
# Not sure how to make app.router.get() return None, so use mock here. # Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn) monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -752,7 +752,7 @@ def test_static_blueprint_name(static_file_directory, file_name):
app.blueprint(bp) app.blueprint(bp)
uri = app.url_for("static", name="static.testing") uri = app.url_for("static", name="static.testing")
assert uri == "/static/test.file" assert uri == "/static/test.file/"
_, response = app.test_client.get("/static/test.file") _, response = app.test_client.get("/static/test.file")
assert response.status == 404 assert response.status == 404

View File

@ -126,7 +126,6 @@ def test_html_traceback_output_in_debug_mode():
assert response.status == 500 assert response.status == 500
soup = BeautifulSoup(response.body, "html.parser") soup = BeautifulSoup(response.body, "html.parser")
html = str(soup) html = str(soup)
print(html)
assert "response = handler(request, **kwargs)" in html assert "response = handler(request, **kwargs)" in html
assert "handler_4" in html assert "handler_4" in html

View File

@ -59,7 +59,6 @@ def write_app(filename, **runargs):
def scanner(proc): def scanner(proc):
for line in proc.stdout: for line in proc.stdout:
line = line.decode().strip() line = line.decode().strip()
print(">", line)
if line.startswith("complete"): if line.startswith("complete"):
yield line yield line

View File

@ -74,3 +74,12 @@ def test_custom_generator():
"/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"} "/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"}
) )
assert request.id == REQUEST_ID * 2 assert request.id == REQUEST_ID * 2
def test_route_assigned_to_request(app):
@app.get("/")
async def get(request):
return response.empty()
request, _ = app.test_client.get("/")
assert request.route is list(app.router.routes.values())[0]

View File

@ -646,7 +646,6 @@ def test_streaming_echo():
data = await reader.read(4096) data = await reader.read(4096)
assert data assert data
buffer += data buffer += data
print(res)
assert buffer[size : size + 2] == b"\r\n" assert buffer[size : size + 2] == b"\r\n"
ret, buffer = buffer[:size], buffer[size + 2 :] ret, buffer = buffer[:size], buffer[size + 2 :]
return ret return ret

View File

@ -1,15 +1,20 @@
import asyncio import asyncio
import re
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from sanic_routing.exceptions import ParameterNameConflicts, RouteExists from sanic_routing.exceptions import (
InvalidUsage,
ParameterNameConflicts,
RouteExists,
)
from sanic_testing.testing import SanicTestClient from sanic_testing.testing import SanicTestClient
from sanic import Blueprint, Sanic from sanic import Blueprint, Sanic
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.exceptions import NotFound from sanic.exceptions import NotFound, SanicException
from sanic.request import Request from sanic.request import Request
from sanic.response import json, text from sanic.response import json, text
@ -189,7 +194,6 @@ def test_versioned_routes_get(app, method):
return text("OK") return text("OK")
else: else:
print(func)
raise Exception(f"Method: {method} is not callable") raise Exception(f"Method: {method} is not callable")
client_method = getattr(app.test_client, method) client_method = getattr(app.test_client, method)
@ -1113,3 +1117,59 @@ def test_route_invalid_host(app):
assert str(excinfo.value) == ( assert str(excinfo.value) == (
"Expected either string or Iterable of " "host strings, not {!r}" "Expected either string or Iterable of " "host strings, not {!r}"
).format(host) ).format(host)
def test_route_with_regex_group(app):
@app.route("/path/to/<ext:file\.(txt)>")
async def handler(request, ext):
return text(ext)
_, response = app.test_client.get("/path/to/file.txt")
assert response.text == "txt"
def test_route_with_regex_named_group(app):
@app.route(r"/path/to/<ext:file\.(?P<ext>txt)>")
async def handler(request, ext):
return text(ext)
_, response = app.test_client.get("/path/to/file.txt")
assert response.text == "txt"
def test_route_with_regex_named_group_invalid(app):
@app.route(r"/path/to/<ext:file\.(?P<wrong>txt)>")
async def handler(request, ext):
return text(ext)
with pytest.raises(InvalidUsage) as e:
app.router.finalize()
assert e.match(
re.escape("Named group (wrong) must match your named parameter (ext)")
)
def test_route_with_regex_group_ambiguous(app):
@app.route("/path/to/<ext:file(?:\.)(txt)>")
async def handler(request, ext):
return text(ext)
with pytest.raises(InvalidUsage) as e:
app.router.finalize()
assert e.match(
re.escape(
"Could not compile pattern file(?:\.)(txt). Try using a named "
"group instead: '(?P<ext>your_matching_group)'"
)
)
def test_route_with_bad_named_param(app):
@app.route("/foo/<__bar__>")
async def handler(request):
return text("...")
with pytest.raises(SanicException):
app.router.finalize()

View File

@ -344,3 +344,23 @@ def test_methodview_naming(methodview_app):
assert viewone_url == "/view_one" assert viewone_url == "/view_one"
assert viewtwo_url == "/view_two" assert viewtwo_url == "/view_two"
@pytest.mark.parametrize(
"path,version,expected",
(
("/foo", 1, "/v1/foo"),
("/foo", 1.1, "/v1.1/foo"),
("/foo", "1", "/v1/foo"),
("/foo", "1.1", "/v1.1/foo"),
("/foo", "1.0.1", "/v1.0.1/foo"),
("/foo", "v1.0.1", "/v1.0.1/foo"),
),
)
def test_versioning(app, path, version, expected):
@app.route(path, version=version)
def handler(*_):
...
url = app.url_for("handler")
assert url == expected

View File

@ -88,3 +88,19 @@ def test_websocket_bp_route_name(app):
# TODO: add test with a route with multiple hosts # TODO: add test with a route with multiple hosts
# TODO: add test with a route with _host in url_for # TODO: add test with a route with _host in url_for
@pytest.mark.parametrize(
"path,strict,expected",
(
("/foo", False, "/foo"),
("/foo/", False, "/foo"),
("/foo", True, "/foo"),
("/foo/", True, "/foo/"),
),
)
def test_trailing_slash_url_for(app, path, strict, expected):
@app.route(path, strict_slashes=strict)
def handler(*_):
...
url = app.url_for("handler")
assert url == expected