Smarter auto fallback (#2162)

* Smarter auto fallback

* remove config from blueprints

* Add tests for error formatting

* Add check for proper format

* Fix some tests

* Add some tests

* docstring

* Add accept matching

* Add some more tests on matching

* Fix contains bug, earlier return on MediaType eq

* Add matching flags for wildcards

* Add mathing controls to accept

* Cleanup dev cruft

* Add cleanup and resolve OSError relating to test implementation

* Fix test

* Fix some typos
This commit is contained in:
Adam Hopkins 2021-09-29 23:53:49 +03:00 committed by GitHub
parent b5f2bd9b0e
commit cf1d2148ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 562 additions and 125 deletions

View File

@ -179,7 +179,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.configure_logging = configure_logging self.configure_logging = configure_logging
self.ctx = ctx or SimpleNamespace() self.ctx = ctx or SimpleNamespace()
self.debug = None self.debug = None
self.error_handler = error_handler or ErrorHandler() self.error_handler = error_handler or ErrorHandler(
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.is_running = False self.is_running = False
self.is_stopping = False self.is_stopping = False
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)

View File

@ -266,6 +266,9 @@ class Blueprint(BaseSanic):
opt_version = options.get("version", None) opt_version = options.get("version", None)
opt_strict_slashes = options.get("strict_slashes", None) opt_strict_slashes = options.get("strict_slashes", None)
opt_version_prefix = options.get("version_prefix", self.version_prefix) opt_version_prefix = options.get("version_prefix", self.version_prefix)
error_format = options.get(
"error_format", app.config.FALLBACK_ERROR_FORMAT
)
routes = [] routes = []
middleware = [] middleware = []
@ -313,6 +316,7 @@ class Blueprint(BaseSanic):
future.unquote, future.unquote,
future.static, future.static,
version_prefix, version_prefix,
error_format,
) )
route = app._apply_route(apply_route) route = app._apply_route(apply_route)

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from warnings import warn from warnings import warn
from sanic.errorpages import check_error_format
from sanic.http import Http from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool from .utils import load_module_from_file_location, str_to_bool
@ -20,7 +21,7 @@ BASE_LOGO = """
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"ACCESS_LOG": True, "ACCESS_LOG": True,
"EVENT_AUTOREGISTER": False, "EVENT_AUTOREGISTER": False,
"FALLBACK_ERROR_FORMAT": "html", "FALLBACK_ERROR_FORMAT": "auto",
"FORWARDED_FOR_HEADER": "X-Forwarded-For", "FORWARDED_FOR_HEADER": "X-Forwarded-For",
"FORWARDED_SECRET": None, "FORWARDED_SECRET": None,
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
@ -94,6 +95,7 @@ class Config(dict):
self.load_environment_vars(SANIC_PREFIX) self.load_environment_vars(SANIC_PREFIX)
self._configure_header_size() self._configure_header_size()
self._check_error_format()
def __getattr__(self, attr): def __getattr__(self, attr):
try: try:
@ -109,6 +111,8 @@ class Config(dict):
"REQUEST_MAX_SIZE", "REQUEST_MAX_SIZE",
): ):
self._configure_header_size() self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
def _configure_header_size(self): def _configure_header_size(self):
Http.set_header_max_size( Http.set_header_max_size(
@ -117,6 +121,9 @@ class Config(dict):
self.REQUEST_MAX_SIZE, self.REQUEST_MAX_SIZE,
) )
def _check_error_format(self):
check_error_format(self.FALLBACK_ERROR_FORMAT)
def load_environment_vars(self, prefix=SANIC_PREFIX): def load_environment_vars(self, prefix=SANIC_PREFIX):
""" """
Looks for prefixed environment variables and applies Looks for prefixed environment variables and applies

View File

@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = {
} }
RENDERERS_BY_CONTENT_TYPE = { RENDERERS_BY_CONTENT_TYPE = {
"multipart/form-data": HTMLRenderer,
"application/json": JSONRenderer,
"text/plain": TextRenderer, "text/plain": TextRenderer,
"application/json": JSONRenderer,
"multipart/form-data": HTMLRenderer,
"text/html": HTMLRenderer,
} }
CONTENT_TYPE_BY_RENDERERS = {
v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()
}
RESPONSE_MAPPING = {
"empty": "html",
"json": "json",
"text": "text",
"raw": "text",
"html": "html",
"file": "html",
"file_stream": "text",
"stream": "text",
"redirect": "html",
"text/plain": "text",
"text/html": "html",
"application/json": "json",
}
def check_error_format(format):
if format not in RENDERERS_BY_CONFIG and format != "auto":
raise SanicException(f"Unknown format: {format}")
def exception_response( def exception_response(
request: Request, request: Request,
exception: Exception, exception: Exception,
debug: bool, debug: bool,
fallback: str,
base: t.Type[BaseRenderer],
renderer: t.Type[t.Optional[BaseRenderer]] = None, renderer: t.Type[t.Optional[BaseRenderer]] = None,
) -> HTTPResponse: ) -> HTTPResponse:
""" """
Render a response for the default FALLBACK exception handler. Render a response for the default FALLBACK exception handler.
""" """
content_type = None
if not renderer: if not renderer:
renderer = HTMLRenderer # Make sure we have something set
renderer = base
render_format = fallback
if request: if request:
if request.app.config.FALLBACK_ERROR_FORMAT == "auto": # If there is a request, try and get the format
# from the route
if request.route:
try: try:
renderer = JSONRenderer if request.json else HTMLRenderer render_format = request.route.ctx.error_format
except InvalidUsage: except AttributeError:
...
content_type = request.headers.getone("content-type", "").split(
";"
)[0]
acceptable = request.accept
# If the format is auto still, make a guess
if render_format == "auto":
# First, if there is an Accept header, check if text/html
# is the first option
# According to MDN Web Docs, all major browsers use text/html
# as the primary value in Accept (with the exception of IE 8,
# and, well, if you are supporting IE 8, then you have bigger
# problems to concern yourself with than what default exception
# renderer is used)
# Source:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
if acceptable and acceptable[0].match(
"text/html",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
):
renderer = HTMLRenderer renderer = HTMLRenderer
content_type, *_ = request.headers.getone( # Second, if there is an Accept header, check if
"content-type", "" # application/json is an option, or if the content-type
).split(";") # is application/json
renderer = RENDERERS_BY_CONTENT_TYPE.get( elif (
content_type, renderer acceptable
) and acceptable.match(
"application/json",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
)
or content_type == "application/json"
):
renderer = JSONRenderer
# Third, if there is no Accept header, assume we want text.
# The likely use case here is a raw socket.
elif not acceptable:
renderer = TextRenderer
else:
# Fourth, look to see if there was a JSON body
# When in this situation, the request is probably coming
# from curl, an API client like Postman or Insomnia, or a
# package like requests or httpx
try:
# Give them the benefit of the doubt if they did:
# $ curl localhost:8000 -d '{"foo": "bar"}'
# And provide them with JSONRenderer
renderer = JSONRenderer if request.json else base
except InvalidUsage:
renderer = base
else: else:
render_format = request.app.config.FALLBACK_ERROR_FORMAT
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)
# Lastly, if there is an Accept header, make sure
# our choice is okay
if acceptable:
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
if type_ and type_ not in acceptable:
# If the renderer selected is not in the Accept header
# look through what is in the Accept header, and select
# the first option that matches. Otherwise, just drop back
# to the original default
for accept in acceptable:
mtype = f"{accept.type_}/{accept.subtype}"
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
if maybe:
renderer = maybe
break
else:
renderer = base
renderer = t.cast(t.Type[BaseRenderer], renderer) renderer = t.cast(t.Type[BaseRenderer], renderer)
return renderer(request, exception, debug).render() return renderer(request, exception, debug).render()

View File

@ -1,12 +1,13 @@
from typing import List, Optional from typing import Dict, List, Optional, Tuple, Type
from sanic.errorpages import exception_response from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
from sanic.exceptions import ( from sanic.exceptions import (
ContentRangeError, ContentRangeError,
HeaderNotFound, HeaderNotFound,
InvalidRangeType, InvalidRangeType,
) )
from sanic.log import error_logger from sanic.log import error_logger
from sanic.models.handler_types import RouteHandler
from sanic.response import text from sanic.response import text
@ -23,10 +24,15 @@ class ErrorHandler:
""" """
def __init__(self): # Beginning in v22.3, the base renderer will be TextRenderer
self.handlers = [] def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer):
self.cached_handlers = {} self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
self.cached_handlers: Dict[
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
] = {}
self.debug = False self.debug = False
self.fallback = fallback
self.base = base
def add(self, exception, handler, route_names: Optional[List[str]] = None): def add(self, exception, handler, route_names: Optional[List[str]] = None):
""" """
@ -142,7 +148,13 @@ class ErrorHandler:
:return: :return:
""" """
self.log(request, exception) self.log(request, exception)
return exception_response(request, exception, self.debug) return exception_response(
request,
exception,
debug=self.debug,
base=self.base,
fallback=self.fallback,
)
@staticmethod @staticmethod
def log(request, exception): def log(request, exception):

View File

@ -35,7 +35,7 @@ _host_re = re.compile(
def parse_arg_as_accept(f): def parse_arg_as_accept(f):
def func(self, other, *args, **kwargs): def func(self, other, *args, **kwargs):
if not isinstance(other, Accept): if not isinstance(other, Accept) and other:
other = Accept.parse(other) other = Accept.parse(other)
return f(self, other, *args, **kwargs) return f(self, other, *args, **kwargs)
@ -181,6 +181,27 @@ class Accept(str):
return cls(mtype, MediaType(type_), MediaType(subtype), **params) return cls(mtype, MediaType(type_), MediaType(subtype), **params)
class AcceptContainer(list):
def __contains__(self, o: object) -> bool:
return any(item.match(o) for item in self)
def match(
self,
o: object,
*,
allow_type_wildcard: bool = True,
allow_subtype_wildcard: bool = True,
) -> bool:
return any(
item.match(
o,
allow_type_wildcard=allow_type_wildcard,
allow_subtype_wildcard=allow_subtype_wildcard,
)
for item in self
)
def parse_content_header(value: str) -> Tuple[str, Options]: def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values. """Parse content-type and content-disposition header values.
@ -356,7 +377,7 @@ def _sort_accept_value(accept: Accept):
) )
def parse_accept(accept: str) -> List[Accept]: def parse_accept(accept: str) -> AcceptContainer:
"""Parse an Accept header and order the acceptable media types in """Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2 accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
@ -370,4 +391,6 @@ def parse_accept(accept: str) -> List[Accept]:
accept_list.append(Accept.parse(mtype)) accept_list.append(Accept.parse(mtype))
return sorted(accept_list, key=_sort_accept_value, reverse=True) return AcceptContainer(
sorted(accept_list, key=_sort_accept_value, reverse=True)
)

View File

@ -1,17 +1,20 @@
from ast import NodeVisitor, Return, parse
from functools import partial, wraps from functools import partial, wraps
from inspect import signature from inspect import getsource, signature
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from pathlib import PurePath from pathlib import PurePath
from re import sub from re import sub
from textwrap import dedent
from time import gmtime, strftime from time import gmtime, strftime
from typing import Callable, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union
from urllib.parse import unquote from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
from sanic.compat import stat_async from sanic.compat import stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
from sanic.errorpages import RESPONSE_MAPPING
from sanic.exceptions import ( from sanic.exceptions import (
ContentRangeError, ContentRangeError,
FileNotFound, FileNotFound,
@ -61,6 +64,7 @@ class RouteMixin:
unquote: bool = False, unquote: bool = False,
static: bool = False, static: bool = False,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Decorate a function to be registered as a route Decorate a function to be registered as a route
@ -103,6 +107,7 @@ class RouteMixin:
nonlocal websocket nonlocal websocket
nonlocal static nonlocal static
nonlocal version_prefix nonlocal version_prefix
nonlocal error_format
if isinstance(handler, tuple): if isinstance(handler, tuple):
# if a handler fn is already wrapped in a route, the handler # if a handler fn is already wrapped in a route, the handler
@ -128,6 +133,9 @@ class RouteMixin:
# subprotocol is unordered, keep it unordered # subprotocol is unordered, keep it unordered
subprotocols = frozenset(subprotocols) subprotocols = frozenset(subprotocols)
if not error_format or error_format == "auto":
error_format = self._determine_error_format(handler)
route = FutureRoute( route = FutureRoute(
handler, handler,
uri, uri,
@ -143,6 +151,7 @@ class RouteMixin:
unquote, unquote,
static, static,
version_prefix, version_prefix,
error_format,
) )
self._future_routes.add(route) self._future_routes.add(route)
@ -186,6 +195,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
stream: bool = False, stream: bool = False,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteHandler: ) -> RouteHandler:
"""A helper method to register class instance or """A helper method to register class instance or
functions as a handler to the application url functions as a handler to the application url
@ -236,6 +246,7 @@ class RouteMixin:
version=version, version=version,
name=name, name=name,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
)(handler) )(handler)
return handler return handler
@ -249,6 +260,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
ignore_body: bool = True, ignore_body: bool = True,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **GET** *HTTP* method Add an API URL under the **GET** *HTTP* method
@ -272,6 +284,7 @@ class RouteMixin:
name=name, name=name,
ignore_body=ignore_body, ignore_body=ignore_body,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def post( def post(
@ -283,6 +296,7 @@ class RouteMixin:
version: Optional[int] = None, version: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **POST** *HTTP* method Add an API URL under the **POST** *HTTP* method
@ -306,6 +320,7 @@ class RouteMixin:
version=version, version=version,
name=name, name=name,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def put( def put(
@ -317,6 +332,7 @@ class RouteMixin:
version: Optional[int] = None, version: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **PUT** *HTTP* method Add an API URL under the **PUT** *HTTP* method
@ -340,6 +356,7 @@ class RouteMixin:
version=version, version=version,
name=name, name=name,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def head( def head(
@ -351,6 +368,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
ignore_body: bool = True, ignore_body: bool = True,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **HEAD** *HTTP* method Add an API URL under the **HEAD** *HTTP* method
@ -382,6 +400,7 @@ class RouteMixin:
name=name, name=name,
ignore_body=ignore_body, ignore_body=ignore_body,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def options( def options(
@ -393,6 +412,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
ignore_body: bool = True, ignore_body: bool = True,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **OPTIONS** *HTTP* method Add an API URL under the **OPTIONS** *HTTP* method
@ -424,6 +444,7 @@ class RouteMixin:
name=name, name=name,
ignore_body=ignore_body, ignore_body=ignore_body,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def patch( def patch(
@ -435,6 +456,7 @@ class RouteMixin:
version: Optional[int] = None, version: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **PATCH** *HTTP* method Add an API URL under the **PATCH** *HTTP* method
@ -468,6 +490,7 @@ class RouteMixin:
version=version, version=version,
name=name, name=name,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def delete( def delete(
@ -479,6 +502,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
ignore_body: bool = True, ignore_body: bool = True,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper: ) -> RouteWrapper:
""" """
Add an API URL under the **DELETE** *HTTP* method Add an API URL under the **DELETE** *HTTP* method
@ -502,6 +526,7 @@ class RouteMixin:
name=name, name=name,
ignore_body=ignore_body, ignore_body=ignore_body,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def websocket( def websocket(
@ -514,6 +539,7 @@ class RouteMixin:
name: Optional[str] = None, name: Optional[str] = None,
apply: bool = True, apply: bool = True,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
): ):
""" """
Decorate a function to be registered as a websocket route Decorate a function to be registered as a websocket route
@ -540,6 +566,7 @@ class RouteMixin:
subprotocols=subprotocols, subprotocols=subprotocols,
websocket=True, websocket=True,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
) )
def add_websocket_route( def add_websocket_route(
@ -552,6 +579,7 @@ class RouteMixin:
version: Optional[int] = None, version: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
): ):
""" """
A helper method to register a function as a websocket route. A helper method to register a function as a websocket route.
@ -580,6 +608,7 @@ class RouteMixin:
version=version, version=version,
name=name, name=name,
version_prefix=version_prefix, version_prefix=version_prefix,
error_format=error_format,
)(handler) )(handler)
def static( def static(
@ -789,8 +818,8 @@ class RouteMixin:
) )
except Exception: except Exception:
error_logger.exception( error_logger.exception(
f"Exception in static request handler:\ f"Exception in static request handler: "
path={file_or_directory}, " f"path={file_or_directory}, "
f"relative_url={__file_uri__}" f"relative_url={__file_uri__}"
) )
raise raise
@ -888,3 +917,43 @@ class RouteMixin:
)(_handler) )(_handler)
return route return route
def _determine_error_format(self, handler) -> str:
if not isinstance(handler, CompositionView):
try:
src = dedent(getsource(handler))
tree = parse(src)
http_response_types = self._get_response_types(tree)
if len(http_response_types) == 1:
return next(iter(http_response_types))
except OSError:
...
return "auto"
def _get_response_types(self, node):
types = set()
class HttpResponseVisitor(NodeVisitor):
def visit_Return(self, node: Return) -> Any:
nonlocal types
try:
checks = [node.value.func.id] # type: ignore
if node.value.keywords: # type: ignore
checks += [
k.value
for k in node.value.keywords # type: ignore
if k.arg == "content_type"
]
for check in checks:
if check in RESPONSE_MAPPING:
types.add(RESPONSE_MAPPING[check])
except AttributeError:
...
HttpResponseVisitor().visit(node)
return types

View File

@ -24,6 +24,7 @@ class FutureRoute(NamedTuple):
unquote: bool unquote: bool
static: bool static: bool
version_prefix: str version_prefix: str
error_format: Optional[str]
class FutureListener(NamedTuple): class FutureListener(NamedTuple):

View File

@ -34,7 +34,7 @@ from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from sanic.headers import ( from sanic.headers import (
Accept, AcceptContainer,
Options, Options,
parse_accept, parse_accept,
parse_content_header, parse_content_header,
@ -139,7 +139,7 @@ class Request:
self.conn_info: Optional[ConnInfo] = None self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.parsed_forwarded: Optional[Options] = None self.parsed_forwarded: Optional[Options] = None
self.parsed_accept: Optional[List[Accept]] = None self.parsed_accept: Optional[AcceptContainer] = None
self.parsed_json = None self.parsed_json = None
self.parsed_form = None self.parsed_form = None
self.parsed_files = None self.parsed_files = None
@ -301,7 +301,7 @@ class Request:
return self.parsed_json return self.parsed_json
@property @property
def accept(self) -> List[Accept]: def accept(self) -> AcceptContainer:
if self.parsed_accept is None: if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "") accept_header = self.headers.getone("accept", "")
self.parsed_accept = parse_accept(accept_header) self.parsed_accept = parse_accept(accept_header)

View File

@ -13,6 +13,7 @@ from sanic_routing.exceptions import (
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotSupported, NotFound, SanicException from sanic.exceptions import MethodNotSupported, NotFound, SanicException
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
@ -78,6 +79,7 @@ class Router(BaseRouter):
unquote: bool = False, unquote: bool = False,
static: bool = False, static: bool = False,
version_prefix: str = "/v", version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> Union[Route, List[Route]]: ) -> Union[Route, List[Route]]:
""" """
Add a handler to the router Add a handler to the router
@ -137,6 +139,11 @@ class Router(BaseRouter):
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
route.ctx.static = static route.ctx.static = static
route.ctx.error_format = (
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
check_error_format(route.ctx.error_format)
routes.append(route) routes.append(route)

View File

@ -109,6 +109,7 @@ def sanic_router(app):
# noinspection PyProtectedMember # noinspection PyProtectedMember
def _setup(route_details: tuple) -> Tuple[Router, tuple]: def _setup(route_details: tuple) -> Tuple[Router, tuple]:
router = Router() router = Router()
router.ctx.app = app
added_router = [] added_router = []
for method, route in route_details: for method, route in route_details:
try: try:

View File

@ -20,4 +20,4 @@ def test_bad_request_response(app):
app.run(host="127.0.0.1", port=42101, debug=False) app.run(host="127.0.0.1", port=42101, debug=False)
assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n"
assert b"Bad Request" in lines[-1] assert b"Bad Request" in lines[-2]

View File

@ -3,7 +3,12 @@ from pytest import raises
from sanic.app import Sanic from sanic.app import Sanic
from sanic.blueprint_group import BlueprintGroup from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import Forbidden, InvalidUsage, SanicException, ServerError from sanic.exceptions import (
Forbidden,
InvalidUsage,
SanicException,
ServerError,
)
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, text from sanic.response import HTTPResponse, text

View File

@ -1,10 +1,10 @@
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.errorpages import exception_response from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound from sanic.exceptions import NotFound, SanicException
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse from sanic.response import HTTPResponse, html, json, text
@pytest.fixture @pytest.fixture
@ -20,7 +20,7 @@ def app():
@pytest.fixture @pytest.fixture
def fake_request(app): def fake_request(app):
return Request(b"/foobar", {}, "1.1", "GET", None, app) return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -47,7 +47,13 @@ def test_should_return_html_valid_setting(
try: try:
raise exception("bad stuff") raise exception("bad stuff")
except Exception as e: except Exception as e:
response = exception_response(fake_request, e, True) response = exception_response(
fake_request,
e,
True,
base=HTMLRenderer,
fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT,
)
assert isinstance(response, HTTPResponse) assert isinstance(response, HTTPResponse)
assert response.status == status assert response.status == status
@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app):
app.config.FALLBACK_ERROR_FORMAT = "auto" app.config.FALLBACK_ERROR_FORMAT = "auto"
_, response = app.test_client.get( _, response = app.test_client.get(
"/error", headers={"content-type": "application/json"} "/error", headers={"content-type": "application/json", "accept": "*/*"}
) )
assert response.status == 500 assert response.status == 500
assert response.content_type == "application/json" assert response.content_type == "application/json"
_, response = app.test_client.get( _, response = app.test_client.get(
"/error", headers={"content-type": "text/plain"} "/error", headers={"content-type": "foo/bar", "accept": "*/*"}
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
def test_route_error_format_set_on_auto(app):
@app.get("/text")
def text_response(request):
return text(request.route.ctx.error_format)
@app.get("/json")
def json_response(request):
return json({"format": request.route.ctx.error_format})
@app.get("/html")
def html_response(request):
return html(request.route.ctx.error_format)
_, response = app.test_client.get("/text")
assert response.text == "text"
_, response = app.test_client.get("/json")
assert response.json["format"] == "json"
_, response = app.test_client.get("/html")
assert response.text == "html"
def test_route_error_response_from_auto_route(app):
@app.get("/text")
def text_response(request):
raise Exception("oops")
return text("Never gonna see this")
@app.get("/json")
def json_response(request):
raise Exception("oops")
return json({"message": "Never gonna see this"})
@app.get("/html")
def html_response(request):
raise Exception("oops")
return html("<h1>Never gonna see this</h1>")
_, response = app.test_client.get("/text")
assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get("/json")
assert response.content_type == "application/json"
_, response = app.test_client.get("/html")
assert response.content_type == "text/html; charset=utf-8"
def test_route_error_response_from_explicit_format(app):
@app.get("/text", error_format="json")
def text_response(request):
raise Exception("oops")
return text("Never gonna see this")
@app.get("/json", error_format="text")
def json_response(request):
raise Exception("oops")
return json({"message": "Never gonna see this"})
_, response = app.test_client.get("/text")
assert response.content_type == "application/json"
_, response = app.test_client.get("/json")
assert response.content_type == "text/plain; charset=utf-8"
def test_unknown_fallback_format(app):
with pytest.raises(SanicException, match="Unknown format: bad"):
app.config.FALLBACK_ERROR_FORMAT = "bad"
def test_route_error_format_unknown(app):
with pytest.raises(SanicException, match="Unknown format: bad"):
@app.get("/text", error_format="bad")
def handler(request):
...
def test_fallback_with_content_type_mismatch_accept(app):
app.config.FALLBACK_ERROR_FORMAT = "auto"
_, response = app.test_client.get(
"/error",
headers={"content-type": "application/json", "accept": "text/plain"},
) )
assert response.status == 500 assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8" assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get(
"/error",
headers={"content-type": "text/plain", "accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
app.router.reset()
@app.route("/alt1")
@app.route("/alt2", error_format="text")
@app.route("/alt3", error_format="html")
def handler(_):
raise Exception("problem here")
# Yes, we know this return value is unreachable. This is on purpose.
return json({})
app.router.finalize()
_, response = app.test_client.get(
"/alt1",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
_, response = app.test_client.get(
"/alt1",
headers={"accept": "foo/bar,*/*"},
)
assert response.status == 500
assert response.content_type == "application/json"
_, response = app.test_client.get(
"/alt2",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
_, response = app.test_client.get(
"/alt2",
headers={"accept": "foo/bar,*/*"},
)
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get(
"/alt3",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
@pytest.mark.parametrize(
"accept,content_type,expected",
(
(None, None, "text/plain; charset=utf-8"),
("foo/bar", None, "text/html; charset=utf-8"),
("application/json", None, "application/json"),
("application/json,text/plain", None, "application/json"),
("text/plain,application/json", None, "application/json"),
("text/plain,foo/bar", None, "text/plain; charset=utf-8"),
# Following test is valid after v22.3
# ("text/plain,text/html", None, "text/plain; charset=utf-8"),
("*/*", "foo/bar", "text/html; charset=utf-8"),
("*/*", "application/json", "application/json"),
),
)
def test_combinations_for_auto(fake_request, accept, content_type, expected):
if accept:
fake_request.headers["accept"] = accept
else:
del fake_request.headers["accept"]
if content_type:
fake_request.headers["content-type"] = content_type
try:
raise Exception("bad stuff")
except Exception as e:
response = exception_response(
fake_request,
e,
True,
base=HTMLRenderer,
fallback="auto",
)
assert response.content_type == expected

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from sanic import Sanic from sanic import Sanic
@ -8,9 +10,6 @@ from sanic.handlers import ErrorHandler
from sanic.response import stream, text from sanic.response import stream, text
exception_handler_app = Sanic("test_exception_handler")
async def sample_streaming_fn(response): async def sample_streaming_fn(response):
await response.write("foo,") await response.write("foo,")
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
@ -21,107 +20,102 @@ class ErrorWithRequestCtx(ServerError):
pass pass
@exception_handler_app.route("/1") @pytest.fixture
def handler_1(request): def exception_handler_app():
raise InvalidUsage("OK") exception_handler_app = Sanic("test_exception_handler")
@exception_handler_app.route("/1", error_format="html")
def handler_1(request):
raise InvalidUsage("OK")
@exception_handler_app.route("/2", error_format="html")
def handler_2(request):
raise ServerError("OK")
@exception_handler_app.route("/3", error_format="html")
def handler_3(request):
raise NotFound("OK")
@exception_handler_app.route("/4", error_format="html")
def handler_4(request):
foo = bar # noqa -- F821
return text(foo)
@exception_handler_app.route("/5", error_format="html")
def handler_5(request):
class CustomServerError(ServerError):
pass
raise CustomServerError("Custom server error")
@exception_handler_app.route("/6/<arg:int>", error_format="html")
def handler_6(request, arg):
try:
foo = 1 / arg
except Exception as e:
raise e from ValueError(f"{arg}")
return text(foo)
@exception_handler_app.route("/7", error_format="html")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8", error_format="html")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception):
return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
return exception_handler_app
@exception_handler_app.route("/2") def test_invalid_usage_exception_handler(exception_handler_app):
def handler_2(request):
raise ServerError("OK")
@exception_handler_app.route("/3")
def handler_3(request):
raise NotFound("OK")
@exception_handler_app.route("/4")
def handler_4(request):
foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception
return text(foo)
@exception_handler_app.route("/5")
def handler_5(request):
class CustomServerError(ServerError):
pass
raise CustomServerError("Custom server error")
@exception_handler_app.route("/6/<arg:int>")
def handler_6(request, arg):
try:
foo = 1 / arg
except Exception as e:
raise e from ValueError(f"{arg}")
return text(foo)
@exception_handler_app.route("/7")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception):
return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
def test_invalid_usage_exception_handler():
request, response = exception_handler_app.test_client.get("/1") request, response = exception_handler_app.test_client.get("/1")
assert response.status == 400 assert response.status == 400
def test_server_error_exception_handler(): def test_server_error_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/2") request, response = exception_handler_app.test_client.get("/2")
assert response.status == 200 assert response.status == 200
assert response.text == "OK" assert response.text == "OK"
def test_not_found_exception_handler(): def test_not_found_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/3") request, response = exception_handler_app.test_client.get("/3")
assert response.status == 200 assert response.status == 200
def test_text_exception__handler(): def test_text_exception__handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/random") request, response = exception_handler_app.test_client.get("/random")
assert response.status == 200 assert response.status == 200
assert response.text == "Done." assert response.text == "Done."
def test_async_exception_handler(): def test_async_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/7") request, response = exception_handler_app.test_client.get("/7")
assert response.status == 200 assert response.status == 200
assert response.text == "foo,bar" assert response.text == "foo,bar"
def test_html_traceback_output_in_debug_mode(): def test_html_traceback_output_in_debug_mode(exception_handler_app):
request, response = exception_handler_app.test_client.get("/4", debug=True) request, response = exception_handler_app.test_client.get("/4", debug=True)
assert response.status == 500 assert response.status == 500
soup = BeautifulSoup(response.body, "html.parser") soup = BeautifulSoup(response.body, "html.parser")
@ -136,12 +130,12 @@ def test_html_traceback_output_in_debug_mode():
) == summary_text ) == summary_text
def test_inherited_exception_handler(): def test_inherited_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/5") request, response = exception_handler_app.test_client.get("/5")
assert response.status == 200 assert response.status == 200
def test_chained_exception_handler(): def test_chained_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get( request, response = exception_handler_app.test_client.get(
"/6/0", debug=True "/6/0", debug=True
) )
@ -153,7 +147,6 @@ def test_chained_exception_handler():
assert "handler_6" in html assert "handler_6" in html
assert "foo = 1 / arg" in html assert "foo = 1 / arg" in html
assert "ValueError" in html assert "ValueError" in html
assert "The above exception was the direct cause" in html
summary_text = " ".join(soup.select(".summary")[0].text.split()) summary_text = " ".join(soup.select(".summary")[0].text.split())
assert ( assert (
@ -161,7 +154,7 @@ def test_chained_exception_handler():
) == summary_text ) == summary_text
def test_exception_handler_lookup(): def test_exception_handler_lookup(exception_handler_app):
class CustomError(Exception): class CustomError(Exception):
pass pass
@ -184,7 +177,7 @@ def test_exception_handler_lookup():
class ModuleNotFoundError(ImportError): class ModuleNotFoundError(ImportError):
pass pass
handler = ErrorHandler() handler = ErrorHandler("auto")
handler.add(ImportError, import_error_handler) handler.add(ImportError, import_error_handler)
handler.add(CustomError, custom_error_handler) handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler) handler.add(ServerError, server_error_handler)
@ -209,7 +202,7 @@ def test_exception_handler_lookup():
) )
def test_exception_handler_processed_request_middleware(): def test_exception_handler_processed_request_middleware(exception_handler_app):
request, response = exception_handler_app.test_client.get("/8") request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200 assert response.status == 200
assert response.text == "Done." assert response.text == "Done."

View File

@ -5,6 +5,7 @@ import pytest
from sanic import headers, text from sanic import headers, text
from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http from sanic.http import Http
from sanic.request import Request
@pytest.fixture @pytest.fixture
@ -338,3 +339,31 @@ def test_value_in_accept(value):
assert "foo/bar" in acceptable assert "foo/bar" in acceptable
assert "foo/*" in acceptable assert "foo/*" in acceptable
assert "*/*" in acceptable assert "*/*" in acceptable
@pytest.mark.parametrize("value", ("foo/bar", "foo/*"))
def test_value_not_in_accept(value):
acceptable = headers.parse_accept(value)
assert "no/match" not in acceptable
assert "no/*" not in acceptable
@pytest.mark.parametrize(
"header,expected",
(
(
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501
[
"text/html",
"application/xhtml+xml",
"image/avif",
"image/webp",
"application/xml;q=0.9",
"*/*;q=0.8",
],
),
),
)
def test_browser_headers(header, expected):
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
assert request.accept == expected