More consistent config setting with post-FALLBACK_ERROR_FORMAT apply (#2310)

* Update unit testing and add more consistent config

* Change init and app values to private

* Cleanup line lengths
This commit is contained in:
Adam Hopkins 2021-11-16 13:07:33 +02:00 committed by GitHub
parent abeb8d0bc0
commit cde02b5936
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 135 additions and 30 deletions

View File

@ -190,14 +190,14 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self._state: ApplicationState = ApplicationState(app=self) self._state: ApplicationState = ApplicationState(app=self)
self.blueprints: Dict[str, Blueprint] = {} self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config( self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix load_env=load_env,
env_prefix=env_prefix,
app=self,
) )
self.configure_logging: bool = configure_logging self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace() self.ctx: Any = ctx or SimpleNamespace()
self.debug = False self.debug = False
self.error_handler: ErrorHandler = error_handler or ErrorHandler( self.error_handler: ErrorHandler = error_handler or ErrorHandler()
fallback=self.config.FALLBACK_ERROR_FORMAT,
)
self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {}
self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {}

View File

@ -1,7 +1,9 @@
from __future__ import annotations
from inspect import isclass from inspect import isclass
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from warnings import warn from warnings import warn
from sanic.errorpages import check_error_format from sanic.errorpages import check_error_format
@ -9,6 +11,10 @@ from sanic.http import Http
from sanic.utils import load_module_from_file_location, str_to_bool from sanic.utils import load_module_from_file_location, str_to_bool
if TYPE_CHECKING: # no cov
from sanic import Sanic
SANIC_PREFIX = "SANIC_" SANIC_PREFIX = "SANIC_"
@ -73,10 +79,13 @@ class Config(dict):
load_env: Optional[Union[bool, str]] = True, load_env: Optional[Union[bool, str]] = True,
env_prefix: Optional[str] = SANIC_PREFIX, env_prefix: Optional[str] = SANIC_PREFIX,
keep_alive: Optional[bool] = None, keep_alive: Optional[bool] = None,
*,
app: Optional[Sanic] = None,
): ):
defaults = defaults or {} defaults = defaults or {}
super().__init__({**DEFAULT_CONFIG, **defaults}) super().__init__({**DEFAULT_CONFIG, **defaults})
self._app = app
self._LOGO = "" self._LOGO = ""
if keep_alive is not None: if keep_alive is not None:
@ -99,6 +108,7 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
self._check_error_format() self._check_error_format()
self._init = True
def __getattr__(self, attr): def __getattr__(self, attr):
try: try:
@ -106,8 +116,20 @@ class Config(dict):
except KeyError as ke: except KeyError as ke:
raise AttributeError(f"Config has no '{ke.args[0]}'") raise AttributeError(f"Config has no '{ke.args[0]}'")
def __setattr__(self, attr, value): def __setattr__(self, attr, value) -> None:
self[attr] = value self.update({attr: value})
def __setitem__(self, attr, value) -> None:
self.update({attr: value})
def update(self, *other, **kwargs) -> None:
other_mapping = {k: v for item in other for k, v in dict(item).items()}
super().update(*other, **kwargs)
for attr, value in {**other_mapping, **kwargs}.items():
self._post_set(attr, value)
def _post_set(self, attr, value) -> None:
if self.get("_init"):
if attr in ( if attr in (
"REQUEST_MAX_HEADER_SIZE", "REQUEST_MAX_HEADER_SIZE",
"REQUEST_BUFFER_SIZE", "REQUEST_BUFFER_SIZE",
@ -116,6 +138,14 @@ class Config(dict):
self._configure_header_size() self._configure_header_size()
elif attr == "FALLBACK_ERROR_FORMAT": elif attr == "FALLBACK_ERROR_FORMAT":
self._check_error_format() self._check_error_format()
if self.app and value != self.app.error_handler.fallback:
if self.app.error_handler.fallback != "auto":
warn(
"Overriding non-default ErrorHandler fallback "
"value. Changing from "
f"{self.app.error_handler.fallback} to {value}."
)
self.app.error_handler.fallback = value
elif attr == "LOGO": elif attr == "LOGO":
self._LOGO = value self._LOGO = value
warn( warn(
@ -124,6 +154,10 @@ class Config(dict):
DeprecationWarning, DeprecationWarning,
) )
@property
def app(self):
return self._app
@property @property
def LOGO(self): def LOGO(self):
return self._LOGO return self._LOGO

View File

@ -383,6 +383,7 @@ def exception_response(
""" """
content_type = None content_type = None
print("exception_response", fallback)
if not renderer: if not renderer:
# Make sure we have something set # Make sure we have something set
renderer = base renderer = base
@ -393,6 +394,7 @@ def exception_response(
# from the route # from the route
if request.route: if request.route:
try: try:
if request.route.ctx.error_format:
render_format = request.route.ctx.error_format render_format = request.route.ctx.error_format
except AttributeError: except AttributeError:
... ...

View File

@ -918,7 +918,7 @@ class RouteMixin:
return route return route
def _determine_error_format(self, handler) -> str: def _determine_error_format(self, handler) -> Optional[str]:
if not isinstance(handler, CompositionView): if not isinstance(handler, CompositionView):
try: try:
src = dedent(getsource(handler)) src = dedent(getsource(handler))
@ -930,7 +930,7 @@ class RouteMixin:
except (OSError, TypeError): except (OSError, TypeError):
... ...
return "auto" return None
def _get_response_types(self, node): def _get_response_types(self, node):
types = set() types = set()

View File

@ -139,10 +139,9 @@ class Router(BaseRouter):
route.ctx.stream = stream route.ctx.stream = stream
route.ctx.hosts = hosts route.ctx.hosts = hosts
route.ctx.static = static route.ctx.static = static
route.ctx.error_format = ( route.ctx.error_format = error_format
error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT
)
if error_format:
check_error_format(route.ctx.error_format) check_error_format(route.ctx.error_format)
routes.append(route) routes.append(route)

View File

@ -1,9 +1,9 @@
from contextlib import contextmanager from contextlib import contextmanager
from email import message
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from textwrap import dedent from textwrap import dedent
from unittest.mock import Mock
import pytest import pytest
@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app):
) )
with pytest.warns(DeprecationWarning, match=message): with pytest.warns(DeprecationWarning, match=message):
app.config.LOGO = "My Custom Logo" app.config.LOGO = "My Custom Logo"
def test_config_set_methods(app, monkeypatch):
post_set = Mock()
monkeypatch.setattr(Config, "_post_set", post_set)
app.config.FOO = 1
post_set.assert_called_once_with("FOO", 1)
post_set.reset_mock()
app.config["FOO"] = 2
post_set.assert_called_once_with("FOO", 2)
post_set.reset_mock()
app.config.update({"FOO": 3})
post_set.assert_called_once_with("FOO", 3)
post_set.reset_mock()
app.config.update([("FOO", 4)])
post_set.assert_called_once_with("FOO", 4)
post_set.reset_mock()
app.config.update(FOO=5)
post_set.assert_called_once_with("FOO", 5)
post_set.reset_mock()
app.config.update_config({"FOO": 6})
post_set.assert_called_once_with("FOO", 6)

View File

@ -3,6 +3,7 @@ import pytest
from sanic import Sanic from sanic import Sanic
from sanic.errorpages import HTMLRenderer, exception_response from sanic.errorpages import HTMLRenderer, exception_response
from sanic.exceptions import NotFound, SanicException from sanic.exceptions import NotFound, SanicException
from sanic.handlers import ErrorHandler
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, html, json, text from sanic.response import HTTPResponse, html, json, text
@ -271,3 +272,44 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected):
) )
assert response.content_type == expected assert response.content_type == expected
def test_allow_fallback_error_format_set_main_process_start(app):
@app.main_process_start
async def start(app, _):
app.config.FALLBACK_ERROR_FORMAT = "text"
request, response = app.test_client.get("/error")
assert request.app.error_handler.fallback == "text"
assert response.status == 500
assert response.content_type == "text/plain; charset=utf-8"
def test_setting_fallback_to_non_default_raise_warning(app):
app.error_handler = ErrorHandler(fallback="text")
assert app.error_handler.fallback == "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to auto."
),
):
app.config.FALLBACK_ERROR_FORMAT = "auto"
assert app.error_handler.fallback == "auto"
app.config.FALLBACK_ERROR_FORMAT = "text"
with pytest.warns(
UserWarning,
match=(
"Overriding non-default ErrorHandler fallback value. "
"Changing from text to json."
),
):
app.config.FALLBACK_ERROR_FORMAT = "json"
assert app.error_handler.fallback == "json"