Smarter auto fallback (#2162)

* Smarter auto fallback

* remove config from blueprints

* Add tests for error formatting

* Add check for proper format

* Fix some tests

* Add some tests

* docstring

* Add accept matching

* Add some more tests on matching

* Fix contains bug, earlier return on MediaType eq

* Add matching flags for wildcards

* Add mathing controls to accept

* Cleanup dev cruft

* Add cleanup and resolve OSError relating to test implementation

* Fix test

* Fix some typos
This commit is contained in:
Adam Hopkins 2021-09-29 23:53:49 +03:00 committed by GitHub
parent b5f2bd9b0e
commit cf1d2148ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 562 additions and 125 deletions

View File

@ -179,7 +179,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.configure_logging = configure_logging
self.ctx = ctx or SimpleNamespace()
self.debug = None
self.error_handler = error_handler or ErrorHandler()
self.error_handler = error_handler or ErrorHandler(
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.is_running = False
self.is_stopping = False
self.listeners: Dict[str, List[ListenerType]] = defaultdict(list)

View File

@ -266,6 +266,9 @@ class Blueprint(BaseSanic):
opt_version = options.get("version", None)
opt_strict_slashes = options.get("strict_slashes", None)
opt_version_prefix = options.get("version_prefix", self.version_prefix)
error_format = options.get(
"error_format", app.config.FALLBACK_ERROR_FORMAT
)
routes = []
middleware = []
@ -313,6 +316,7 @@ class Blueprint(BaseSanic):
future.unquote,
future.static,
version_prefix,
error_format,
)
route = app._apply_route(apply_route)

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union
from warnings import warn
from sanic.errorpages import check_error_format
from sanic.http import Http
from .utils import load_module_from_file_location, str_to_bool
@ -20,7 +21,7 @@ BASE_LOGO = """
DEFAULT_CONFIG = {
"ACCESS_LOG": True,
"EVENT_AUTOREGISTER": False,
"FALLBACK_ERROR_FORMAT": "html",
"FALLBACK_ERROR_FORMAT": "auto",
"FORWARDED_FOR_HEADER": "X-Forwarded-For",
"FORWARDED_SECRET": None,
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
@ -94,6 +95,7 @@ class Config(dict):
self.load_environment_vars(SANIC_PREFIX)
self._configure_header_size()
self._check_error_format()
def __getattr__(self, attr):
try:
@ -109,6 +111,8 @@ class Config(dict):
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format()
def _configure_header_size(self):
Http.set_header_max_size(
@ -117,6 +121,9 @@ class Config(dict):
self.REQUEST_MAX_SIZE,
)
def _check_error_format(self):
check_error_format(self.FALLBACK_ERROR_FORMAT)
def load_environment_vars(self, prefix=SANIC_PREFIX):
"""
Looks for prefixed environment variables and applies

View File

@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = {
}
RENDERERS_BY_CONTENT_TYPE = {
"multipart/form-data": HTMLRenderer,
"application/json": JSONRenderer,
"text/plain": TextRenderer,
"application/json": JSONRenderer,
"multipart/form-data": HTMLRenderer,
"text/html": HTMLRenderer,
}
CONTENT_TYPE_BY_RENDERERS = {
v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items()
}
RESPONSE_MAPPING = {
"empty": "html",
"json": "json",
"text": "text",
"raw": "text",
"html": "html",
"file": "html",
"file_stream": "text",
"stream": "text",
"redirect": "html",
"text/plain": "text",
"text/html": "html",
"application/json": "json",
}
def check_error_format(format):
if format not in RENDERERS_BY_CONFIG and format != "auto":
raise SanicException(f"Unknown format: {format}")
def exception_response(
request: Request,
exception: Exception,
debug: bool,
fallback: str,
base: t.Type[BaseRenderer],
renderer: t.Type[t.Optional[BaseRenderer]] = None,
) -> HTTPResponse:
"""
Render a response for the default FALLBACK exception handler.
"""
content_type = None
if not renderer:
renderer = HTMLRenderer
# Make sure we have something set
renderer = base
render_format = fallback
if request:
if request.app.config.FALLBACK_ERROR_FORMAT == "auto":
# If there is a request, try and get the format
# from the route
if request.route:
try:
renderer = JSONRenderer if request.json else HTMLRenderer
except InvalidUsage:
render_format = request.route.ctx.error_format
except AttributeError:
...
content_type = request.headers.getone("content-type", "").split(
";"
)[0]
acceptable = request.accept
# If the format is auto still, make a guess
if render_format == "auto":
# First, if there is an Accept header, check if text/html
# is the first option
# According to MDN Web Docs, all major browsers use text/html
# as the primary value in Accept (with the exception of IE 8,
# and, well, if you are supporting IE 8, then you have bigger
# problems to concern yourself with than what default exception
# renderer is used)
# Source:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values
if acceptable and acceptable[0].match(
"text/html",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
):
renderer = HTMLRenderer
content_type, *_ = request.headers.getone(
"content-type", ""
).split(";")
renderer = RENDERERS_BY_CONTENT_TYPE.get(
content_type, renderer
)
# Second, if there is an Accept header, check if
# application/json is an option, or if the content-type
# is application/json
elif (
acceptable
and acceptable.match(
"application/json",
allow_type_wildcard=False,
allow_subtype_wildcard=False,
)
or content_type == "application/json"
):
renderer = JSONRenderer
# Third, if there is no Accept header, assume we want text.
# The likely use case here is a raw socket.
elif not acceptable:
renderer = TextRenderer
else:
# Fourth, look to see if there was a JSON body
# When in this situation, the request is probably coming
# from curl, an API client like Postman or Insomnia, or a
# package like requests or httpx
try:
# Give them the benefit of the doubt if they did:
# $ curl localhost:8000 -d '{"foo": "bar"}'
# And provide them with JSONRenderer
renderer = JSONRenderer if request.json else base
except InvalidUsage:
renderer = base
else:
render_format = request.app.config.FALLBACK_ERROR_FORMAT
renderer = RENDERERS_BY_CONFIG.get(render_format, renderer)
# Lastly, if there is an Accept header, make sure
# our choice is okay
if acceptable:
type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore
if type_ and type_ not in acceptable:
# If the renderer selected is not in the Accept header
# look through what is in the Accept header, and select
# the first option that matches. Otherwise, just drop back
# to the original default
for accept in acceptable:
mtype = f"{accept.type_}/{accept.subtype}"
maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype)
if maybe:
renderer = maybe
break
else:
renderer = base
renderer = t.cast(t.Type[BaseRenderer], renderer)
return renderer(request, exception, debug).render()

View File

@ -1,12 +1,13 @@
from typing import List, Optional
from typing import Dict, List, Optional, Tuple, Type
from sanic.errorpages import exception_response
from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response
from sanic.exceptions import (
ContentRangeError,
HeaderNotFound,
InvalidRangeType,
)
from sanic.log import error_logger
from sanic.models.handler_types import RouteHandler
from sanic.response import text
@ -23,10 +24,15 @@ class ErrorHandler:
"""
def __init__(self):
self.handlers = []
self.cached_handlers = {}
# Beginning in v22.3, the base renderer will be TextRenderer
def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer):
self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = []
self.cached_handlers: Dict[
Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler]
] = {}
self.debug = False
self.fallback = fallback
self.base = base
def add(self, exception, handler, route_names: Optional[List[str]] = None):
"""
@ -142,7 +148,13 @@ class ErrorHandler:
:return:
"""
self.log(request, exception)
return exception_response(request, exception, self.debug)
return exception_response(
request,
exception,
debug=self.debug,
base=self.base,
fallback=self.fallback,
)
@staticmethod
def log(request, exception):

View File

@ -35,7 +35,7 @@ _host_re = re.compile(
def parse_arg_as_accept(f):
def func(self, other, *args, **kwargs):
if not isinstance(other, Accept):
if not isinstance(other, Accept) and other:
other = Accept.parse(other)
return f(self, other, *args, **kwargs)
@ -181,6 +181,27 @@ class Accept(str):
return cls(mtype, MediaType(type_), MediaType(subtype), **params)
class AcceptContainer(list):
def __contains__(self, o: object) -> bool:
return any(item.match(o) for item in self)
def match(
self,
o: object,
*,
allow_type_wildcard: bool = True,
allow_subtype_wildcard: bool = True,
) -> bool:
return any(
item.match(
o,
allow_type_wildcard=allow_type_wildcard,
allow_subtype_wildcard=allow_subtype_wildcard,
)
for item in self
)
def parse_content_header(value: str) -> Tuple[str, Options]:
"""Parse content-type and content-disposition header values.
@ -356,7 +377,7 @@ def _sort_accept_value(accept: Accept):
)
def parse_accept(accept: str) -> List[Accept]:
def parse_accept(accept: str) -> AcceptContainer:
"""Parse an Accept header and order the acceptable media types in
accorsing to RFC 7231, s. 5.3.2
https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2
@ -370,4 +391,6 @@ def parse_accept(accept: str) -> List[Accept]:
accept_list.append(Accept.parse(mtype))
return sorted(accept_list, key=_sort_accept_value, reverse=True)
return AcceptContainer(
sorted(accept_list, key=_sort_accept_value, reverse=True)
)

View File

@ -1,17 +1,20 @@
from ast import NodeVisitor, Return, parse
from functools import partial, wraps
from inspect import signature
from inspect import getsource, signature
from mimetypes import guess_type
from os import path
from pathlib import PurePath
from re import sub
from textwrap import dedent
from time import gmtime, strftime
from typing import Callable, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union
from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore
from sanic.compat import stat_async
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS
from sanic.errorpages import RESPONSE_MAPPING
from sanic.exceptions import (
ContentRangeError,
FileNotFound,
@ -61,6 +64,7 @@ class RouteMixin:
unquote: bool = False,
static: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Decorate a function to be registered as a route
@ -103,6 +107,7 @@ class RouteMixin:
nonlocal websocket
nonlocal static
nonlocal version_prefix
nonlocal error_format
if isinstance(handler, tuple):
# if a handler fn is already wrapped in a route, the handler
@ -128,6 +133,9 @@ class RouteMixin:
# subprotocol is unordered, keep it unordered
subprotocols = frozenset(subprotocols)
if not error_format or error_format == "auto":
error_format = self._determine_error_format(handler)
route = FutureRoute(
handler,
uri,
@ -143,6 +151,7 @@ class RouteMixin:
unquote,
static,
version_prefix,
error_format,
)
self._future_routes.add(route)
@ -186,6 +195,7 @@ class RouteMixin:
name: Optional[str] = None,
stream: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteHandler:
"""A helper method to register class instance or
functions as a handler to the application url
@ -236,6 +246,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)(handler)
return handler
@ -249,6 +260,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **GET** *HTTP* method
@ -272,6 +284,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def post(
@ -283,6 +296,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **POST** *HTTP* method
@ -306,6 +320,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def put(
@ -317,6 +332,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **PUT** *HTTP* method
@ -340,6 +356,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def head(
@ -351,6 +368,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **HEAD** *HTTP* method
@ -382,6 +400,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def options(
@ -393,6 +412,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **OPTIONS** *HTTP* method
@ -424,6 +444,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def patch(
@ -435,6 +456,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **PATCH** *HTTP* method
@ -468,6 +490,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)
def delete(
@ -479,6 +502,7 @@ class RouteMixin:
name: Optional[str] = None,
ignore_body: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> RouteWrapper:
"""
Add an API URL under the **DELETE** *HTTP* method
@ -502,6 +526,7 @@ class RouteMixin:
name=name,
ignore_body=ignore_body,
version_prefix=version_prefix,
error_format=error_format,
)
def websocket(
@ -514,6 +539,7 @@ class RouteMixin:
name: Optional[str] = None,
apply: bool = True,
version_prefix: str = "/v",
error_format: Optional[str] = None,
):
"""
Decorate a function to be registered as a websocket route
@ -540,6 +566,7 @@ class RouteMixin:
subprotocols=subprotocols,
websocket=True,
version_prefix=version_prefix,
error_format=error_format,
)
def add_websocket_route(
@ -552,6 +579,7 @@ class RouteMixin:
version: Optional[int] = None,
name: Optional[str] = None,
version_prefix: str = "/v",
error_format: Optional[str] = None,
):
"""
A helper method to register a function as a websocket route.
@ -580,6 +608,7 @@ class RouteMixin:
version=version,
name=name,
version_prefix=version_prefix,
error_format=error_format,
)(handler)
def static(
@ -789,8 +818,8 @@ class RouteMixin:
)
except Exception:
error_logger.exception(
f"Exception in static request handler:\
path={file_or_directory}, "
f"Exception in static request handler: "
f"path={file_or_directory}, "
f"relative_url={__file_uri__}"
)
raise
@ -888,3 +917,43 @@ class RouteMixin:
)(_handler)
return route
def _determine_error_format(self, handler) -> str:
if not isinstance(handler, CompositionView):
try:
src = dedent(getsource(handler))
tree = parse(src)
http_response_types = self._get_response_types(tree)
if len(http_response_types) == 1:
return next(iter(http_response_types))
except OSError:
...
return "auto"
def _get_response_types(self, node):
types = set()
class HttpResponseVisitor(NodeVisitor):
def visit_Return(self, node: Return) -> Any:
nonlocal types
try:
checks = [node.value.func.id] # type: ignore
if node.value.keywords: # type: ignore
checks += [
k.value
for k in node.value.keywords # type: ignore
if k.arg == "content_type"
]
for check in checks:
if check in RESPONSE_MAPPING:
types.add(RESPONSE_MAPPING[check])
except AttributeError:
...
HttpResponseVisitor().visit(node)
return types

View File

@ -24,6 +24,7 @@ class FutureRoute(NamedTuple):
unquote: bool
static: bool
version_prefix: str
error_format: Optional[str]
class FutureListener(NamedTuple):

View File

@ -34,7 +34,7 @@ from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import InvalidUsage
from sanic.headers import (
Accept,
AcceptContainer,
Options,
parse_accept,
parse_content_header,
@ -139,7 +139,7 @@ class Request:
self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace()
self.parsed_forwarded: Optional[Options] = None
self.parsed_accept: Optional[List[Accept]] = None
self.parsed_accept: Optional[AcceptContainer] = None
self.parsed_json = None
self.parsed_form = None
self.parsed_files = None
@ -301,7 +301,7 @@ class Request:
return self.parsed_json
@property
def accept(self) -> List[Accept]:
def accept(self) -> AcceptContainer:
if self.parsed_accept is None:
accept_header = self.headers.getone("accept", "")
self.parsed_accept = parse_accept(accept_header)

View File

@ -13,6 +13,7 @@ from sanic_routing.exceptions import (
from sanic_routing.route import Route # type: ignore
from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format
from sanic.exceptions import MethodNotSupported, NotFound, SanicException
from sanic.models.handler_types import RouteHandler
@ -78,6 +79,7 @@ class Router(BaseRouter):
unquote: bool = False,
static: bool = False,
version_prefix: str = "/v",
error_format: Optional[str] = None,
) -> Union[Route, List[Route]]:
"""
Add a handler to the router
@ -137,6 +139,11 @@ class Router(BaseRouter):
route.ctx.stream = stream
route.ctx.hosts = hosts
route.ctx.static = static
route.ctx.error_format = (
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
check_error_format(route.ctx.error_format)
routes.append(route)

View File

@ -109,6 +109,7 @@ def sanic_router(app):
# noinspection PyProtectedMember
def _setup(route_details: tuple) -> Tuple[Router, tuple]:
router = Router()
router.ctx.app = app
added_router = []
for method, route in route_details:
try:

View File

@ -20,4 +20,4 @@ def test_bad_request_response(app):
app.run(host="127.0.0.1", port=42101, debug=False)
assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n"
assert b"Bad Request" in lines[-1]
assert b"Bad Request" in lines[-2]

View File

@ -3,7 +3,12 @@ from pytest import raises
from sanic.app import Sanic
from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.exceptions import Forbidden, InvalidUsage, SanicException, ServerError
from sanic.exceptions import (
Forbidden,
InvalidUsage,
SanicException,
ServerError,
)
from sanic.request import Request
from sanic.response import HTTPResponse, text

View File

@ -1,10 +1,10 @@
import pytest
from sanic import Sanic
from sanic.errorpages import exception_response
from sanic.exceptions import NotFound
from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException
from sanic.request import Request
from sanic.response import HTTPResponse
from sanic.response import HTTPResponse, html, json, text
@pytest.fixture
@ -20,7 +20,7 @@ def app():
@pytest.fixture
def fake_request(app):
return Request(b"/foobar", {}, "1.1", "GET", None, app)
return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app)
@pytest.mark.parametrize(
@ -47,7 +47,13 @@ def test_should_return_html_valid_setting(
try:
raise exception("bad stuff")
except Exception as e:
response = exception_response(fake_request, e, True)
response = exception_response(
fake_request,
e,
True,
base=HTMLRenderer,
fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT,
)
assert isinstance(response, HTTPResponse)
assert response.status == status
@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app):
app.config.FALLBACK_ERROR_FORMAT = "auto"
_, response = app.test_client.get(
"/error", headers={"content-type": "application/json"}
"/error", headers={"content-type": "application/json", "accept": "*/*"}
)
assert response.status == 500
assert response.content_type == "application/json"
_, response = app.test_client.get(
"/error", headers={"content-type": "text/plain"}
"/error", headers={"content-type": "foo/bar", "accept": "*/*"}
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
def test_route_error_format_set_on_auto(app):
@app.get("/text")
def text_response(request):
return text(request.route.ctx.error_format)
@app.get("/json")
def json_response(request):
return json({"format": request.route.ctx.error_format})
@app.get("/html")
def html_response(request):
return html(request.route.ctx.error_format)
_, response = app.test_client.get("/text")
assert response.text == "text"
_, response = app.test_client.get("/json")
assert response.json["format"] == "json"
_, response = app.test_client.get("/html")
assert response.text == "html"
def test_route_error_response_from_auto_route(app):
@app.get("/text")
def text_response(request):
raise Exception("oops")
return text("Never gonna see this")
@app.get("/json")
def json_response(request):
raise Exception("oops")
return json({"message": "Never gonna see this"})
@app.get("/html")
def html_response(request):
raise Exception("oops")
return html("<h1>Never gonna see this</h1>")
_, response = app.test_client.get("/text")
assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get("/json")
assert response.content_type == "application/json"
_, response = app.test_client.get("/html")
assert response.content_type == "text/html; charset=utf-8"
def test_route_error_response_from_explicit_format(app):
@app.get("/text", error_format="json")
def text_response(request):
raise Exception("oops")
return text("Never gonna see this")
@app.get("/json", error_format="text")
def json_response(request):
raise Exception("oops")
return json({"message": "Never gonna see this"})
_, response = app.test_client.get("/text")
assert response.content_type == "application/json"
_, response = app.test_client.get("/json")
assert response.content_type == "text/plain; charset=utf-8"
def test_unknown_fallback_format(app):
with pytest.raises(SanicException, match="Unknown format: bad"):
app.config.FALLBACK_ERROR_FORMAT = "bad"
def test_route_error_format_unknown(app):
with pytest.raises(SanicException, match="Unknown format: bad"):
@app.get("/text", error_format="bad")
def handler(request):
...
def test_fallback_with_content_type_mismatch_accept(app):
app.config.FALLBACK_ERROR_FORMAT = "auto"
_, response = app.test_client.get(
"/error",
headers={"content-type": "application/json", "accept": "text/plain"},
)
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get(
"/error",
headers={"content-type": "text/plain", "accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
app.router.reset()
@app.route("/alt1")
@app.route("/alt2", error_format="text")
@app.route("/alt3", error_format="html")
def handler(_):
raise Exception("problem here")
# Yes, we know this return value is unreachable. This is on purpose.
return json({})
app.router.finalize()
_, response = app.test_client.get(
"/alt1",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
_, response = app.test_client.get(
"/alt1",
headers={"accept": "foo/bar,*/*"},
)
assert response.status == 500
assert response.content_type == "application/json"
_, response = app.test_client.get(
"/alt2",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
_, response = app.test_client.get(
"/alt2",
headers={"accept": "foo/bar,*/*"},
)
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
_, response = app.test_client.get(
"/alt3",
headers={"accept": "foo/bar"},
)
assert response.status == 500
assert response.content_type == "text/html; charset=utf-8"
@pytest.mark.parametrize(
"accept,content_type,expected",
(
(None, None, "text/plain; charset=utf-8"),
("foo/bar", None, "text/html; charset=utf-8"),
("application/json", None, "application/json"),
("application/json,text/plain", None, "application/json"),
("text/plain,application/json", None, "application/json"),
("text/plain,foo/bar", None, "text/plain; charset=utf-8"),
# Following test is valid after v22.3
# ("text/plain,text/html", None, "text/plain; charset=utf-8"),
("*/*", "foo/bar", "text/html; charset=utf-8"),
("*/*", "application/json", "application/json"),
),
)
def test_combinations_for_auto(fake_request, accept, content_type, expected):
if accept:
fake_request.headers["accept"] = accept
else:
del fake_request.headers["accept"]
if content_type:
fake_request.headers["content-type"] = content_type
try:
raise Exception("bad stuff")
except Exception as e:
response = exception_response(
fake_request,
e,
True,
base=HTMLRenderer,
fallback="auto",
)
assert response.content_type == expected

View File

@ -1,5 +1,7 @@
import asyncio
import pytest
from bs4 import BeautifulSoup
from sanic import Sanic
@ -8,9 +10,6 @@ from sanic.handlers import ErrorHandler
from sanic.response import stream, text
exception_handler_app = Sanic("test_exception_handler")
async def sample_streaming_fn(response):
await response.write("foo,")
await asyncio.sleep(0.001)
@ -21,107 +20,102 @@ class ErrorWithRequestCtx(ServerError):
pass
@exception_handler_app.route("/1")
def handler_1(request):
raise InvalidUsage("OK")
@pytest.fixture
def exception_handler_app():
exception_handler_app = Sanic("test_exception_handler")
@exception_handler_app.route("/1", error_format="html")
def handler_1(request):
raise InvalidUsage("OK")
@exception_handler_app.route("/2", error_format="html")
def handler_2(request):
raise ServerError("OK")
@exception_handler_app.route("/3", error_format="html")
def handler_3(request):
raise NotFound("OK")
@exception_handler_app.route("/4", error_format="html")
def handler_4(request):
foo = bar # noqa -- F821
return text(foo)
@exception_handler_app.route("/5", error_format="html")
def handler_5(request):
class CustomServerError(ServerError):
pass
raise CustomServerError("Custom server error")
@exception_handler_app.route("/6/<arg:int>", error_format="html")
def handler_6(request, arg):
try:
foo = 1 / arg
except Exception as e:
raise e from ValueError(f"{arg}")
return text(foo)
@exception_handler_app.route("/7", error_format="html")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8", error_format="html")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception):
return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
return exception_handler_app
@exception_handler_app.route("/2")
def handler_2(request):
raise ServerError("OK")
@exception_handler_app.route("/3")
def handler_3(request):
raise NotFound("OK")
@exception_handler_app.route("/4")
def handler_4(request):
foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception
return text(foo)
@exception_handler_app.route("/5")
def handler_5(request):
class CustomServerError(ServerError):
pass
raise CustomServerError("Custom server error")
@exception_handler_app.route("/6/<arg:int>")
def handler_6(request, arg):
try:
foo = 1 / arg
except Exception as e:
raise e from ValueError(f"{arg}")
return text(foo)
@exception_handler_app.route("/7")
def handler_7(request):
raise Forbidden("go away!")
@exception_handler_app.route("/8")
def handler_8(request):
raise ErrorWithRequestCtx("OK")
@exception_handler_app.exception(ErrorWithRequestCtx, NotFound)
def handler_exception_with_ctx(request, exception):
return text(request.ctx.middleware_ran)
@exception_handler_app.exception(ServerError)
def handler_exception(request, exception):
return text("OK")
@exception_handler_app.exception(Forbidden)
async def async_handler_exception(request, exception):
return stream(
sample_streaming_fn,
content_type="text/csv",
)
@exception_handler_app.middleware
async def some_request_middleware(request):
request.ctx.middleware_ran = "Done."
def test_invalid_usage_exception_handler():
def test_invalid_usage_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/1")
assert response.status == 400
def test_server_error_exception_handler():
def test_server_error_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/2")
assert response.status == 200
assert response.text == "OK"
def test_not_found_exception_handler():
def test_not_found_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/3")
assert response.status == 200
def test_text_exception__handler():
def test_text_exception__handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/random")
assert response.status == 200
assert response.text == "Done."
def test_async_exception_handler():
def test_async_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/7")
assert response.status == 200
assert response.text == "foo,bar"
def test_html_traceback_output_in_debug_mode():
def test_html_traceback_output_in_debug_mode(exception_handler_app):
request, response = exception_handler_app.test_client.get("/4", debug=True)
assert response.status == 500
soup = BeautifulSoup(response.body, "html.parser")
@ -136,12 +130,12 @@ def test_html_traceback_output_in_debug_mode():
) == summary_text
def test_inherited_exception_handler():
def test_inherited_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get("/5")
assert response.status == 200
def test_chained_exception_handler():
def test_chained_exception_handler(exception_handler_app):
request, response = exception_handler_app.test_client.get(
"/6/0", debug=True
)
@ -153,7 +147,6 @@ def test_chained_exception_handler():
assert "handler_6" in html
assert "foo = 1 / arg" in html
assert "ValueError" in html
assert "The above exception was the direct cause" in html
summary_text = " ".join(soup.select(".summary")[0].text.split())
assert (
@ -161,7 +154,7 @@ def test_chained_exception_handler():
) == summary_text
def test_exception_handler_lookup():
def test_exception_handler_lookup(exception_handler_app):
class CustomError(Exception):
pass
@ -184,7 +177,7 @@ def test_exception_handler_lookup():
class ModuleNotFoundError(ImportError):
pass
handler = ErrorHandler()
handler = ErrorHandler("auto")
handler.add(ImportError, import_error_handler)
handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler)
@ -209,7 +202,7 @@ def test_exception_handler_lookup():
)
def test_exception_handler_processed_request_middleware():
def test_exception_handler_processed_request_middleware(exception_handler_app):
request, response = exception_handler_app.test_client.get("/8")
assert response.status == 200
assert response.text == "Done."

View File

@ -5,6 +5,7 @@ import pytest
from sanic import headers, text
from sanic.exceptions import InvalidHeader, PayloadTooLarge
from sanic.http import Http
from sanic.request import Request
@pytest.fixture
@ -338,3 +339,31 @@ def test_value_in_accept(value):
assert "foo/bar" in acceptable
assert "foo/*" in acceptable
assert "*/*" in acceptable
@pytest.mark.parametrize("value", ("foo/bar", "foo/*"))
def test_value_not_in_accept(value):
acceptable = headers.parse_accept(value)
assert "no/match" not in acceptable
assert "no/*" not in acceptable
@pytest.mark.parametrize(
"header,expected",
(
(
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501
[
"text/html",
"application/xhtml+xml",
"image/avif",
"image/webp",
"application/xml;q=0.9",
"*/*;q=0.8",
],
),
),
)
def test_browser_headers(header, expected):
request = Request(b"/", {"accept": header}, "1.1", "GET", None, None)
assert request.accept == expected