Smarter auto fallback (#2162)
* Smarter auto fallback * remove config from blueprints * Add tests for error formatting * Add check for proper format * Fix some tests * Add some tests * docstring * Add accept matching * Add some more tests on matching * Fix contains bug, earlier return on MediaType eq * Add matching flags for wildcards * Add mathing controls to accept * Cleanup dev cruft * Add cleanup and resolve OSError relating to test implementation * Fix test * Fix some typos
This commit is contained in:
parent
b5f2bd9b0e
commit
cf1d2148ac
|
@ -179,7 +179,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
|||
self.configure_logging = configure_logging
|
||||
self.ctx = ctx or SimpleNamespace()
|
||||
self.debug = None
|
||||
self.error_handler = error_handler or ErrorHandler()
|
||||
self.error_handler = error_handler or ErrorHandler(
|
||||
fallback=self.config.FALLBACK_ERROR_FORMAT,
|
||||
)
|
||||
self.is_running = False
|
||||
self.is_stopping = False
|
||||
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)
|
||||
|
|
|
@ -266,6 +266,9 @@ class Blueprint(BaseSanic):
|
|||
opt_version = options.get("version", None)
|
||||
opt_strict_slashes = options.get("strict_slashes", None)
|
||||
opt_version_prefix = options.get("version_prefix", self.version_prefix)
|
||||
error_format = options.get(
|
||||
"error_format", app.config.FALLBACK_ERROR_FORMAT
|
||||
)
|
||||
|
||||
routes = []
|
||||
middleware = []
|
||||
|
@ -313,6 +316,7 @@ class Blueprint(BaseSanic):
|
|||
future.unquote,
|
||||
future.static,
|
||||
version_prefix,
|
||||
error_format,
|
||||
)
|
||||
|
||||
route = app._apply_route(apply_route)
|
||||
|
|
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional, Union
|
||||
from warnings import warn
|
||||
|
||||
from sanic.errorpages import check_error_format
|
||||
from sanic.http import Http
|
||||
|
||||
from .utils import load_module_from_file_location, str_to_bool
|
||||
|
@ -20,7 +21,7 @@ BASE_LOGO = """
|
|||
DEFAULT_CONFIG = {
|
||||
"ACCESS_LOG": True,
|
||||
"EVENT_AUTOREGISTER": False,
|
||||
"FALLBACK_ERROR_FORMAT": "html",
|
||||
"FALLBACK_ERROR_FORMAT": "auto",
|
||||
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
|
||||
"FORWARDED_SECRET": None,
|
||||
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
|
||||
|
@ -94,6 +95,7 @@ class Config(dict):
|
|||
self.load_environment_vars(SANIC_PREFIX)
|
||||
|
||||
self._configure_header_size()
|
||||
self._check_error_format()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
|
@ -109,6 +111,8 @@ class Config(dict):
|
|||
"REQUEST_MAX_SIZE",
|
||||
):
|
||||
self._configure_header_size()
|
||||
elif attr == "FALLBACK_ERROR_FORMAT":
|
||||
self._check_error_format()
|
||||
|
||||
def _configure_header_size(self):
|
||||
Http.set_header_max_size(
|
||||
|
@ -117,6 +121,9 @@ class Config(dict):
|
|||
self.REQUEST_MAX_SIZE,
|
||||
)
|
||||
|
||||
def _check_error_format(self):
|
||||
check_error_format(self.FALLBACK_ERROR_FORMAT)
|
||||
|
||||
def load_environment_vars(self, prefix=SANIC_PREFIX):
|
||||
"""
|
||||
Looks for prefixed environment variables and applies
|
||||
|
|
|
@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = {
|
|||
}
|
||||
|
||||
RENDERERS_BY_CONTENT_TYPE = {
|
||||
"multipart/form-data": HTMLRenderer,
|
||||
"application/json": JSONRenderer,
|
||||
"text/plain": TextRenderer,
|
||||
"application/json": JSONRenderer,
|
||||
"multipart/form-data": HTMLRenderer,
|
||||
"text/html": HTMLRenderer,
|
||||
}
|
||||
CONTENT_TYPE_BY_RENDERERS = {
|
||||
v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()
|
||||
}
|
||||
|
||||
RESPONSE_MAPPING = {
|
||||
"empty": "html",
|
||||
"json": "json",
|
||||
"text": "text",
|
||||
"raw": "text",
|
||||
"html": "html",
|
||||
"file": "html",
|
||||
"file_stream": "text",
|
||||
"stream": "text",
|
||||
"redirect": "html",
|
||||
"text/plain": "text",
|
||||
"text/html": "html",
|
||||
"application/json": "json",
|
||||
}
|
||||
|
||||
|
||||
def check_error_format(format):
|
||||
if format not in RENDERERS_BY_CONFIG and format != "auto":
|
||||
raise SanicException(f"Unknown format: {format}")
|
||||
|
||||
|
||||
def exception_response(
|
||||
request: Request,
|
||||
exception: Exception,
|
||||
debug: bool,
|
||||
fallback: str,
|
||||
base: t.Type[BaseRenderer],
|
||||
renderer: t.Type[t.Optional[BaseRenderer]] = None,
|
||||
) -> HTTPResponse:
|
||||
"""
|
||||
Render a response for the default FALLBACK exception handler.
|
||||
"""
|
||||
content_type = None
|
||||
|
||||
if not renderer:
|
||||
renderer = HTMLRenderer
|
||||
# Make sure we have something set
|
||||
renderer = base
|
||||
render_format = fallback
|
||||
|
||||
if request:
|
||||
if request.app.config.FALLBACK_ERROR_FORMAT == "auto":
|
||||
# If there is a request, try and get the format
|
||||
# from the route
|
||||
if request.route:
|
||||
try:
|
||||
renderer = JSONRenderer if request.json else HTMLRenderer
|
||||
except InvalidUsage:
|
||||
render_format = request.route.ctx.error_format
|
||||
except AttributeError:
|
||||
...
|
||||
|
||||
content_type = request.headers.getone("content-type", "").split(
|
||||
";"
|
||||
)[0]
|
||||
|
||||
acceptable = request.accept
|
||||
|
||||
# If the format is auto still, make a guess
|
||||
if render_format == "auto":
|
||||
# First, if there is an Accept header, check if text/html
|
||||
# is the first option
|
||||
# According to MDN Web Docs, all major browsers use text/html
|
||||
# as the primary value in Accept (with the exception of IE 8,
|
||||
# and, well, if you are supporting IE 8, then you have bigger
|
||||
# problems to concern yourself with than what default exception
|
||||
# renderer is used)
|
||||
# Source:
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
|
||||
|
||||
if acceptable and acceptable[0].match(
|
||||
"text/html",
|
||||
allow_type_wildcard=False,
|
||||
allow_subtype_wildcard=False,
|
||||
):
|
||||
renderer = HTMLRenderer
|
||||
|
||||
content_type, *_ = request.headers.getone(
|
||||
"content-type", ""
|
||||
).split(";")
|
||||
renderer = RENDERERS_BY_CONTENT_TYPE.get(
|
||||
content_type, renderer
|
||||
)
|
||||
# Second, if there is an Accept header, check if
|
||||
# application/json is an option, or if the content-type
|
||||
# is application/json
|
||||
elif (
|
||||
acceptable
|
||||
and acceptable.match(
|
||||
"application/json",
|
||||
allow_type_wildcard=False,
|
||||
allow_subtype_wildcard=False,
|
||||
)
|
||||
or content_type == "application/json"
|
||||
):
|
||||
renderer = JSONRenderer
|
||||
|
||||
# Third, if there is no Accept header, assume we want text.
|
||||
# The likely use case here is a raw socket.
|
||||
elif not acceptable:
|
||||
renderer = TextRenderer
|
||||
else:
|
||||
# Fourth, look to see if there was a JSON body
|
||||
# When in this situation, the request is probably coming
|
||||
# from curl, an API client like Postman or Insomnia, or a
|
||||
# package like requests or httpx
|
||||
try:
|
||||
# Give them the benefit of the doubt if they did:
|
||||
# $ curl localhost:8000 -d '{"foo": "bar"}'
|
||||
# And provide them with JSONRenderer
|
||||
renderer = JSONRenderer if request.json else base
|
||||
except InvalidUsage:
|
||||
renderer = base
|
||||
else:
|
||||
render_format = request.app.config.FALLBACK_ERROR_FORMAT
|
||||
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)
|
||||
|
||||
# Lastly, if there is an Accept header, make sure
|
||||
# our choice is okay
|
||||
if acceptable:
|
||||
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
|
||||
if type_ and type_ not in acceptable:
|
||||
# If the renderer selected is not in the Accept header
|
||||
# look through what is in the Accept header, and select
|
||||
# the first option that matches. Otherwise, just drop back
|
||||
# to the original default
|
||||
for accept in acceptable:
|
||||
mtype = f"{accept.type_}/{accept.subtype}"
|
||||
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
|
||||
if maybe:
|
||||
renderer = maybe
|
||||
break
|
||||
else:
|
||||
renderer = base
|
||||
|
||||
renderer = t.cast(t.Type[BaseRenderer], renderer)
|
||||
return renderer(request, exception, debug).render()
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from sanic.errorpages import exception_response
|
||||
from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
|
||||
from sanic.exceptions import (
|
||||
ContentRangeError,
|
||||
HeaderNotFound,
|
||||
InvalidRangeType,
|
||||
)
|
||||
from sanic.log import error_logger
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
|
@ -23,10 +24,15 @@ class ErrorHandler:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
self.cached_handlers = {}
|
||||
# Beginning in v22.3, the base renderer will be TextRenderer
|
||||
def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer):
|
||||
self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
|
||||
self.cached_handlers: Dict[
|
||||
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
|
||||
] = {}
|
||||
self.debug = False
|
||||
self.fallback = fallback
|
||||
self.base = base
|
||||
|
||||
def add(self, exception, handler, route_names: Optional[List[str]] = None):
|
||||
"""
|
||||
|
@ -142,7 +148,13 @@ class ErrorHandler:
|
|||
:return:
|
||||
"""
|
||||
self.log(request, exception)
|
||||
return exception_response(request, exception, self.debug)
|
||||
return exception_response(
|
||||
request,
|
||||
exception,
|
||||
debug=self.debug,
|
||||
base=self.base,
|
||||
fallback=self.fallback,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log(request, exception):
|
||||
|
|
|
@ -35,7 +35,7 @@ _host_re = re.compile(
|
|||
|
||||
def parse_arg_as_accept(f):
|
||||
def func(self, other, *args, **kwargs):
|
||||
if not isinstance(other, Accept):
|
||||
if not isinstance(other, Accept) and other:
|
||||
other = Accept.parse(other)
|
||||
return f(self, other, *args, **kwargs)
|
||||
|
||||
|
@ -181,6 +181,27 @@ class Accept(str):
|
|||
return cls(mtype, MediaType(type_), MediaType(subtype), **params)
|
||||
|
||||
|
||||
class AcceptContainer(list):
|
||||
def __contains__(self, o: object) -> bool:
|
||||
return any(item.match(o) for item in self)
|
||||
|
||||
def match(
|
||||
self,
|
||||
o: object,
|
||||
*,
|
||||
allow_type_wildcard: bool = True,
|
||||
allow_subtype_wildcard: bool = True,
|
||||
) -> bool:
|
||||
return any(
|
||||
item.match(
|
||||
o,
|
||||
allow_type_wildcard=allow_type_wildcard,
|
||||
allow_subtype_wildcard=allow_subtype_wildcard,
|
||||
)
|
||||
for item in self
|
||||
)
|
||||
|
||||
|
||||
def parse_content_header(value: str) -> Tuple[str, Options]:
|
||||
"""Parse content-type and content-disposition header values.
|
||||
|
||||
|
@ -356,7 +377,7 @@ def _sort_accept_value(accept: Accept):
|
|||
)
|
||||
|
||||
|
||||
def parse_accept(accept: str) -> List[Accept]:
|
||||
def parse_accept(accept: str) -> AcceptContainer:
|
||||
"""Parse an Accept header and order the acceptable media types in
|
||||
accorsing to RFC 7231, s. 5.3.2
|
||||
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
|
||||
|
@ -370,4 +391,6 @@ def parse_accept(accept: str) -> List[Accept]:
|
|||
|
||||
accept_list.append(Accept.parse(mtype))
|
||||
|
||||
return sorted(accept_list, key=_sort_accept_value, reverse=True)
|
||||
return AcceptContainer(
|
||||
sorted(accept_list, key=_sort_accept_value, reverse=True)
|
||||
)
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
from ast import NodeVisitor, Return, parse
|
||||
from functools import partial, wraps
|
||||
from inspect import signature
|
||||
from inspect import getsource, signature
|
||||
from mimetypes import guess_type
|
||||
from os import path
|
||||
from pathlib import PurePath
|
||||
from re import sub
|
||||
from textwrap import dedent
|
||||
from time import gmtime, strftime
|
||||
from typing import Callable, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
from sanic.compat import stat_async
|
||||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
|
||||
from sanic.errorpages import RESPONSE_MAPPING
|
||||
from sanic.exceptions import (
|
||||
ContentRangeError,
|
||||
FileNotFound,
|
||||
|
@ -61,6 +64,7 @@ class RouteMixin:
|
|||
unquote: bool = False,
|
||||
static: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Decorate a function to be registered as a route
|
||||
|
@ -103,6 +107,7 @@ class RouteMixin:
|
|||
nonlocal websocket
|
||||
nonlocal static
|
||||
nonlocal version_prefix
|
||||
nonlocal error_format
|
||||
|
||||
if isinstance(handler, tuple):
|
||||
# if a handler fn is already wrapped in a route, the handler
|
||||
|
@ -128,6 +133,9 @@ class RouteMixin:
|
|||
# subprotocol is unordered, keep it unordered
|
||||
subprotocols = frozenset(subprotocols)
|
||||
|
||||
if not error_format or error_format == "auto":
|
||||
error_format = self._determine_error_format(handler)
|
||||
|
||||
route = FutureRoute(
|
||||
handler,
|
||||
uri,
|
||||
|
@ -143,6 +151,7 @@ class RouteMixin:
|
|||
unquote,
|
||||
static,
|
||||
version_prefix,
|
||||
error_format,
|
||||
)
|
||||
|
||||
self._future_routes.add(route)
|
||||
|
@ -186,6 +195,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteHandler:
|
||||
"""A helper method to register class instance or
|
||||
functions as a handler to the application url
|
||||
|
@ -236,6 +246,7 @@ class RouteMixin:
|
|||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)(handler)
|
||||
return handler
|
||||
|
||||
|
@ -249,6 +260,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **GET** *HTTP* method
|
||||
|
@ -272,6 +284,7 @@ class RouteMixin:
|
|||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def post(
|
||||
|
@ -283,6 +296,7 @@ class RouteMixin:
|
|||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **POST** *HTTP* method
|
||||
|
@ -306,6 +320,7 @@ class RouteMixin:
|
|||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def put(
|
||||
|
@ -317,6 +332,7 @@ class RouteMixin:
|
|||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **PUT** *HTTP* method
|
||||
|
@ -340,6 +356,7 @@ class RouteMixin:
|
|||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def head(
|
||||
|
@ -351,6 +368,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **HEAD** *HTTP* method
|
||||
|
@ -382,6 +400,7 @@ class RouteMixin:
|
|||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def options(
|
||||
|
@ -393,6 +412,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **OPTIONS** *HTTP* method
|
||||
|
@ -424,6 +444,7 @@ class RouteMixin:
|
|||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def patch(
|
||||
|
@ -435,6 +456,7 @@ class RouteMixin:
|
|||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **PATCH** *HTTP* method
|
||||
|
@ -468,6 +490,7 @@ class RouteMixin:
|
|||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def delete(
|
||||
|
@ -479,6 +502,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
ignore_body: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> RouteWrapper:
|
||||
"""
|
||||
Add an API URL under the **DELETE** *HTTP* method
|
||||
|
@ -502,6 +526,7 @@ class RouteMixin:
|
|||
name=name,
|
||||
ignore_body=ignore_body,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def websocket(
|
||||
|
@ -514,6 +539,7 @@ class RouteMixin:
|
|||
name: Optional[str] = None,
|
||||
apply: bool = True,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorate a function to be registered as a websocket route
|
||||
|
@ -540,6 +566,7 @@ class RouteMixin:
|
|||
subprotocols=subprotocols,
|
||||
websocket=True,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)
|
||||
|
||||
def add_websocket_route(
|
||||
|
@ -552,6 +579,7 @@ class RouteMixin:
|
|||
version: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
A helper method to register a function as a websocket route.
|
||||
|
@ -580,6 +608,7 @@ class RouteMixin:
|
|||
version=version,
|
||||
name=name,
|
||||
version_prefix=version_prefix,
|
||||
error_format=error_format,
|
||||
)(handler)
|
||||
|
||||
def static(
|
||||
|
@ -789,8 +818,8 @@ class RouteMixin:
|
|||
)
|
||||
except Exception:
|
||||
error_logger.exception(
|
||||
f"Exception in static request handler:\
|
||||
path={file_or_directory}, "
|
||||
f"Exception in static request handler: "
|
||||
f"path={file_or_directory}, "
|
||||
f"relative_url={__file_uri__}"
|
||||
)
|
||||
raise
|
||||
|
@ -888,3 +917,43 @@ class RouteMixin:
|
|||
)(_handler)
|
||||
|
||||
return route
|
||||
|
||||
def _determine_error_format(self, handler) -> str:
|
||||
if not isinstance(handler, CompositionView):
|
||||
try:
|
||||
src = dedent(getsource(handler))
|
||||
tree = parse(src)
|
||||
http_response_types = self._get_response_types(tree)
|
||||
|
||||
if len(http_response_types) == 1:
|
||||
return next(iter(http_response_types))
|
||||
except OSError:
|
||||
...
|
||||
|
||||
return "auto"
|
||||
|
||||
def _get_response_types(self, node):
|
||||
types = set()
|
||||
|
||||
class HttpResponseVisitor(NodeVisitor):
|
||||
def visit_Return(self, node: Return) -> Any:
|
||||
nonlocal types
|
||||
|
||||
try:
|
||||
checks = [node.value.func.id] # type: ignore
|
||||
if node.value.keywords: # type: ignore
|
||||
checks += [
|
||||
k.value
|
||||
for k in node.value.keywords # type: ignore
|
||||
if k.arg == "content_type"
|
||||
]
|
||||
|
||||
for check in checks:
|
||||
if check in RESPONSE_MAPPING:
|
||||
types.add(RESPONSE_MAPPING[check])
|
||||
except AttributeError:
|
||||
...
|
||||
|
||||
HttpResponseVisitor().visit(node)
|
||||
|
||||
return types
|
||||
|
|
|
@ -24,6 +24,7 @@ class FutureRoute(NamedTuple):
|
|||
unquote: bool
|
||||
static: bool
|
||||
version_prefix: str
|
||||
error_format: Optional[str]
|
||||
|
||||
|
||||
class FutureListener(NamedTuple):
|
||||
|
|
|
@ -34,7 +34,7 @@ from sanic.compat import CancelledErrors, Header
|
|||
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
|
||||
from sanic.exceptions import InvalidUsage
|
||||
from sanic.headers import (
|
||||
Accept,
|
||||
AcceptContainer,
|
||||
Options,
|
||||
parse_accept,
|
||||
parse_content_header,
|
||||
|
@ -139,7 +139,7 @@ class Request:
|
|||
self.conn_info: Optional[ConnInfo] = None
|
||||
self.ctx = SimpleNamespace()
|
||||
self.parsed_forwarded: Optional[Options] = None
|
||||
self.parsed_accept: Optional[List[Accept]] = None
|
||||
self.parsed_accept: Optional[AcceptContainer] = None
|
||||
self.parsed_json = None
|
||||
self.parsed_form = None
|
||||
self.parsed_files = None
|
||||
|
@ -301,7 +301,7 @@ class Request:
|
|||
return self.parsed_json
|
||||
|
||||
@property
|
||||
def accept(self) -> List[Accept]:
|
||||
def accept(self) -> AcceptContainer:
|
||||
if self.parsed_accept is None:
|
||||
accept_header = self.headers.getone("accept", "")
|
||||
self.parsed_accept = parse_accept(accept_header)
|
||||
|
|
|
@ -13,6 +13,7 @@ from sanic_routing.exceptions import (
|
|||
from sanic_routing.route import Route # type: ignore
|
||||
|
||||
from sanic.constants import HTTP_METHODS
|
||||
from sanic.errorpages import check_error_format
|
||||
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
|
||||
from sanic.models.handler_types import RouteHandler
|
||||
|
||||
|
@ -78,6 +79,7 @@ class Router(BaseRouter):
|
|||
unquote: bool = False,
|
||||
static: bool = False,
|
||||
version_prefix: str = "/v",
|
||||
error_format: Optional[str] = None,
|
||||
) -> Union[Route, List[Route]]:
|
||||
"""
|
||||
Add a handler to the router
|
||||
|
@ -137,6 +139,11 @@ class Router(BaseRouter):
|
|||
route.ctx.stream = stream
|
||||
route.ctx.hosts = hosts
|
||||
route.ctx.static = static
|
||||
route.ctx.error_format = (
|
||||
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
|
||||
)
|
||||
|
||||
check_error_format(route.ctx.error_format)
|
||||
|
||||
routes.append(route)
|
||||
|
||||
|
|
|
@ -109,6 +109,7 @@ def sanic_router(app):
|
|||
# noinspection PyProtectedMember
|
||||
def _setup(route_details: tuple) -> Tuple[Router, tuple]:
|
||||
router = Router()
|
||||
router.ctx.app = app
|
||||
added_router = []
|
||||
for method, route in route_details:
|
||||
try:
|
||||
|
|
|
@ -20,4 +20,4 @@ def test_bad_request_response(app):
|
|||
|
||||
app.run(host="127.0.0.1", port=42101, debug=False)
|
||||
assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n"
|
||||
assert b"Bad Request" in lines[-1]
|
||||
assert b"Bad Request" in lines[-2]
|
||||
|
|
|
@ -3,7 +3,12 @@ from pytest import raises
|
|||
from sanic.app import Sanic
|
||||
from sanic.blueprint_group import BlueprintGroup
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import Forbidden, InvalidUsage, SanicException, ServerError
|
||||
from sanic.exceptions import (
|
||||
Forbidden,
|
||||
InvalidUsage,
|
||||
SanicException,
|
||||
ServerError,
|
||||
)
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse, text
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.errorpages import exception_response
|
||||
from sanic.exceptions import NotFound
|
||||
from sanic.errorpages import HTMLRenderer, exception_response
|
||||
from sanic.exceptions import NotFound, SanicException
|
||||
from sanic.request import Request
|
||||
from sanic.response import HTTPResponse
|
||||
from sanic.response import HTTPResponse, html, json, text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -20,7 +20,7 @@ def app():
|
|||
|
||||
@pytest.fixture
|
||||
def fake_request(app):
|
||||
return Request(b"/foobar", {}, "1.1", "GET", None, app)
|
||||
return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -47,7 +47,13 @@ def test_should_return_html_valid_setting(
|
|||
try:
|
||||
raise exception("bad stuff")
|
||||
except Exception as e:
|
||||
response = exception_response(fake_request, e, True)
|
||||
response = exception_response(
|
||||
fake_request,
|
||||
e,
|
||||
True,
|
||||
base=HTMLRenderer,
|
||||
fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT,
|
||||
)
|
||||
|
||||
assert isinstance(response, HTTPResponse)
|
||||
assert response.status == status
|
||||
|
@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app):
|
|||
app.config.FALLBACK_ERROR_FORMAT = "auto"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error", headers={"content-type": "application/json"}
|
||||
"/error", headers={"content-type": "application/json", "accept": "*/*"}
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error", headers={"content-type": "text/plain"}
|
||||
"/error", headers={"content-type": "foo/bar", "accept": "*/*"}
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
def test_route_error_format_set_on_auto(app):
|
||||
@app.get("/text")
|
||||
def text_response(request):
|
||||
return text(request.route.ctx.error_format)
|
||||
|
||||
@app.get("/json")
|
||||
def json_response(request):
|
||||
return json({"format": request.route.ctx.error_format})
|
||||
|
||||
@app.get("/html")
|
||||
def html_response(request):
|
||||
return html(request.route.ctx.error_format)
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.text == "text"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.json["format"] == "json"
|
||||
|
||||
_, response = app.test_client.get("/html")
|
||||
assert response.text == "html"
|
||||
|
||||
|
||||
def test_route_error_response_from_auto_route(app):
|
||||
@app.get("/text")
|
||||
def text_response(request):
|
||||
raise Exception("oops")
|
||||
return text("Never gonna see this")
|
||||
|
||||
@app.get("/json")
|
||||
def json_response(request):
|
||||
raise Exception("oops")
|
||||
return json({"message": "Never gonna see this"})
|
||||
|
||||
@app.get("/html")
|
||||
def html_response(request):
|
||||
raise Exception("oops")
|
||||
return html("<h1>Never gonna see this</h1>")
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get("/html")
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
def test_route_error_response_from_explicit_format(app):
|
||||
@app.get("/text", error_format="json")
|
||||
def text_response(request):
|
||||
raise Exception("oops")
|
||||
return text("Never gonna see this")
|
||||
|
||||
@app.get("/json", error_format="text")
|
||||
def json_response(request):
|
||||
raise Exception("oops")
|
||||
return json({"message": "Never gonna see this"})
|
||||
|
||||
_, response = app.test_client.get("/text")
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get("/json")
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
|
||||
def test_unknown_fallback_format(app):
|
||||
with pytest.raises(SanicException, match="Unknown format: bad"):
|
||||
app.config.FALLBACK_ERROR_FORMAT = "bad"
|
||||
|
||||
|
||||
def test_route_error_format_unknown(app):
|
||||
with pytest.raises(SanicException, match="Unknown format: bad"):
|
||||
|
||||
@app.get("/text", error_format="bad")
|
||||
def handler(request):
|
||||
...
|
||||
|
||||
|
||||
def test_fallback_with_content_type_mismatch_accept(app):
|
||||
app.config.FALLBACK_ERROR_FORMAT = "auto"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error",
|
||||
headers={"content-type": "application/json", "accept": "text/plain"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/error",
|
||||
headers={"content-type": "text/plain", "accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
app.router.reset()
|
||||
|
||||
@app.route("/alt1")
|
||||
@app.route("/alt2", error_format="text")
|
||||
@app.route("/alt3", error_format="html")
|
||||
def handler(_):
|
||||
raise Exception("problem here")
|
||||
# Yes, we know this return value is unreachable. This is on purpose.
|
||||
return json({})
|
||||
|
||||
app.router.finalize()
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt1",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
_, response = app.test_client.get(
|
||||
"/alt1",
|
||||
headers={"accept": "foo/bar,*/*"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt2",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
_, response = app.test_client.get(
|
||||
"/alt2",
|
||||
headers={"accept": "foo/bar,*/*"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/plain; charset=utf-8"
|
||||
|
||||
_, response = app.test_client.get(
|
||||
"/alt3",
|
||||
headers={"accept": "foo/bar"},
|
||||
)
|
||||
assert response.status == 500
|
||||
assert response.content_type == "text/html; charset=utf-8"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"accept,content_type,expected",
|
||||
(
|
||||
(None, None, "text/plain; charset=utf-8"),
|
||||
("foo/bar", None, "text/html; charset=utf-8"),
|
||||
("application/json", None, "application/json"),
|
||||
("application/json,text/plain", None, "application/json"),
|
||||
("text/plain,application/json", None, "application/json"),
|
||||
("text/plain,foo/bar", None, "text/plain; charset=utf-8"),
|
||||
# Following test is valid after v22.3
|
||||
# ("text/plain,text/html", None, "text/plain; charset=utf-8"),
|
||||
("*/*", "foo/bar", "text/html; charset=utf-8"),
|
||||
("*/*", "application/json", "application/json"),
|
||||
),
|
||||
)
|
||||
def test_combinations_for_auto(fake_request, accept, content_type, expected):
|
||||
if accept:
|
||||
fake_request.headers["accept"] = accept
|
||||
else:
|
||||
del fake_request.headers["accept"]
|
||||
|
||||
if content_type:
|
||||
fake_request.headers["content-type"] = content_type
|
||||
|
||||
try:
|
||||
raise Exception("bad stuff")
|
||||
except Exception as e:
|
||||
response = exception_response(
|
||||
fake_request,
|
||||
e,
|
||||
True,
|
||||
base=HTMLRenderer,
|
||||
fallback="auto",
|
||||
)
|
||||
|
||||
assert response.content_type == expected
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from sanic import Sanic
|
||||
|
@ -8,9 +10,6 @@ from sanic.handlers import ErrorHandler
|
|||
from sanic.response import stream, text
|
||||
|
||||
|
||||
exception_handler_app = Sanic("test_exception_handler")
|
||||
|
||||
|
||||
async def sample_streaming_fn(response):
|
||||
await response.write("foo,")
|
||||
await asyncio.sleep(0.001)
|
||||
|
@ -21,107 +20,102 @@ class ErrorWithRequestCtx(ServerError):
|
|||
pass
|
||||
|
||||
|
||||
@exception_handler_app.route("/1")
|
||||
def handler_1(request):
|
||||
raise InvalidUsage("OK")
|
||||
@pytest.fixture
|
||||
def exception_handler_app():
|
||||
exception_handler_app = Sanic("test_exception_handler")
|
||||
|
||||
@exception_handler_app.route("/1", error_format="html")
|
||||
def handler_1(request):
|
||||
raise InvalidUsage("OK")
|
||||
|
||||
@exception_handler_app.route("/2", error_format="html")
|
||||
def handler_2(request):
|
||||
raise ServerError("OK")
|
||||
|
||||
@exception_handler_app.route("/3", error_format="html")
|
||||
def handler_3(request):
|
||||
raise NotFound("OK")
|
||||
|
||||
@exception_handler_app.route("/4", error_format="html")
|
||||
def handler_4(request):
|
||||
foo = bar # noqa -- F821
|
||||
return text(foo)
|
||||
|
||||
@exception_handler_app.route("/5", error_format="html")
|
||||
def handler_5(request):
|
||||
class CustomServerError(ServerError):
|
||||
pass
|
||||
|
||||
raise CustomServerError("Custom server error")
|
||||
|
||||
@exception_handler_app.route("/6/<arg:int>", error_format="html")
|
||||
def handler_6(request, arg):
|
||||
try:
|
||||
foo = 1 / arg
|
||||
except Exception as e:
|
||||
raise e from ValueError(f"{arg}")
|
||||
return text(foo)
|
||||
|
||||
@exception_handler_app.route("/7", error_format="html")
|
||||
def handler_7(request):
|
||||
raise Forbidden("go away!")
|
||||
|
||||
@exception_handler_app.route("/8", error_format="html")
|
||||
def handler_8(request):
|
||||
|
||||
raise ErrorWithRequestCtx("OK")
|
||||
|
||||
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
|
||||
def handler_exception_with_ctx(request, exception):
|
||||
return text(request.ctx.middleware_ran)
|
||||
|
||||
@exception_handler_app.exception(ServerError)
|
||||
def handler_exception(request, exception):
|
||||
return text("OK")
|
||||
|
||||
@exception_handler_app.exception(Forbidden)
|
||||
async def async_handler_exception(request, exception):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
content_type="text/csv",
|
||||
)
|
||||
|
||||
@exception_handler_app.middleware
|
||||
async def some_request_middleware(request):
|
||||
request.ctx.middleware_ran = "Done."
|
||||
|
||||
return exception_handler_app
|
||||
|
||||
|
||||
@exception_handler_app.route("/2")
|
||||
def handler_2(request):
|
||||
raise ServerError("OK")
|
||||
|
||||
|
||||
@exception_handler_app.route("/3")
|
||||
def handler_3(request):
|
||||
raise NotFound("OK")
|
||||
|
||||
|
||||
@exception_handler_app.route("/4")
|
||||
def handler_4(request):
|
||||
foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception
|
||||
return text(foo)
|
||||
|
||||
|
||||
@exception_handler_app.route("/5")
|
||||
def handler_5(request):
|
||||
class CustomServerError(ServerError):
|
||||
pass
|
||||
|
||||
raise CustomServerError("Custom server error")
|
||||
|
||||
|
||||
@exception_handler_app.route("/6/<arg:int>")
|
||||
def handler_6(request, arg):
|
||||
try:
|
||||
foo = 1 / arg
|
||||
except Exception as e:
|
||||
raise e from ValueError(f"{arg}")
|
||||
return text(foo)
|
||||
|
||||
|
||||
@exception_handler_app.route("/7")
|
||||
def handler_7(request):
|
||||
raise Forbidden("go away!")
|
||||
|
||||
|
||||
@exception_handler_app.route("/8")
|
||||
def handler_8(request):
|
||||
|
||||
raise ErrorWithRequestCtx("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
|
||||
def handler_exception_with_ctx(request, exception):
|
||||
return text(request.ctx.middleware_ran)
|
||||
|
||||
|
||||
@exception_handler_app.exception(ServerError)
|
||||
def handler_exception(request, exception):
|
||||
return text("OK")
|
||||
|
||||
|
||||
@exception_handler_app.exception(Forbidden)
|
||||
async def async_handler_exception(request, exception):
|
||||
return stream(
|
||||
sample_streaming_fn,
|
||||
content_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@exception_handler_app.middleware
|
||||
async def some_request_middleware(request):
|
||||
request.ctx.middleware_ran = "Done."
|
||||
|
||||
|
||||
def test_invalid_usage_exception_handler():
|
||||
def test_invalid_usage_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/1")
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
def test_server_error_exception_handler():
|
||||
def test_server_error_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/2")
|
||||
assert response.status == 200
|
||||
assert response.text == "OK"
|
||||
|
||||
|
||||
def test_not_found_exception_handler():
|
||||
def test_not_found_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/3")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_text_exception__handler():
|
||||
def test_text_exception__handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/random")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
||||
|
||||
def test_async_exception_handler():
|
||||
def test_async_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/7")
|
||||
assert response.status == 200
|
||||
assert response.text == "foo,bar"
|
||||
|
||||
|
||||
def test_html_traceback_output_in_debug_mode():
|
||||
def test_html_traceback_output_in_debug_mode(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/4", debug=True)
|
||||
assert response.status == 500
|
||||
soup = BeautifulSoup(response.body, "html.parser")
|
||||
|
@ -136,12 +130,12 @@ def test_html_traceback_output_in_debug_mode():
|
|||
) == summary_text
|
||||
|
||||
|
||||
def test_inherited_exception_handler():
|
||||
def test_inherited_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/5")
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
def test_chained_exception_handler():
|
||||
def test_chained_exception_handler(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get(
|
||||
"/6/0", debug=True
|
||||
)
|
||||
|
@ -153,7 +147,6 @@ def test_chained_exception_handler():
|
|||
assert "handler_6" in html
|
||||
assert "foo = 1 / arg" in html
|
||||
assert "ValueError" in html
|
||||
assert "The above exception was the direct cause" in html
|
||||
|
||||
summary_text = " ".join(soup.select(".summary")[0].text.split())
|
||||
assert (
|
||||
|
@ -161,7 +154,7 @@ def test_chained_exception_handler():
|
|||
) == summary_text
|
||||
|
||||
|
||||
def test_exception_handler_lookup():
|
||||
def test_exception_handler_lookup(exception_handler_app):
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -184,7 +177,7 @@ def test_exception_handler_lookup():
|
|||
class ModuleNotFoundError(ImportError):
|
||||
pass
|
||||
|
||||
handler = ErrorHandler()
|
||||
handler = ErrorHandler("auto")
|
||||
handler.add(ImportError, import_error_handler)
|
||||
handler.add(CustomError, custom_error_handler)
|
||||
handler.add(ServerError, server_error_handler)
|
||||
|
@ -209,7 +202,7 @@ def test_exception_handler_lookup():
|
|||
)
|
||||
|
||||
|
||||
def test_exception_handler_processed_request_middleware():
|
||||
def test_exception_handler_processed_request_middleware(exception_handler_app):
|
||||
request, response = exception_handler_app.test_client.get("/8")
|
||||
assert response.status == 200
|
||||
assert response.text == "Done."
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
from sanic import headers, text
|
||||
from sanic.exceptions import InvalidHeader, PayloadTooLarge
|
||||
from sanic.http import Http
|
||||
from sanic.request import Request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -338,3 +339,31 @@ def test_value_in_accept(value):
|
|||
assert "foo/bar" in acceptable
|
||||
assert "foo/*" in acceptable
|
||||
assert "*/*" in acceptable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ("foo/bar", "foo/*"))
|
||||
def test_value_not_in_accept(value):
|
||||
acceptable = headers.parse_accept(value)
|
||||
assert "no/match" not in acceptable
|
||||
assert "no/*" not in acceptable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,expected",
|
||||
(
|
||||
(
|
||||
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501
|
||||
[
|
||||
"text/html",
|
||||
"application/xhtml+xml",
|
||||
"image/avif",
|
||||
"image/webp",
|
||||
"application/xml;q=0.9",
|
||||
"*/*;q=0.8",
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_browser_headers(header, expected):
|
||||
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
|
||||
assert request.accept == expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user