Router tweaks (#2031)
* Add trailing slash when defined and strict_slashes * Add partial matching, and fix some issues with url_for * Cover additional edge cases * cleanup tests
This commit is contained in:
parent
4b968dc611
commit
27f64ddae2
36
sanic/app.py
36
sanic/app.py
|
@ -393,17 +393,22 @@ class Sanic(BaseSanic):
|
|||
if getattr(route.ctx, "static", None):
|
||||
filename = kwargs.pop("filename", "")
|
||||
# it's static folder
|
||||
if "file_uri" in uri:
|
||||
folder_ = uri.split("<file_uri:", 1)[0]
|
||||
if "__file_uri__" in uri:
|
||||
folder_ = uri.split("<__file_uri__:", 1)[0]
|
||||
if folder_.endswith("/"):
|
||||
folder_ = folder_[:-1]
|
||||
|
||||
if filename.startswith("/"):
|
||||
filename = filename[1:]
|
||||
|
||||
kwargs["file_uri"] = filename
|
||||
kwargs["__file_uri__"] = filename
|
||||
|
||||
if uri != "/" and uri.endswith("/"):
|
||||
if (
|
||||
uri != "/"
|
||||
and uri.endswith("/")
|
||||
and not route.strict
|
||||
and not route.raw_path[:-1]
|
||||
):
|
||||
uri = uri[:-1]
|
||||
|
||||
if not uri.startswith("/"):
|
||||
|
@ -573,25 +578,27 @@ class Sanic(BaseSanic):
|
|||
# Define `response` var here to remove warnings about
|
||||
# allocation before assignment below.
|
||||
response = None
|
||||
name = None
|
||||
try:
|
||||
# Fetch handler from router
|
||||
(
|
||||
route,
|
||||
handler,
|
||||
kwargs,
|
||||
uri,
|
||||
name,
|
||||
ignore_body,
|
||||
) = self.router.get(request)
|
||||
request.name = name
|
||||
|
||||
request._match_info = kwargs
|
||||
request.route = route
|
||||
request.name = route.name
|
||||
request.uri_template = f"/{route.path}"
|
||||
request.endpoint = request.name
|
||||
|
||||
if (
|
||||
request.stream
|
||||
and request.stream.request_body
|
||||
and not ignore_body
|
||||
and not route.ctx.ignore_body
|
||||
):
|
||||
if self.router.is_stream_handler(request):
|
||||
|
||||
if hasattr(handler, "is_stream"):
|
||||
# Streaming handler: lift the size limit
|
||||
request.stream.request_max_size = float("inf")
|
||||
else:
|
||||
|
@ -602,15 +609,15 @@ class Sanic(BaseSanic):
|
|||
# Request Middleware
|
||||
# -------------------------------------------- #
|
||||
response = await self._run_request_middleware(
|
||||
request, request_name=name
|
||||
request, request_name=route.name
|
||||
)
|
||||
|
||||
# No middleware results
|
||||
if not response:
|
||||
# -------------------------------------------- #
|
||||
# Execute Handler
|
||||
# -------------------------------------------- #
|
||||
|
||||
request.uri_template = f"/{uri}"
|
||||
if handler is None:
|
||||
raise ServerError(
|
||||
(
|
||||
|
@ -619,12 +626,11 @@ class Sanic(BaseSanic):
|
|||
)
|
||||
)
|
||||
|
||||
request.endpoint = request.name
|
||||
|
||||
# Run response handler
|
||||
response = handler(request, **kwargs)
|
||||
if isawaitable(response):
|
||||
response = await response
|
||||
|
||||
if response:
|
||||
response = await request.respond(response)
|
||||
else:
|
||||
|
|
|
@ -611,7 +611,7 @@ class RouteMixin:
|
|||
else:
|
||||
break
|
||||
|
||||
if not name: # noq
|
||||
if not name: # noqa
|
||||
raise ValueError("Could not generate a name for handler")
|
||||
|
||||
if not name.startswith(f"{self.name}."):
|
||||
|
@ -627,19 +627,19 @@ class RouteMixin:
|
|||
stream_large_files,
|
||||
request,
|
||||
content_type=None,
|
||||
file_uri=None,
|
||||
__file_uri__=None,
|
||||
):
|
||||
# Using this to determine if the URL is trying to break out of the path
|
||||
# served. os.path.realpath seems to be very slow
|
||||
if file_uri and "../" in file_uri:
|
||||
if __file_uri__ and "../" in __file_uri__:
|
||||
raise InvalidUsage("Invalid URL")
|
||||
# Merge served directory and requested file if provided
|
||||
# Strip all / that in the beginning of the URL to help prevent python
|
||||
# from herping a derp and treating the uri as an absolute path
|
||||
root_path = file_path = file_or_directory
|
||||
if file_uri:
|
||||
if __file_uri__:
|
||||
file_path = path.join(
|
||||
file_or_directory, sub("^[/]*", "", file_uri)
|
||||
file_or_directory, sub("^[/]*", "", __file_uri__)
|
||||
)
|
||||
|
||||
# URL decode the path sent by the browser otherwise we won't be able to
|
||||
|
@ -648,10 +648,12 @@ class RouteMixin:
|
|||
if not file_path.startswith(path.abspath(unquote(root_path))):
|
||||
error_logger.exception(
|
||||
f"File not found: path={file_or_directory}, "
|
||||
f"relative_url={file_uri}"
|
||||
f"relative_url={__file_uri__}"
|
||||
)
|
||||
raise FileNotFound(
|
||||
"File not found", path=file_or_directory, relative_url=file_uri
|
||||
"File not found",
|
||||
path=file_or_directory,
|
||||
relative_url=__file_uri__,
|
||||
)
|
||||
try:
|
||||
headers = {}
|
||||
|
@ -719,10 +721,12 @@ class RouteMixin:
|
|||
except Exception:
|
||||
error_logger.exception(
|
||||
f"File not found: path={file_or_directory}, "
|
||||
f"relative_url={file_uri}"
|
||||
f"relative_url={__file_uri__}"
|
||||
)
|
||||
raise FileNotFound(
|
||||
"File not found", path=file_or_directory, relative_url=file_uri
|
||||
"File not found",
|
||||
path=file_or_directory,
|
||||
relative_url=__file_uri__,
|
||||
)
|
||||
|
||||
def _register_static(
|
||||
|
@ -772,7 +776,7 @@ class RouteMixin:
|
|||
# If we're not trying to match a file directly,
|
||||
# serve from the folder
|
||||
if not path.isfile(file_or_directory):
|
||||
uri += "/<file_uri>"
|
||||
uri += "/<__file_uri__>"
|
||||
|
||||
# special prefix for static files
|
||||
# if not static.name.startswith("_static_"):
|
||||
|
|
|
@ -12,6 +12,8 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic.server import ConnInfo
|
||||
|
@ -104,6 +106,7 @@ class Request:
|
|||
"parsed_forwarded",
|
||||
"raw_url",
|
||||
"request_middleware_started",
|
||||
"route",
|
||||
"stream",
|
||||
"transport",
|
||||
"uri_template",
|
||||
|
@ -151,6 +154,7 @@ class Request:
|
|||
self._match_info: Dict[str, Any] = {}
|
||||
self.stream: Optional[Http] = None
|
||||
self.endpoint: Optional[str] = None
|
||||
self.route: Optional[Route] = None
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
|
@ -431,6 +435,7 @@ class Request:
|
|||
:return: Incoming cookies on the request
|
||||
:rtype: Dict[str, str]
|
||||
"""
|
||||
|
||||
if self._cookies is None:
|
||||
cookie = self.headers.get("Cookie")
|
||||
if cookie is not None:
|
||||
|
|
|
@ -9,12 +9,13 @@ from sanic_routing.exceptions import (
|
|||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import MethodNotSupported, NotFound
|
||||
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
|
||||
from sanic.handlers import RouteHandler
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
ROUTER_CACHE_SIZE = 1024
|
||||
ALLOWED_LABELS = ("__file_uri__",)
|
||||
|
||||
|
||||
class Router(BaseRouter):
|
||||
|
@ -33,7 +34,7 @@ class Router(BaseRouter):
|
|||
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
|
||||
def _get(
|
||||
self, path, method, host
|
||||
) -> Tuple[RouteHandler, Dict[str, Any], str, str, bool]:
|
||||
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
|
||||
try:
|
||||
route, handler, params = self.resolve(
|
||||
path=path,
|
||||
|
@ -50,14 +51,14 @@ class Router(BaseRouter):
|
|||
)
|
||||
|
||||
return (
|
||||
route,
|
||||
handler,
|
||||
params,
|
||||
route.path,
|
||||
route.name,
|
||||
route.ctx.ignore_body,
|
||||
)
|
||||
|
||||
def get(self, request: Request):
|
||||
def get( # type: ignore
|
||||
self, request: Request
|
||||
) -> Tuple[Route, RouteHandler, Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a `Route` object containg the details about how to handle
|
||||
a response for a given request
|
||||
|
@ -66,14 +67,13 @@ class Router(BaseRouter):
|
|||
:type request: Request
|
||||
:return: details needed for handling the request and returning the
|
||||
correct response
|
||||
:rtype: Tuple[ RouteHandler, Tuple[Any, ...], Dict[str, Any], str, str,
|
||||
Optional[str], bool, ]
|
||||
:rtype: Tuple[ Route, RouteHandler, Dict[str, Any]]
|
||||
"""
|
||||
return self._get(
|
||||
request.path, request.method, request.headers.get("host")
|
||||
)
|
||||
|
||||
def add(
|
||||
def add( # type: ignore
|
||||
self,
|
||||
uri: str,
|
||||
methods: Iterable[str],
|
||||
|
@ -138,7 +138,7 @@ class Router(BaseRouter):
|
|||
if host:
|
||||
params.update({"requirements": {"host": host}})
|
||||
|
||||
route = super().add(**params)
|
||||
route = super().add(**params) # type: ignore
|
||||
route.ctx.ignore_body = ignore_body
|
||||
route.ctx.stream = stream
|
||||
route.ctx.hosts = hosts
|
||||
|
@ -150,23 +150,6 @@ class Router(BaseRouter):
|
|||
return routes[0]
|
||||
return routes
|
||||
|
||||
def is_stream_handler(self, request) -> bool:
|
||||
"""
|
||||
Handler for request is stream or not.
|
||||
|
||||
:param request: Request object
|
||||
:return: bool
|
||||
"""
|
||||
try:
|
||||
handler = self.get(request)[0]
|
||||
except (NotFound, MethodNotSupported):
|
||||
return False
|
||||
if hasattr(handler, "view_class") and hasattr(
|
||||
handler.view_class, request.method.lower()
|
||||
):
|
||||
handler = getattr(handler.view_class, request.method.lower())
|
||||
return hasattr(handler, "is_stream")
|
||||
|
||||
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
|
||||
def find_route_by_view_name(self, view_name, name=None):
|
||||
"""
|
||||
|
@ -204,3 +187,15 @@ class Router(BaseRouter):
|
|||
@property
|
||||
def routes_regex(self):
|
||||
return self.regex_routes
|
||||
|
||||
def finalize(self, *args, **kwargs):
|
||||
super().finalize(*args, **kwargs)
|
||||
|
||||
for route in self.dynamic_routes.values():
|
||||
if any(
|
||||
label.startswith("__") and label not in ALLOWED_LABELS
|
||||
for label in route.labels
|
||||
):
|
||||
raise SanicException(
|
||||
f"Invalid route: {route}. Parameter names cannot use '__'."
|
||||
)
|
||||
|
|
|
@ -39,7 +39,7 @@ class TestSanicRouteResolution:
|
|||
iterations=1000,
|
||||
rounds=1000,
|
||||
)
|
||||
assert await result[0](None) == 1
|
||||
assert await result[1](None) == 1
|
||||
|
||||
@mark.asyncio
|
||||
async def test_resolve_route_with_typed_args(
|
||||
|
@ -72,4 +72,4 @@ class TestSanicRouteResolution:
|
|||
iterations=1000,
|
||||
rounds=1000,
|
||||
)
|
||||
assert await result[0](None) == 1
|
||||
assert await result[1](None) == 1
|
||||
|
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
from inspect import isawaitable
|
||||
from os import environ
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -123,7 +123,7 @@ def test_app_route_raise_value_error(app):
|
|||
|
||||
def test_app_handle_request_handler_is_none(app, monkeypatch):
|
||||
def mockreturn(*args, **kwargs):
|
||||
return None, {}, "", "", False
|
||||
return Mock(), None, {}
|
||||
|
||||
# Not sure how to make app.router.get() return None, so use mock here.
|
||||
monkeypatch.setattr(app.router, "get", mockreturn)
|
||||
|
|
|
@ -752,7 +752,7 @@ def test_static_blueprint_name(static_file_directory, file_name):
|
|||
app.blueprint(bp)
|
||||
|
||||
uri = app.url_for("static", name="static.testing")
|
||||
assert uri == "/static/test.file"
|
||||
assert uri == "/static/test.file/"
|
||||
|
||||
_, response = app.test_client.get("/static/test.file")
|
||||
assert response.status == 404
|
||||
|
|
|
@ -126,7 +126,6 @@ def test_html_traceback_output_in_debug_mode():
|
|||
assert response.status == 500
|
||||
soup = BeautifulSoup(response.body, "html.parser")
|
||||
html = str(soup)
|
||||
print(html)
|
||||
|
||||
assert "response = handler(request, **kwargs)" in html
|
||||
assert "handler_4" in html
|
||||
|
|
|
@ -59,7 +59,6 @@ def write_app(filename, **runargs):
|
|||
def scanner(proc):
|
||||
for line in proc.stdout:
|
||||
line = line.decode().strip()
|
||||
print(">", line)
|
||||
if line.startswith("complete"):
|
||||
yield line
|
||||
|
||||
|
|
|
@ -74,3 +74,12 @@ def test_custom_generator():
|
|||
"/", headers={"SOME-OTHER-REQUEST-ID": f"{REQUEST_ID}"}
|
||||
)
|
||||
assert request.id == REQUEST_ID * 2
|
||||
|
||||
|
||||
def test_route_assigned_to_request(app):
|
||||
@app.get("/")
|
||||
async def get(request):
|
||||
return response.empty()
|
||||
|
||||
request, _ = app.test_client.get("/")
|
||||
assert request.route is list(app.router.routes.values())[0]
|
||||
|
|
|
@ -646,7 +646,6 @@ def test_streaming_echo():
|
|||
data = await reader.read(4096)
|
||||
assert data
|
||||
buffer += data
|
||||
print(res)
|
||||
assert buffer[size : size + 2] == b"\r\n"
|
||||
ret, buffer = buffer[:size], buffer[size + 2 :]
|
||||
return ret
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
import asyncio
|
||||
import re
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic_routing.exceptions import ParameterNameConflicts, RouteExists
|
||||
from sanic_routing.exceptions import (
|
||||
InvalidUsage,
|
||||
ParameterNameConflicts,
|
||||
RouteExists,
|
||||
)
|
||||
from sanic_testing.testing import SanicTestClient
|
||||
|
||||
from sanic import Blueprint, Sanic
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.exceptions import NotFound, SanicException
|
||||
from sanic.request import Request
|
||||
from sanic.response import json, text
|
||||
|
||||
|
@ -189,7 +194,6 @@ def test_versioned_routes_get(app, method):
|
|||
return text("OK")
|
||||
|
||||
else:
|
||||
print(func)
|
||||
raise Exception(f"Method: {method} is not callable")
|
||||
|
||||
client_method = getattr(app.test_client, method)
|
||||
|
@ -1113,3 +1117,59 @@ def test_route_invalid_host(app):
|
|||
assert str(excinfo.value) == (
|
||||
"Expected either string or Iterable of " "host strings, not {!r}"
|
||||
).format(host)
|
||||
|
||||
|
||||
def test_route_with_regex_group(app):
|
||||
@app.route("/path/to/<ext:file\.(txt)>")
|
||||
async def handler(request, ext):
|
||||
return text(ext)
|
||||
|
||||
_, response = app.test_client.get("/path/to/file.txt")
|
||||
assert response.text == "txt"
|
||||
|
||||
|
||||
def test_route_with_regex_named_group(app):
|
||||
@app.route(r"/path/to/<ext:file\.(?P<ext>txt)>")
|
||||
async def handler(request, ext):
|
||||
return text(ext)
|
||||
|
||||
_, response = app.test_client.get("/path/to/file.txt")
|
||||
assert response.text == "txt"
|
||||
|
||||
|
||||
def test_route_with_regex_named_group_invalid(app):
|
||||
@app.route(r"/path/to/<ext:file\.(?P<wrong>txt)>")
|
||||
async def handler(request, ext):
|
||||
return text(ext)
|
||||
|
||||
with pytest.raises(InvalidUsage) as e:
|
||||
app.router.finalize()
|
||||
|
||||
assert e.match(
|
||||
re.escape("Named group (wrong) must match your named parameter (ext)")
|
||||
)
|
||||
|
||||
|
||||
def test_route_with_regex_group_ambiguous(app):
|
||||
@app.route("/path/to/<ext:file(?:\.)(txt)>")
|
||||
async def handler(request, ext):
|
||||
return text(ext)
|
||||
|
||||
with pytest.raises(InvalidUsage) as e:
|
||||
app.router.finalize()
|
||||
|
||||
assert e.match(
|
||||
re.escape(
|
||||
"Could not compile pattern file(?:\.)(txt). Try using a named "
|
||||
"group instead: '(?P<ext>your_matching_group)'"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_route_with_bad_named_param(app):
|
||||
@app.route("/foo/<__bar__>")
|
||||
async def handler(request):
|
||||
return text("...")
|
||||
|
||||
with pytest.raises(SanicException):
|
||||
app.router.finalize()
|
||||
|
|
|
@ -344,3 +344,23 @@ def test_methodview_naming(methodview_app):
|
|||
|
||||
assert viewone_url == "/view_one"
|
||||
assert viewtwo_url == "/view_two"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,version,expected",
|
||||
(
|
||||
("/foo", 1, "/v1/foo"),
|
||||
("/foo", 1.1, "/v1.1/foo"),
|
||||
("/foo", "1", "/v1/foo"),
|
||||
("/foo", "1.1", "/v1.1/foo"),
|
||||
("/foo", "1.0.1", "/v1.0.1/foo"),
|
||||
("/foo", "v1.0.1", "/v1.0.1/foo"),
|
||||
),
|
||||
)
|
||||
def test_versioning(app, path, version, expected):
|
||||
@app.route(path, version=version)
|
||||
def handler(*_):
|
||||
...
|
||||
|
||||
url = app.url_for("handler")
|
||||
assert url == expected
|
||||
|
|
|
@ -88,3 +88,19 @@ def test_websocket_bp_route_name(app):
|
|||
|
||||
# TODO: add test with a route with multiple hosts
|
||||
# TODO: add test with a route with _host in url_for
|
||||
@pytest.mark.parametrize(
|
||||
"path,strict,expected",
|
||||
(
|
||||
("/foo", False, "/foo"),
|
||||
("/foo/", False, "/foo"),
|
||||
("/foo", True, "/foo"),
|
||||
("/foo/", True, "/foo/"),
|
||||
),
|
||||
)
|
||||
def test_trailing_slash_url_for(app, path, strict, expected):
|
||||
@app.route(path, strict_slashes=strict)
|
||||
def handler(*_):
|
||||
...
|
||||
|
||||
url = app.url_for("handler")
|
||||
assert url == expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user