diff --git a/sanic/app.py b/sanic/app.py index ca8c4fe0..01aa07cb 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -179,7 +179,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.configure_logging = configure_logging self.ctx = ctx or SimpleNamespace() self.debug = None - self.error_handler = error_handler or ErrorHandler() + self.error_handler = error_handler or ErrorHandler( + fallback=self.config.FALLBACK_ERROR_FORMAT, + ) self.is_running = False self.is_stopping = False self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index fd7332aa..617ec606 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -266,6 +266,9 @@ class Blueprint(BaseSanic): opt_version = options.get("version", None) opt_strict_slashes = options.get("strict_slashes", None) opt_version_prefix = options.get("version_prefix", self.version_prefix) + error_format = options.get( + "error_format", app.config.FALLBACK_ERROR_FORMAT + ) routes = [] middleware = [] @@ -313,6 +316,7 @@ class Blueprint(BaseSanic): future.unquote, future.static, version_prefix, + error_format, ) route = app._apply_route(apply_route) diff --git a/sanic/config.py b/sanic/config.py index 2a90c5fb..2971f4e4 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Union from warnings import warn +from sanic.errorpages import check_error_format from sanic.http import Http from .utils import load_module_from_file_location, str_to_bool @@ -20,7 +21,7 @@ BASE_LOGO = """ DEFAULT_CONFIG = { "ACCESS_LOG": True, "EVENT_AUTOREGISTER": False, - "FALLBACK_ERROR_FORMAT": "html", + "FALLBACK_ERROR_FORMAT": "auto", "FORWARDED_FOR_HEADER": "X-Forwarded-For", "FORWARDED_SECRET": None, "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec @@ -94,6 +95,7 @@ class Config(dict): self.load_environment_vars(SANIC_PREFIX) self._configure_header_size() + self._check_error_format() def __getattr__(self, attr): try: @@ -109,6 +111,8 @@ class Config(dict): "REQUEST_MAX_SIZE", ): self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() def _configure_header_size(self): Http.set_header_max_size( @@ -117,6 +121,9 @@ class Config(dict): self.REQUEST_MAX_SIZE, ) + def _check_error_format(self): + check_error_format(self.FALLBACK_ERROR_FORMAT) + def load_environment_vars(self, prefix=SANIC_PREFIX): """ Looks for prefixed environment variables and applies diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 5fc10de1..82cdd57a 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -340,41 +340,138 @@ RENDERERS_BY_CONFIG = { } RENDERERS_BY_CONTENT_TYPE = { - "multipart/form-data": HTMLRenderer, - "application/json": JSONRenderer, "text/plain": TextRenderer, + "application/json": JSONRenderer, + "multipart/form-data": HTMLRenderer, + "text/html": HTMLRenderer, } +CONTENT_TYPE_BY_RENDERERS = { + v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() +} + +RESPONSE_MAPPING = { + "empty": "html", + "json": "json", + "text": "text", + "raw": "text", + "html": "html", + "file": "html", + "file_stream": "text", + "stream": "text", + "redirect": "html", + "text/plain": "text", + "text/html": "html", + "application/json": "json", +} + + +def check_error_format(format): + if format not in RENDERERS_BY_CONFIG and format != "auto": + raise SanicException(f"Unknown format: {format}") def exception_response( request: Request, exception: Exception, debug: bool, + fallback: str, + base: t.Type[BaseRenderer], renderer: t.Type[t.Optional[BaseRenderer]] = None, ) -> HTTPResponse: """ Render a response for the default FALLBACK exception handler. """ + content_type = None if not renderer: - renderer = HTMLRenderer + # Make sure we have something set + renderer = base + render_format = fallback if request: - if request.app.config.FALLBACK_ERROR_FORMAT == "auto": + # If there is a request, try and get the format + # from the route + if request.route: try: - renderer = JSONRenderer if request.json else HTMLRenderer - except InvalidUsage: + render_format = request.route.ctx.error_format + except AttributeError: + ... + + content_type = request.headers.getone("content-type", "").split( + ";" + )[0] + + acceptable = request.accept + + # If the format is auto still, make a guess + if render_format == "auto": + # First, if there is an Accept header, check if text/html + # is the first option + # According to MDN Web Docs, all major browsers use text/html + # as the primary value in Accept (with the exception of IE 8, + # and, well, if you are supporting IE 8, then you have bigger + # problems to concern yourself with than what default exception + # renderer is used) + # Source: + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values + + if acceptable and acceptable[0].match( + "text/html", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ): renderer = HTMLRenderer - content_type, *_ = request.headers.getone( - "content-type", "" - ).split(";") - renderer = RENDERERS_BY_CONTENT_TYPE.get( - content_type, renderer - ) + # Second, if there is an Accept header, check if + # application/json is an option, or if the content-type + # is application/json + elif ( + acceptable + and acceptable.match( + "application/json", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ) + or content_type == "application/json" + ): + renderer = JSONRenderer + + # Third, if there is no Accept header, assume we want text. + # The likely use case here is a raw socket. + elif not acceptable: + renderer = TextRenderer + else: + # Fourth, look to see if there was a JSON body + # When in this situation, the request is probably coming + # from curl, an API client like Postman or Insomnia, or a + # package like requests or httpx + try: + # Give them the benefit of the doubt if they did: + # $ curl localhost:8000 -d '{"foo": "bar"}' + # And provide them with JSONRenderer + renderer = JSONRenderer if request.json else base + except InvalidUsage: + renderer = base else: - render_format = request.app.config.FALLBACK_ERROR_FORMAT renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) + # Lastly, if there is an Accept header, make sure + # our choice is okay + if acceptable: + type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore + if type_ and type_ not in acceptable: + # If the renderer selected is not in the Accept header + # look through what is in the Accept header, and select + # the first option that matches. Otherwise, just drop back + # to the original default + for accept in acceptable: + mtype = f"{accept.type_}/{accept.subtype}" + maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) + if maybe: + renderer = maybe + break + else: + renderer = base + renderer = t.cast(t.Type[BaseRenderer], renderer) return renderer(request, exception, debug).render() diff --git a/sanic/handlers.py b/sanic/handlers.py index e33aff76..ffeb76b8 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,12 +1,13 @@ -from typing import List, Optional +from typing import Dict, List, Optional, Tuple, Type -from sanic.errorpages import exception_response +from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response from sanic.exceptions import ( ContentRangeError, HeaderNotFound, InvalidRangeType, ) from sanic.log import error_logger +from sanic.models.handler_types import RouteHandler from sanic.response import text @@ -23,10 +24,15 @@ class ErrorHandler: """ - def __init__(self): - self.handlers = [] - self.cached_handlers = {} + # Beginning in v22.3, the base renderer will be TextRenderer + def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): + self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] + self.cached_handlers: Dict[ + Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] + ] = {} self.debug = False + self.fallback = fallback + self.base = base def add(self, exception, handler, route_names: Optional[List[str]] = None): """ @@ -142,7 +148,13 @@ class ErrorHandler: :return: """ self.log(request, exception) - return exception_response(request, exception, self.debug) + return exception_response( + request, + exception, + debug=self.debug, + base=self.base, + fallback=self.fallback, + ) @staticmethod def log(request, exception): diff --git a/sanic/headers.py b/sanic/headers.py index cc05f8e0..dbb8720f 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -35,7 +35,7 @@ _host_re = re.compile( def parse_arg_as_accept(f): def func(self, other, *args, **kwargs): - if not isinstance(other, Accept): + if not isinstance(other, Accept) and other: other = Accept.parse(other) return f(self, other, *args, **kwargs) @@ -181,6 +181,27 @@ class Accept(str): return cls(mtype, MediaType(type_), MediaType(subtype), **params) +class AcceptContainer(list): + def __contains__(self, o: object) -> bool: + return any(item.match(o) for item in self) + + def match( + self, + o: object, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + return any( + item.match( + o, + allow_type_wildcard=allow_type_wildcard, + allow_subtype_wildcard=allow_subtype_wildcard, + ) + for item in self + ) + + def parse_content_header(value: str) -> Tuple[str, Options]: """Parse content-type and content-disposition header values. @@ -356,7 +377,7 @@ def _sort_accept_value(accept: Accept): ) -def parse_accept(accept: str) -> List[Accept]: +def parse_accept(accept: str) -> AcceptContainer: """Parse an Accept header and order the acceptable media types in accorsing to RFC 7231, s. 5.3.2 https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 @@ -370,4 +391,6 @@ def parse_accept(accept: str) -> List[Accept]: accept_list.append(Accept.parse(mtype)) - return sorted(accept_list, key=_sort_accept_value, reverse=True) + return AcceptContainer( + sorted(accept_list, key=_sort_accept_value, reverse=True) + ) diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 6881f291..3d9733cd 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -1,17 +1,20 @@ +from ast import NodeVisitor, Return, parse from functools import partial, wraps -from inspect import signature +from inspect import getsource, signature from mimetypes import guess_type from os import path from pathlib import PurePath from re import sub +from textwrap import dedent from time import gmtime, strftime -from typing import Callable, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import unquote from sanic_routing.route import Route # type: ignore from sanic.compat import stat_async from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS +from sanic.errorpages import RESPONSE_MAPPING from sanic.exceptions import ( ContentRangeError, FileNotFound, @@ -61,6 +64,7 @@ class RouteMixin: unquote: bool = False, static: bool = False, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Decorate a function to be registered as a route @@ -103,6 +107,7 @@ class RouteMixin: nonlocal websocket nonlocal static nonlocal version_prefix + nonlocal error_format if isinstance(handler, tuple): # if a handler fn is already wrapped in a route, the handler @@ -128,6 +133,9 @@ class RouteMixin: # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) + if not error_format or error_format == "auto": + error_format = self._determine_error_format(handler) + route = FutureRoute( handler, uri, @@ -143,6 +151,7 @@ class RouteMixin: unquote, static, version_prefix, + error_format, ) self._future_routes.add(route) @@ -186,6 +195,7 @@ class RouteMixin: name: Optional[str] = None, stream: bool = False, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteHandler: """A helper method to register class instance or functions as a handler to the application url @@ -236,6 +246,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) return handler @@ -249,6 +260,7 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **GET** *HTTP* method @@ -272,6 +284,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def post( @@ -283,6 +296,7 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **POST** *HTTP* method @@ -306,6 +320,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def put( @@ -317,6 +332,7 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **PUT** *HTTP* method @@ -340,6 +356,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def head( @@ -351,6 +368,7 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **HEAD** *HTTP* method @@ -382,6 +400,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def options( @@ -393,6 +412,7 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **OPTIONS** *HTTP* method @@ -424,6 +444,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def patch( @@ -435,6 +456,7 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **PATCH** *HTTP* method @@ -468,6 +490,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def delete( @@ -479,6 +502,7 @@ class RouteMixin: name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> RouteWrapper: """ Add an API URL under the **DELETE** *HTTP* method @@ -502,6 +526,7 @@ class RouteMixin: name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def websocket( @@ -514,6 +539,7 @@ class RouteMixin: name: Optional[str] = None, apply: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ Decorate a function to be registered as a websocket route @@ -540,6 +566,7 @@ class RouteMixin: subprotocols=subprotocols, websocket=True, version_prefix=version_prefix, + error_format=error_format, ) def add_websocket_route( @@ -552,6 +579,7 @@ class RouteMixin: version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ A helper method to register a function as a websocket route. @@ -580,6 +608,7 @@ class RouteMixin: version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) def static( @@ -789,8 +818,8 @@ class RouteMixin: ) except Exception: error_logger.exception( - f"Exception in static request handler:\ - path={file_or_directory}, " + f"Exception in static request handler: " + f"path={file_or_directory}, " f"relative_url={__file_uri__}" ) raise @@ -888,3 +917,43 @@ class RouteMixin: )(_handler) return route + + def _determine_error_format(self, handler) -> str: + if not isinstance(handler, CompositionView): + try: + src = dedent(getsource(handler)) + tree = parse(src) + http_response_types = self._get_response_types(tree) + + if len(http_response_types) == 1: + return next(iter(http_response_types)) + except OSError: + ... + + return "auto" + + def _get_response_types(self, node): + types = set() + + class HttpResponseVisitor(NodeVisitor): + def visit_Return(self, node: Return) -> Any: + nonlocal types + + try: + checks = [node.value.func.id] # type: ignore + if node.value.keywords: # type: ignore + checks += [ + k.value + for k in node.value.keywords # type: ignore + if k.arg == "content_type" + ] + + for check in checks: + if check in RESPONSE_MAPPING: + types.add(RESPONSE_MAPPING[check]) + except AttributeError: + ... + + HttpResponseVisitor().visit(node) + + return types diff --git a/sanic/models/futures.py b/sanic/models/futures.py index 8fd21655..fe7d77eb 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -24,6 +24,7 @@ class FutureRoute(NamedTuple): unquote: bool static: bool version_prefix: str + error_format: Optional[str] class FutureListener(NamedTuple): diff --git a/sanic/request.py b/sanic/request.py index 7f94de2d..c744e3c3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -34,7 +34,7 @@ from sanic.compat import CancelledErrors, Header from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.exceptions import InvalidUsage from sanic.headers import ( - Accept, + AcceptContainer, Options, parse_accept, parse_content_header, @@ -139,7 +139,7 @@ class Request: self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None - self.parsed_accept: Optional[List[Accept]] = None + self.parsed_accept: Optional[AcceptContainer] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None @@ -301,7 +301,7 @@ class Request: return self.parsed_json @property - def accept(self) -> List[Accept]: + def accept(self) -> AcceptContainer: if self.parsed_accept is None: accept_header = self.headers.getone("accept", "") self.parsed_accept = parse_accept(accept_header) diff --git a/sanic/router.py b/sanic/router.py index 2661acf5..6995ed6d 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -13,6 +13,7 @@ from sanic_routing.exceptions import ( from sanic_routing.route import Route # type: ignore from sanic.constants import HTTP_METHODS +from sanic.errorpages import check_error_format from sanic.exceptions import MethodNotSupported, NotFound, SanicException from sanic.models.handler_types import RouteHandler @@ -78,6 +79,7 @@ class Router(BaseRouter): unquote: bool = False, static: bool = False, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> Union[Route, List[Route]]: """ Add a handler to the router @@ -137,6 +139,11 @@ class Router(BaseRouter): route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static + route.ctx.error_format = ( + error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT + ) + + check_error_format(route.ctx.error_format) routes.append(route) diff --git a/tests/conftest.py b/tests/conftest.py index d24066c5..175e967e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,6 +109,7 @@ def sanic_router(app): # noinspection PyProtectedMember def _setup(route_details: tuple) -> Tuple[Router, tuple]: router = Router() + router.ctx.app = app added_router = [] for method, route in route_details: try: diff --git a/tests/test_bad_request.py b/tests/test_bad_request.py index 140fbe8a..7a87d919 100644 --- a/tests/test_bad_request.py +++ b/tests/test_bad_request.py @@ -20,4 +20,4 @@ def test_bad_request_response(app): app.run(host="127.0.0.1", port=42101, debug=False) assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" - assert b"Bad Request" in lines[-1] + assert b"Bad Request" in lines[-2] diff --git a/tests/test_blueprint_group.py b/tests/test_blueprint_group.py index 45e968b5..09729c15 100644 --- a/tests/test_blueprint_group.py +++ b/tests/test_blueprint_group.py @@ -3,7 +3,12 @@ from pytest import raises from sanic.app import Sanic from sanic.blueprint_group import BlueprintGroup from sanic.blueprints import Blueprint -from sanic.exceptions import Forbidden, InvalidUsage, SanicException, ServerError +from sanic.exceptions import ( + Forbidden, + InvalidUsage, + SanicException, + ServerError, +) from sanic.request import Request from sanic.response import HTTPResponse, text diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 495c764f..5af4ca5f 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,10 +1,10 @@ import pytest from sanic import Sanic -from sanic.errorpages import exception_response -from sanic.exceptions import NotFound +from sanic.errorpages import HTMLRenderer, exception_response +from sanic.exceptions import NotFound, SanicException from sanic.request import Request -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, html, json, text @pytest.fixture @@ -20,7 +20,7 @@ def app(): @pytest.fixture def fake_request(app): - return Request(b"/foobar", {}, "1.1", "GET", None, app) + return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app) @pytest.mark.parametrize( @@ -47,7 +47,13 @@ def test_should_return_html_valid_setting( try: raise exception("bad stuff") except Exception as e: - response = exception_response(fake_request, e, True) + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT, + ) assert isinstance(response, HTTPResponse) assert response.status == status @@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app): app.config.FALLBACK_ERROR_FORMAT = "auto" _, response = app.test_client.get( - "/error", headers={"content-type": "application/json"} + "/error", headers={"content-type": "application/json", "accept": "*/*"} ) assert response.status == 500 assert response.content_type == "application/json" _, response = app.test_client.get( - "/error", headers={"content-type": "text/plain"} + "/error", headers={"content-type": "foo/bar", "accept": "*/*"} + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_format_set_on_auto(app): + @app.get("/text") + def text_response(request): + return text(request.route.ctx.error_format) + + @app.get("/json") + def json_response(request): + return json({"format": request.route.ctx.error_format}) + + @app.get("/html") + def html_response(request): + return html(request.route.ctx.error_format) + + _, response = app.test_client.get("/text") + assert response.text == "text" + + _, response = app.test_client.get("/json") + assert response.json["format"] == "json" + + _, response = app.test_client.get("/html") + assert response.text == "html" + + +def test_route_error_response_from_auto_route(app): + @app.get("/text") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + @app.get("/html") + def html_response(request): + raise Exception("oops") + return html("

Never gonna see this

") + + _, response = app.test_client.get("/text") + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get("/json") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/html") + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_response_from_explicit_format(app): + @app.get("/text", error_format="json") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json", error_format="text") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + _, response = app.test_client.get("/text") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/json") + assert response.content_type == "text/plain; charset=utf-8" + + +def test_unknown_fallback_format(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + app.config.FALLBACK_ERROR_FORMAT = "bad" + + +def test_route_error_format_unknown(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + + @app.get("/text", error_format="bad") + def handler(request): + ... + + +def test_fallback_with_content_type_mismatch_accept(app): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "application/json", "accept": "text/plain"}, ) assert response.status == 500 assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "text/plain", "accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + app.router.reset() + + @app.route("/alt1") + @app.route("/alt2", error_format="text") + @app.route("/alt3", error_format="html") + def handler(_): + raise Exception("problem here") + # Yes, we know this return value is unreachable. This is on purpose. + return json({}) + + app.router.finalize() + + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "application/json" + + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/alt3", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +@pytest.mark.parametrize( + "accept,content_type,expected", + ( + (None, None, "text/plain; charset=utf-8"), + ("foo/bar", None, "text/html; charset=utf-8"), + ("application/json", None, "application/json"), + ("application/json,text/plain", None, "application/json"), + ("text/plain,application/json", None, "application/json"), + ("text/plain,foo/bar", None, "text/plain; charset=utf-8"), + # Following test is valid after v22.3 + # ("text/plain,text/html", None, "text/plain; charset=utf-8"), + ("*/*", "foo/bar", "text/html; charset=utf-8"), + ("*/*", "application/json", "application/json"), + ), +) +def test_combinations_for_auto(fake_request, accept, content_type, expected): + if accept: + fake_request.headers["accept"] = accept + else: + del fake_request.headers["accept"] + + if content_type: + fake_request.headers["content-type"] = content_type + + try: + raise Exception("bad stuff") + except Exception as e: + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback="auto", + ) + + assert response.content_type == expected diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index ba14102a..dbf9fcbb 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,5 +1,7 @@ import asyncio +import pytest + from bs4 import BeautifulSoup from sanic import Sanic @@ -8,9 +10,6 @@ from sanic.handlers import ErrorHandler from sanic.response import stream, text -exception_handler_app = Sanic("test_exception_handler") - - async def sample_streaming_fn(response): await response.write("foo,") await asyncio.sleep(0.001) @@ -21,107 +20,102 @@ class ErrorWithRequestCtx(ServerError): pass -@exception_handler_app.route("/1") -def handler_1(request): - raise InvalidUsage("OK") +@pytest.fixture +def exception_handler_app(): + exception_handler_app = Sanic("test_exception_handler") + + @exception_handler_app.route("/1", error_format="html") + def handler_1(request): + raise InvalidUsage("OK") + + @exception_handler_app.route("/2", error_format="html") + def handler_2(request): + raise ServerError("OK") + + @exception_handler_app.route("/3", error_format="html") + def handler_3(request): + raise NotFound("OK") + + @exception_handler_app.route("/4", error_format="html") + def handler_4(request): + foo = bar # noqa -- F821 + return text(foo) + + @exception_handler_app.route("/5", error_format="html") + def handler_5(request): + class CustomServerError(ServerError): + pass + + raise CustomServerError("Custom server error") + + @exception_handler_app.route("/6/", error_format="html") + def handler_6(request, arg): + try: + foo = 1 / arg + except Exception as e: + raise e from ValueError(f"{arg}") + return text(foo) + + @exception_handler_app.route("/7", error_format="html") + def handler_7(request): + raise Forbidden("go away!") + + @exception_handler_app.route("/8", error_format="html") + def handler_8(request): + + raise ErrorWithRequestCtx("OK") + + @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) + def handler_exception_with_ctx(request, exception): + return text(request.ctx.middleware_ran) + + @exception_handler_app.exception(ServerError) + def handler_exception(request, exception): + return text("OK") + + @exception_handler_app.exception(Forbidden) + async def async_handler_exception(request, exception): + return stream( + sample_streaming_fn, + content_type="text/csv", + ) + + @exception_handler_app.middleware + async def some_request_middleware(request): + request.ctx.middleware_ran = "Done." + + return exception_handler_app -@exception_handler_app.route("/2") -def handler_2(request): - raise ServerError("OK") - - -@exception_handler_app.route("/3") -def handler_3(request): - raise NotFound("OK") - - -@exception_handler_app.route("/4") -def handler_4(request): - foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception - return text(foo) - - -@exception_handler_app.route("/5") -def handler_5(request): - class CustomServerError(ServerError): - pass - - raise CustomServerError("Custom server error") - - -@exception_handler_app.route("/6/") -def handler_6(request, arg): - try: - foo = 1 / arg - except Exception as e: - raise e from ValueError(f"{arg}") - return text(foo) - - -@exception_handler_app.route("/7") -def handler_7(request): - raise Forbidden("go away!") - - -@exception_handler_app.route("/8") -def handler_8(request): - - raise ErrorWithRequestCtx("OK") - - -@exception_handler_app.exception(ErrorWithRequestCtx, NotFound) -def handler_exception_with_ctx(request, exception): - return text(request.ctx.middleware_ran) - - -@exception_handler_app.exception(ServerError) -def handler_exception(request, exception): - return text("OK") - - -@exception_handler_app.exception(Forbidden) -async def async_handler_exception(request, exception): - return stream( - sample_streaming_fn, - content_type="text/csv", - ) - - -@exception_handler_app.middleware -async def some_request_middleware(request): - request.ctx.middleware_ran = "Done." - - -def test_invalid_usage_exception_handler(): +def test_invalid_usage_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/1") assert response.status == 400 -def test_server_error_exception_handler(): +def test_server_error_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/2") assert response.status == 200 assert response.text == "OK" -def test_not_found_exception_handler(): +def test_not_found_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/3") assert response.status == 200 -def test_text_exception__handler(): +def test_text_exception__handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/random") assert response.status == 200 assert response.text == "Done." -def test_async_exception_handler(): +def test_async_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/7") assert response.status == 200 assert response.text == "foo,bar" -def test_html_traceback_output_in_debug_mode(): +def test_html_traceback_output_in_debug_mode(exception_handler_app): request, response = exception_handler_app.test_client.get("/4", debug=True) assert response.status == 500 soup = BeautifulSoup(response.body, "html.parser") @@ -136,12 +130,12 @@ def test_html_traceback_output_in_debug_mode(): ) == summary_text -def test_inherited_exception_handler(): +def test_inherited_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/5") assert response.status == 200 -def test_chained_exception_handler(): +def test_chained_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get( "/6/0", debug=True ) @@ -153,7 +147,6 @@ def test_chained_exception_handler(): assert "handler_6" in html assert "foo = 1 / arg" in html assert "ValueError" in html - assert "The above exception was the direct cause" in html summary_text = " ".join(soup.select(".summary")[0].text.split()) assert ( @@ -161,7 +154,7 @@ def test_chained_exception_handler(): ) == summary_text -def test_exception_handler_lookup(): +def test_exception_handler_lookup(exception_handler_app): class CustomError(Exception): pass @@ -184,7 +177,7 @@ def test_exception_handler_lookup(): class ModuleNotFoundError(ImportError): pass - handler = ErrorHandler() + handler = ErrorHandler("auto") handler.add(ImportError, import_error_handler) handler.add(CustomError, custom_error_handler) handler.add(ServerError, server_error_handler) @@ -209,7 +202,7 @@ def test_exception_handler_lookup(): ) -def test_exception_handler_processed_request_middleware(): +def test_exception_handler_processed_request_middleware(exception_handler_app): request, response = exception_handler_app.test_client.get("/8") assert response.status == 200 assert response.text == "Done." diff --git a/tests/test_headers.py b/tests/test_headers.py index 928847e0..115bed86 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -5,6 +5,7 @@ import pytest from sanic import headers, text from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.http import Http +from sanic.request import Request @pytest.fixture @@ -338,3 +339,31 @@ def test_value_in_accept(value): assert "foo/bar" in acceptable assert "foo/*" in acceptable assert "*/*" in acceptable + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*")) +def test_value_not_in_accept(value): + acceptable = headers.parse_accept(value) + assert "no/match" not in acceptable + assert "no/*" not in acceptable + + +@pytest.mark.parametrize( + "header,expected", + ( + ( + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501 + [ + "text/html", + "application/xhtml+xml", + "image/avif", + "image/webp", + "application/xml;q=0.9", + "*/*;q=0.8", + ], + ), + ), +) +def test_browser_headers(header, expected): + request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) + assert request.accept == expected