Router tweaks (#2031)

* Add trailing slash when defined and strict_slashes

* Add partial matching, and fix some issues with url_for

* Cover additional edge cases

* cleanup tests
This commit is contained in:
Adam Hopkins 2021-03-01 15:30:52 +02:00 committed by GitHub
parent 4b968dc611
commit 27f64ddae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 175 additions and 63 deletions

View File

@ -393,17 +393,22 @@ class Sanic(BaseSanic):
if getattr(route.ctx, "static", None):
filename = kwargs.pop("filename", "")
# it's static folder
if "file_uri" in uri:
folder_ = uri.split("<file_uri:", 1)[0]
if "__file_uri__" in uri:
folder_ = uri.split("<__file_uri__:", 1)[0]
if folder_.endswith("/"):
folder_ = folder_[:-1]
if filename.startswith("/"):
filename = filename[1:]
kwargs["file_uri"] = filename
kwargs["__file_uri__"] = filename
if uri != "/" and uri.endswith("/"):
if (
uri != "/"
and uri.endswith("/")
and not route.strict
and not route.raw_path[:-1]
):
uri = uri[:-1]
if not uri.startswith("/"):
@ -573,25 +578,27 @@ class Sanic(BaseSanic):
# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
name = None
try:
# Fetch handler from router
(
route,
handler,
kwargs,
uri,
name,
ignore_body,
) = self.router.get(request)
request.name = name
request._match_info = kwargs
request.route = route
request.name = route.name
request.uri_template = f"/{route.path}"
request.endpoint = request.name
if (
request.stream
and request.stream.request_body
and not ignore_body
and not route.ctx.ignore_body
):
if self.router.is_stream_handler(request):
if hasattr(handler, "is_stream"):
# Streaming handler: lift the size limit
request.stream.request_max_size = float("inf")
else:
@ -602,15 +609,15 @@ class Sanic(BaseSanic):
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=name
request, request_name=route.name
)
# No middleware results
if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #
request.uri_template = f"/{uri}"
if handler is None:
raise ServerError(
(
@ -619,12 +626,11 @@ class Sanic(BaseSanic):
)
)
request.endpoint = request.name
# Run response handler
response = handler(request, **kwargs)
if isawaitable(response):
response = await response
if response:
response = await request.respond(response)
else:

View File

@ -611,7 +611,7 @@ class RouteMixin:
else:
break
if not name: # noq
if not name: # noqa
raise ValueError("Could not generate a name for handler")
if not name.startswith(f"{self.name}."):
@ -627,19 +627,19 @@ class RouteMixin:
stream_large_files,
request,
content_type=None,
file_uri=None,
__file_uri__=None,
):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and "../" in file_uri:
if __file_uri__ and "../" in __file_uri__:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
# from herping a derp and treating the uri as an absolute path
root_path = file_path = file_or_directory
if file_uri:
if __file_uri__:
file_path = path.join(
file_or_directory, sub("^[/]*", "", file_uri)
file_or_directory, sub("^[/]*", "", __file_uri__)
)
# URL decode the path sent by the browser otherwise we won't be able to
@ -648,10 +648,12 @@ class RouteMixin:
if not file_path.startswith(path.abspath(unquote(root_path))):
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
f"relative_url={__file_uri__}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
try:
headers = {}
@ -719,10 +721,12 @@ class RouteMixin:
except Exception:
error_logger.exception(
f"File not found: path={file_or_directory}, "
f"relative_url={file_uri}"
f"relative_url={__file_uri__}"
)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
"File not found",
path=file_or_directory,
relative_url=__file_uri__,
)
def _register_static(
@ -772,7 +776,7 @@ class RouteMixin:
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += "/<file_uri>"
uri += "/<__file_uri__>"
# special prefix for static files
# if not static.name.startswith("_static_"):

View File

@ -12,6 +12,8 @@ from typing import (
Union,
)
from sanic_routing.route import Route # type: ignore
if TYPE_CHECKING:
from sanic.server import ConnInfo
@ -104,6 +106,7 @@ class Request:
"parsed_forwarded",
"raw_url",
"request_middleware_started",
"route",
"stream",
"transport",
"uri_template",
@ -151,6 +154,7 @@ class Request:
self._match_info: Dict[str, Any] = {}
self.stream: Optional[Http] = None
self.endpoint: Optional[str] = None
self.route: Optional[Route] = None
def __repr__(self):
class_name = self.__class__.__name__
@ -431,6 +435,7 @@ class Request:
:return: Incoming cookies on the request
:rtype: Dict[str, str]
"""
if self._cookies is None:
cookie = self.headers.get("Cookie")
if cookie is not None:

View File

@ -9,12 +9,13 @@ from sanic_routing.exceptions import (
from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS
from sanic.exceptions import MethodNotSupported, NotFound
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
from sanic.handlers import RouteHandler
from sanic.request import Request
ROUTER_CACHE_SIZE = 1024
ALLOWED_LABELS = ("__file_uri__",)
class Router(BaseRouter):
@ -33,7 +34,7 @@ class Router(BaseRouter):
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
def _get(
self, path, method, host
) -> Tuple[RouteHandler, Dict[str, Any], str, str, bool]:
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
try:
route, handler, params = self.resolve(
path=path,
@ -50,14 +51,14 @@ class Router(BaseRouter):
)
return (
route,
handler,
params,
route.path,
route.name,
route.ctx.ignore_body,
)
def get(self, request: Request):
def get( # type: ignore
self, request: Request
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
"""
Retrieve a `Route` object containg the details about how to handle
a response for a given request
@ -66,14 +67,13 @@ class Router(BaseRouter):
:type request: Request
:return: details needed for handling the request and returning the
correct response
:rtype: Tuple[ RouteHandler, Tuple[Any, ...], Dict[str, Any], str, str,
Optional[str], bool, ]
:rtype: Tuple[ Route, RouteHandler, Dict[str, Any]]
"""
return self._get(
request.path, request.method, request.headers.get("host")
)
def add(
def add( # type: ignore
self,
uri: str,
methods: Iterable[str],
@ -138,7 +138,7 @@ class Router(BaseRouter):
if host:
params.update({"requirements": {"host": host}})
route = super().add(**params)
route = super().add(**params) # type: ignore
route.ctx.ignore_body = ignore_body
route.ctx.stream = stream
route.ctx.hosts = hosts
@ -150,23 +150,6 @@ class Router(BaseRouter):
return routes[0]
return routes
def is_stream_handler(self, request) -> bool:
"""
Handler for request is stream or not.
:param request: Request object
:return: bool
"""
try:
handler = self.get(request)[0]
except (NotFound, MethodNotSupported):
return False
if hasattr(handler, "view_class") and hasattr(
handler.view_class, request.method.lower()
):
handler = getattr(handler.view_class, request.method.lower())
return hasattr(handler, "is_stream")
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
def find_route_by_view_name(self, view_name, name=None):
"""
@ -204,3 +187,15 @@ class Router(BaseRouter):
@property
def routes_regex(self):
return self.regex_routes
def finalize(self, *args, **kwargs):
super().finalize(*args, **kwargs)
for route in self.dynamic_routes.values():
if any(
label.startswith("__") and label not in ALLOWED_LABELS
for label in route.labels
):
raise SanicException(
f"Invalid route: {route}. Parameter names cannot use '__'."
)

View File

@ -39,7 +39,7 @@ class TestSanicRouteResolution:
iterations=1000,
rounds=1000,
)
assert await result[0](None) == 1
assert await result[1](None) == 1
@mark.asyncio
async def test_resolve_route_with_typed_args(
@ -72,4 +72,4 @@ class TestSanicRouteResolution:
iterations=1000,
rounds=1000,
)
assert await result[0](None) == 1
assert await result[1](None) == 1

View File

@ -4,7 +4,7 @@ import sys
from inspect import isawaitable
from os import environ
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@ -123,7 +123,7 @@ def test_app_route_raise_value_error(app):
def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs):
return None, {}, "", "", False
return Mock(), None, {}
# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)

View File

@ -752,7 +752,7 @@ def test_static_blueprint_name(static_file_directory, file_name):
app.blueprint(bp)
uri = app.url_for("static", name="static.testing")
assert uri == "/static/test.file"
assert uri == "/static/test.file/"
_, response = app.test_client.get("/static/test.file")
assert response.status == 404

View File

@ -126,7 +126,6 @@ def test_html_traceback_output_in_debug_mode():
assert response.status == 500
soup = BeautifulSoup(response.body, "html.parser")
html = str(soup)
print(html)
assert "response = handler(request, **kwargs)" in html
assert "handler_4" in html

View File

@ -59,7 +59,6 @@ def write_app(filename, **runargs):
def scanner(proc):
for line in proc.stdout:
line = line.decode().strip()
print(">", line)
if line.startswith("complete"):
yield line

View File

@ -74,3 +74,12 @@ def test_custom_generator():
"/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"}
)
assert request.id == REQUEST_ID * 2
def test_route_assigned_to_request(app):
@app.get("/")
async def get(request):
return response.empty()
request, _ = app.test_client.get("/")
assert request.route is list(app.router.routes.values())[0]

View File

@ -646,7 +646,6 @@ def test_streaming_echo():
data = await reader.read(4096)
assert data
buffer += data
print(res)
assert buffer[size : size + 2] == b"\r\n"
ret, buffer = buffer[:size], buffer[size + 2 :]
return ret

View File

@ -1,15 +1,20 @@
import asyncio
import re
from unittest.mock import Mock
import pytest
from sanic_routing.exceptions import ParameterNameConflicts, RouteExists
from sanic_routing.exceptions import (
InvalidUsage,
ParameterNameConflicts,
RouteExists,
)
from sanic_testing.testing import SanicTestClient
from sanic import Blueprint, Sanic
from sanic.constants import HTTP_METHODS
from sanic.exceptions import NotFound
from sanic.exceptions import NotFound, SanicException
from sanic.request import Request
from sanic.response import json, text
@ -189,7 +194,6 @@ def test_versioned_routes_get(app, method):
return text("OK")
else:
print(func)
raise Exception(f"Method: {method} is not callable")
client_method = getattr(app.test_client, method)
@ -1113,3 +1117,59 @@ def test_route_invalid_host(app):
assert str(excinfo.value) == (
"Expected either string or Iterable of " "host strings, not {!r}"
).format(host)
def test_route_with_regex_group(app):
@app.route("/path/to/<ext:file\.(txt)>")
async def handler(request, ext):
return text(ext)
_, response = app.test_client.get("/path/to/file.txt")
assert response.text == "txt"
def test_route_with_regex_named_group(app):
@app.route(r"/path/to/<ext:file\.(?P<ext>txt)>")
async def handler(request, ext):
return text(ext)
_, response = app.test_client.get("/path/to/file.txt")
assert response.text == "txt"
def test_route_with_regex_named_group_invalid(app):
@app.route(r"/path/to/<ext:file\.(?P<wrong>txt)>")
async def handler(request, ext):
return text(ext)
with pytest.raises(InvalidUsage) as e:
app.router.finalize()
assert e.match(
re.escape("Named group (wrong) must match your named parameter (ext)")
)
def test_route_with_regex_group_ambiguous(app):
@app.route("/path/to/<ext:file(?:\.)(txt)>")
async def handler(request, ext):
return text(ext)
with pytest.raises(InvalidUsage) as e:
app.router.finalize()
assert e.match(
re.escape(
"Could not compile pattern file(?:\.)(txt). Try using a named "
"group instead: '(?P<ext>your_matching_group)'"
)
)
def test_route_with_bad_named_param(app):
@app.route("/foo/<__bar__>")
async def handler(request):
return text("...")
with pytest.raises(SanicException):
app.router.finalize()

View File

@ -344,3 +344,23 @@ def test_methodview_naming(methodview_app):
assert viewone_url == "/view_one"
assert viewtwo_url == "/view_two"
@pytest.mark.parametrize(
"path,version,expected",
(
("/foo", 1, "/v1/foo"),
("/foo", 1.1, "/v1.1/foo"),
("/foo", "1", "/v1/foo"),
("/foo", "1.1", "/v1.1/foo"),
("/foo", "1.0.1", "/v1.0.1/foo"),
("/foo", "v1.0.1", "/v1.0.1/foo"),
),
)
def test_versioning(app, path, version, expected):
@app.route(path, version=version)
def handler(*_):
...
url = app.url_for("handler")
assert url == expected

View File

@ -88,3 +88,19 @@ def test_websocket_bp_route_name(app):
# TODO: add test with a route with multiple hosts
# TODO: add test with a route with _host in url_for
@pytest.mark.parametrize(
"path,strict,expected",
(
("/foo", False, "/foo"),
("/foo/", False, "/foo"),
("/foo", True, "/foo"),
("/foo/", True, "/foo/"),
),
)
def test_trailing_slash_url_for(app, path, strict, expected):
@app.route(path, strict_slashes=strict)
def handler(*_):
...
url = app.url_for("handler")
assert url == expected