From af1d289a457f811a255b743e02b631dd922c3223 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 17 Nov 2021 09:05:24 +0200 Subject: [PATCH] Add error format from config replacement objects --- sanic/app.py | 4 +++- sanic/handlers.py | 9 ++++++++- tests/test_errorpages.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index d24eeca0..0be52e32 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1472,7 +1472,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): async def _startup(self): self.signalize() self.finalize() - ErrorHandler.finalize(self.error_handler) + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) TouchUp.run(self) async def _server_event( diff --git a/sanic/handlers.py b/sanic/handlers.py index fd718f06..35deff0f 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -38,7 +38,14 @@ class ErrorHandler: self.base = base @classmethod - def finalize(cls, error_handler): + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + if not isinstance(error_handler, cls): error_logger.warning( f"Error handler is non-conforming: {type(error_handler)}" diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 84949fde..1843f6a7 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,6 +1,7 @@ import pytest from sanic import Sanic +from sanic.config import Config from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException from sanic.handlers import ErrorHandler @@ -313,3 +314,31 @@ def test_setting_fallback_to_non_default_raise_warning(app): app.config.FALLBACK_ERROR_FORMAT = "json" assert app.error_handler.fallback == "json" + + +def test_allow_fallback_error_format_in_config_injection(): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app = Sanic("test", config=MyConfig()) + + @app.route("/error", methods=["GET", "POST"]) + def err(request): + raise Exception("something went wrong") + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_allow_fallback_error_format_in_config_replacement(app): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app.config = MyConfig() + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8"