From cde02b5936838e7a1574ba094e44d987176848d9 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 16 Nov 2021 13:07:33 +0200 Subject: [PATCH] More consistent config setting with post-FALLBACK_ERROR_FORMAT apply (#2310) * Update unit testing and add more consistent config * Change init and app values to private * Cleanup line lengths --- sanic/app.py | 8 ++--- sanic/config.py | 70 +++++++++++++++++++++++++++++----------- sanic/errorpages.py | 4 ++- sanic/mixins/routes.py | 4 +-- sanic/router.py | 7 ++-- tests/test_config.py | 30 ++++++++++++++++- tests/test_errorpages.py | 42 ++++++++++++++++++++++++ 7 files changed, 135 insertions(+), 30 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index fb4ed4eb..30af974a 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -190,14 +190,14 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._state: ApplicationState = ApplicationState(app=self) self.blueprints: Dict[str, Blueprint] = {} self.config: Config = config or Config( - load_env=load_env, env_prefix=env_prefix + load_env=load_env, + env_prefix=env_prefix, + app=self, ) self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.debug = False - self.error_handler: ErrorHandler = error_handler or ErrorHandler( - fallback=self.config.FALLBACK_ERROR_FORMAT, - ) + self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} diff --git a/sanic/config.py b/sanic/config.py index 496ceadb..ebe1a9a6 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from inspect import isclass from os import environ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from warnings import warn from sanic.errorpages import check_error_format @@ -9,6 +11,10 @@ from sanic.http import Http from sanic.utils import load_module_from_file_location, str_to_bool +if TYPE_CHECKING: # no cov + from sanic import Sanic + + SANIC_PREFIX = "SANIC_" @@ -73,10 +79,13 @@ class Config(dict): load_env: Optional[Union[bool, str]] = True, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, + *, + app: Optional[Sanic] = None, ): defaults = defaults or {} super().__init__({**DEFAULT_CONFIG, **defaults}) + self._app = app self._LOGO = "" if keep_alive is not None: @@ -99,6 +108,7 @@ class Config(dict): self._configure_header_size() self._check_error_format() + self._init = True def __getattr__(self, attr): try: @@ -106,23 +116,47 @@ class Config(dict): except KeyError as ke: raise AttributeError(f"Config has no '{ke.args[0]}'") - def __setattr__(self, attr, value): - self[attr] = value - if attr in ( - "REQUEST_MAX_HEADER_SIZE", - "REQUEST_BUFFER_SIZE", - "REQUEST_MAX_SIZE", - ): - self._configure_header_size() - elif attr == "FALLBACK_ERROR_FORMAT": - self._check_error_format() - elif attr == "LOGO": - self._LOGO = value - warn( - "Setting the config.LOGO is deprecated and will no longer " - "be supported starting in v22.6.", - DeprecationWarning, - ) + def __setattr__(self, attr, value) -> None: + self.update({attr: value}) + + def __setitem__(self, attr, value) -> None: + self.update({attr: value}) + + def update(self, *other, **kwargs) -> None: + other_mapping = {k: v for item in other for k, v in dict(item).items()} + super().update(*other, **kwargs) + for attr, value in {**other_mapping, **kwargs}.items(): + self._post_set(attr, value) + + def _post_set(self, attr, value) -> None: + if self.get("_init"): + if attr in ( + "REQUEST_MAX_HEADER_SIZE", + "REQUEST_BUFFER_SIZE", + "REQUEST_MAX_SIZE", + ): + self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() + if self.app and value != self.app.error_handler.fallback: + if self.app.error_handler.fallback != "auto": + warn( + "Overriding non-default ErrorHandler fallback " + "value. Changing from " + f"{self.app.error_handler.fallback} to {value}." + ) + self.app.error_handler.fallback = value + elif attr == "LOGO": + self._LOGO = value + warn( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + DeprecationWarning, + ) + + @property + def app(self): + return self._app @property def LOGO(self): diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 82cdd57a..d046c29d 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -383,6 +383,7 @@ def exception_response( """ content_type = None + print("exception_response", fallback) if not renderer: # Make sure we have something set renderer = base @@ -393,7 +394,8 @@ def exception_response( # from the route if request.route: try: - render_format = request.route.ctx.error_format + if request.route.ctx.error_format: + render_format = request.route.ctx.error_format except AttributeError: ... diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 8467a2e3..7139cd3c 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -918,7 +918,7 @@ class RouteMixin: return route - def _determine_error_format(self, handler) -> str: + def _determine_error_format(self, handler) -> Optional[str]: if not isinstance(handler, CompositionView): try: src = dedent(getsource(handler)) @@ -930,7 +930,7 @@ class RouteMixin: except (OSError, TypeError): ... - return "auto" + return None def _get_response_types(self, node): types = set() diff --git a/sanic/router.py b/sanic/router.py index 6995ed6d..b15c2a3e 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -139,11 +139,10 @@ class Router(BaseRouter): route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static - route.ctx.error_format = ( - error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT - ) + route.ctx.error_format = error_format - check_error_format(route.ctx.error_format) + if error_format: + check_error_format(route.ctx.error_format) routes.append(route) diff --git a/tests/test_config.py b/tests/test_config.py index f3447666..67324f1e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,9 @@ from contextlib import contextmanager -from email import message from os import environ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent +from unittest.mock import Mock import pytest @@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app): ) with pytest.warns(DeprecationWarning, match=message): app.config.LOGO = "My Custom Logo" + + +def test_config_set_methods(app, monkeypatch): + post_set = Mock() + monkeypatch.setattr(Config, "_post_set", post_set) + + app.config.FOO = 1 + post_set.assert_called_once_with("FOO", 1) + post_set.reset_mock() + + app.config["FOO"] = 2 + post_set.assert_called_once_with("FOO", 2) + post_set.reset_mock() + + app.config.update({"FOO": 3}) + post_set.assert_called_once_with("FOO", 3) + post_set.reset_mock() + + app.config.update([("FOO", 4)]) + post_set.assert_called_once_with("FOO", 4) + post_set.reset_mock() + + app.config.update(FOO=5) + post_set.assert_called_once_with("FOO", 5) + post_set.reset_mock() + + app.config.update_config({"FOO": 6}) + post_set.assert_called_once_with("FOO", 6) diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 5af4ca5f..84949fde 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -3,6 +3,7 @@ import pytest from sanic import Sanic from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException +from sanic.handlers import ErrorHandler from sanic.request import Request from sanic.response import HTTPResponse, html, json, text @@ -271,3 +272,44 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): ) assert response.content_type == expected + + +def test_allow_fallback_error_format_set_main_process_start(app): + @app.main_process_start + async def start(app, _): + app.config.FALLBACK_ERROR_FORMAT = "text" + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_setting_fallback_to_non_default_raise_warning(app): + app.error_handler = ErrorHandler(fallback="text") + + assert app.error_handler.fallback == "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to auto." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + assert app.error_handler.fallback == "auto" + + app.config.FALLBACK_ERROR_FORMAT = "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to json." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "json" + + assert app.error_handler.fallback == "json"