Smarter auto fallback (#2162)

* Smarter auto fallback

* remove config from blueprints

* Add tests for error formatting

* Add check for proper format

* Fix some tests

* Add some tests

* docstring

* Add accept matching

* Add some more tests on matching

* Fix contains bug, earlier return on MediaType eq

* Add matching flags for wildcards

* Add mathing controls to accept

* Cleanup dev cruft

* Add cleanup and resolve OSError relating to test implementation

* Fix test

* Fix some typos
This commit is contained in:
Adam Hopkins
2021-09-29 23:53:49 +03:00
committed by GitHub
parent b5f2bd9b0e
commit cf1d2148ac
16 changed files with 562 additions and 125 deletions

View File

@@ -179,7 +179,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.configure_logging = configure_logging
self.ctx = ctx or SimpleNamespace()
self.debug = None
self.error_handler = error_handler or ErrorHandler()
self.error_handler = error_handler or ErrorHandler(
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.is_running = False
self.is_stopping = False
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)

View File

@@ -266,6 +266,9 @@ class Blueprint(BaseSanic):
opt_version = options.get("version", None)
opt_strict_slashes = options.get("strict_slashes", None)
opt_version_prefix = options.get("version_prefix", self.version_prefix)
error_format = options.get(
"error_format", app.config.FALLBACK_ERROR_FORMAT
)
routes = []
middleware = []
@@ -313,6 +316,7 @@ class Blueprint(BaseSanic):
future.unquote,
future.static,
version_prefix,
error_format,
)
route = app._apply_route(apply_route)

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union
from warnings import warn
from sanic.errorpages import check_error_format
from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool
@@ -20,7 +21,7 @@ BASE_LOGO = """
DEFAULT_CONFIG = {
"ACCESS_LOG": True,
"EVENT_AUTOREGISTER": False,
"FALLBACK_ERROR_FORMAT": "html",
"FALLBACK_ERROR_FORMAT": "auto",
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
"FORWARDED_SECRET": None,
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
@@ -94,6 +95,7 @@ class Config(dict):
self.load_environment_vars(SANIC_PREFIX)
self._configure_header_size()
self._check_error_format()
def __getattr__(self, attr):
try:
@@ -109,6 +111,8 @@ class Config(dict):
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
def _configure_header_size(self):
Http.set_header_max_size(
@@ -117,6 +121,9 @@ class Config(dict):
self.REQUEST_MAX_SIZE,
)
def _check_error_format(self):
check_error_format(self.FALLBACK_ERROR_FORMAT)
def load_environment_vars(self, prefix=SANIC_PREFIX):
"""
Looks for prefixed environment variables and applies

View File

@@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = {
}
RENDERERS_BY_CONTENT_TYPE = {
"multipart/form-data": HTMLRenderer,
"application/json": JSONRenderer,
"text/plain": TextRenderer,
"application/json": JSONRenderer,
"multipart/form-data": HTMLRenderer,
"text/html": HTMLRenderer,
}
CONTENT_TYPE_BY_RENDERERS = {
v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()
}
RESPONSE_MAPPING = {
"empty": "html",
"json": "json",
"text": "text",
"raw": "text",
"html": "html",
"file": "html",
"file_stream": "text",
"stream": "text",
"redirect": "html",
"text/plain": "text",
"text/html": "html",
"application/json": "json",
}
def check_error_format(format):
if format not in RENDERERS_BY_CONFIG and format != "auto":
raise SanicException(f"Unknown format: {format}")
def exception_response(
request: Request,
exception: Exception,
debug: bool,
fallback: str,
base: t.Type[BaseRenderer],
renderer: t.Type[t.Optional[BaseRenderer]] = None,
) -> HTTPResponse:
"""
Render a response for the default FALLBACK exception handler.
"""
content_type = None
if not renderer:
renderer = HTMLRenderer
# Make sure we have something set
renderer = base
render_format = fallback
if request:
if request.app.config.FALLBACK_ERROR_FORMAT == "auto":
# If there is a request, try and get the format
# from the route
if request.route:
try:
renderer = JSONRenderer if request.json else HTMLRenderer
except InvalidUsage:
render_format = request.route.ctx.error_format
except AttributeError:
...
content_type = request.headers.getone("content-type", "").split(
";"
)[0]
acceptable = request.accept
# If the format is auto still, make a guess
if render_format == "auto":
# First, if there is an Accept header, check if text/html
# is the first option
# According to MDN Web Docs, all major browsers use text/html
# as the primary value in Accept (with the exception of IE 8,
# and, well, if you are supporting IE 8, then you have bigger
# problems to concern yourself with than what default exception
# renderer is used)
# Source:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
if acceptable and acceptable[0].match(
"text/html",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
):
renderer = HTMLRenderer
content_type, *_ = request.headers.getone(
"content-type", ""
).split(";")
renderer = RENDERERS_BY_CONTENT_TYPE.get(
content_type, renderer
)
# Second, if there is an Accept header, check if
# application/json is an option, or if the content-type
# is application/json
elif (
acceptable
and acceptable.match(
"application/json",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
)
or content_type == "application/json"
):
renderer = JSONRenderer
# Third, if there is no Accept header, assume we want text.
# The likely use case here is a raw socket.
elif not acceptable:
renderer = TextRenderer
else:
# Fourth, look to see if there was a JSON body
# When in this situation, the request is probably coming
# from curl, an API client like Postman or Insomnia, or a
# package like requests or httpx
try:
# Give them the benefit of the doubt if they did:
# $ curl localhost:8000 -d '{"foo": "bar"}'
# And provide them with JSONRenderer
renderer = JSONRenderer if request.json else base
except InvalidUsage:
renderer = base
else:
render_format = request.app.config.FALLBACK_ERROR_FORMAT
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)
# Lastly, if there is an Accept header, make sure
# our choice is okay
if acceptable:
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
if type_ and type_ not in acceptable:
# If the renderer selected is not in the Accept header
# look through what is in the Accept header, and select
# the first option that matches. Otherwise, just drop back
# to the original default
for accept in acceptable:
mtype = f"{accept.type_}/{accept.subtype}"
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
if maybe:
renderer = maybe
break
else:
renderer = base
renderer = t.cast(t.Type[BaseRenderer], renderer)
return renderer(request, exception, debug).render()

View File

@@ -1,12 +1,13 @@
from typing import List, Optional
from typing import Dict, List, Optional, Tuple, Type
from sanic.errorpages import exception_response
from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
from sanic.exceptions import (
ContentRangeError,
HeaderNotFound,
InvalidRangeType,
)
from sanic.log import error_logger
from sanic.models.handler_types import RouteHandler
from sanic.response import text
@@ -23,10 +24,15 @@ class ErrorHandler:
"""
def __init__(self):
self.handlers = []
self.cached_handlers = {}
# Beginning in v22.3, the base renderer will be TextRenderer
def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer):
self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
self.cached_handlers: Dict[
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
] = {}
self.debug = False
self.fallback = fallback
self.base = base
def add(self, exception, handler, route_names: Optional[List[str]] = None):
"""
@@ -142,7 +148,13 @@ class ErrorHandler:
:return:
"""
self.log(request, exception)
return exception_response(request, exception, self.debug)
return exception_response(
request,
exception,
debug=self.debug,
base=self.base,
fallback=self.fallback,
)
@staticmethod
def log(request, exception):

View File

@@ -35,7 +35,7 @@ _host_re = re.compile(
def parse_arg_as_accept(f):
def func(self, other, *args, **kwargs):
if not isinstance(other, Accept):
if not isinstance(other, Accept) and other:
other = Accept.parse(other)
return f(self, other, *args, **kwargs)
@@ -181,6 +181,27 @@ class Accept(str):
return cls(mtype, MediaType(type_), MediaType(subtype), **params)
class AcceptContainer(list):
def __contains__(self, o: object) -> bool:
return any(item.match(o) for item in self)
def match(
self,
o: object,
*,
allow_type_wildcard: bool = True,
allow_subtype_wildcard: bool = True,
) -> bool:
return any(
item.match(
o,
allow_type_wildcard=allow_type_wildcard,
allow_subtype_wildcard=allow_subtype_wildcard,
)
for item in self
)
def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values.
@@ -356,7 +377,7 @@ def _sort_accept_value(accept: Accept):
)
def parse_accept(accept: str) -> List[Accept]:
def parse_accept(accept: str) -> AcceptContainer:
"""Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
@@ -370,4 +391,6 @@ def parse_accept(accept: str) -> List[Accept]:
accept_list.append(Accept.parse(mtype))
return sorted(accept_list, key=_sort_accept_value, reverse=True)
return AcceptContainer(
sorted(accept_list, key=_sort_accept_value, reverse=True)
)

View File

@@ -1,17 +1,20 @@
from ast import NodeVisitor, Return, parse
from functools import partial, wraps
from inspect import signature
from inspect import getsource, signature
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from textwrap import dedent
from time import gmtime, strftime
from typing import Callable, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union
from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore
from sanic.compat import stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
from sanic.errorpages import RESPONSE_MAPPING
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
@@ -61,6 +64,7 @@ class RouteMixin:
unquote: bool = False,
static: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Decorate a function to be registered as a route
@@ -103,6 +107,7 @@ class RouteMixin:
nonlocal websocket
nonlocal static
nonlocal version_prefix
nonlocal error_format
if isinstance(handler, tuple):
# if a handler fn is already wrapped in a route, the handler
@@ -128,6 +133,9 @@ class RouteMixin:
# subprotocol is unordered, keep it unordered
subprotocols = frozenset(subprotocols)
if not error_format or error_format == "auto":
error_format = self._determine_error_format(handler)
route = FutureRoute(
handler,
uri,
@@ -143,6 +151,7 @@ class RouteMixin:
unquote,
static,
version_prefix,
error_format,
)
self._future_routes.add(route)
@@ -186,6 +195,7 @@ class RouteMixin:
name: Optional[str] = None,
stream: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteHandler:
"""A helper method to register class instance or
functions as a handler to the application url
@@ -236,6 +246,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)(handler)
return handler
@@ -249,6 +260,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **GET** *HTTP* method
@@ -272,6 +284,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def post(
@@ -283,6 +296,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **POST** *HTTP* method
@@ -306,6 +320,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def put(
@@ -317,6 +332,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **PUT** *HTTP* method
@@ -340,6 +356,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def head(
@@ -351,6 +368,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **HEAD** *HTTP* method
@@ -382,6 +400,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def options(
@@ -393,6 +412,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **OPTIONS** *HTTP* method
@@ -424,6 +444,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def patch(
@@ -435,6 +456,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **PATCH** *HTTP* method
@@ -468,6 +490,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def delete(
@@ -479,6 +502,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **DELETE** *HTTP* method
@@ -502,6 +526,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def websocket(
@@ -514,6 +539,7 @@ class RouteMixin:
name: Optional[str] = None,
apply: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
):
"""
Decorate a function to be registered as a websocket route
@@ -540,6 +566,7 @@ class RouteMixin:
subprotocols=subprotocols,
websocket=True,
version_prefix=version_prefix,
error_format=error_format,
)
def add_websocket_route(
@@ -552,6 +579,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
):
"""
A helper method to register a function as a websocket route.
@@ -580,6 +608,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)(handler)
def static(
@@ -789,8 +818,8 @@ class RouteMixin:
)
except Exception:
error_logger.exception(
f"Exception in static request handler:\
path={file_or_directory}, "
f"Exception in static request handler: "
f"path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise
@@ -888,3 +917,43 @@ class RouteMixin:
)(_handler)
return route
def _determine_error_format(self, handler) -> str:
if not isinstance(handler, CompositionView):
try:
src = dedent(getsource(handler))
tree = parse(src)
http_response_types = self._get_response_types(tree)
if len(http_response_types) == 1:
return next(iter(http_response_types))
except OSError:
...
return "auto"
def _get_response_types(self, node):
types = set()
class HttpResponseVisitor(NodeVisitor):
def visit_Return(self, node: Return) -> Any:
nonlocal types
try:
checks = [node.value.func.id] # type: ignore
if node.value.keywords: # type: ignore
checks += [
k.value
for k in node.value.keywords # type: ignore
if k.arg == "content_type"
]
for check in checks:
if check in RESPONSE_MAPPING:
types.add(RESPONSE_MAPPING[check])
except AttributeError:
...
HttpResponseVisitor().visit(node)
return types

View File

@@ -24,6 +24,7 @@ class FutureRoute(NamedTuple):
unquote: bool
static: bool
version_prefix: str
error_format: Optional[str]
class FutureListener(NamedTuple):

View File

@@ -34,7 +34,7 @@ from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage
from sanic.headers import (
Accept,
AcceptContainer,
Options,
parse_accept,
parse_content_header,
@@ -139,7 +139,7 @@ class Request:
self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace()
self.parsed_forwarded: Optional[Options] = None
self.parsed_accept: Optional[List[Accept]] = None
self.parsed_accept: Optional[AcceptContainer] = None
self.parsed_json = None
self.parsed_form = None
self.parsed_files = None
@@ -301,7 +301,7 @@ class Request:
return self.parsed_json
@property
def accept(self) -> List[Accept]:
def accept(self) -> AcceptContainer:
if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "")
self.parsed_accept = parse_accept(accept_header)

View File

@@ -13,6 +13,7 @@ from sanic_routing.exceptions import (
from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
from sanic.models.handler_types import RouteHandler
@@ -78,6 +79,7 @@ class Router(BaseRouter):
unquote: bool = False,
static: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> Union[Route, List[Route]]:
"""
Add a handler to the router
@@ -137,6 +139,11 @@ class Router(BaseRouter):
route.ctx.stream = stream
route.ctx.hosts = hosts
route.ctx.static = static
route.ctx.error_format = (
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
check_error_format(route.ctx.error_format)
routes.append(route)