From 9a9f72ad64e919a3bc9cff6a81f2fbccea73f97a Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 14 Nov 2021 23:21:14 +0200 Subject: [PATCH 1/8] Move builtin signals to enum (#2309) * Move builtin signals to enum * Fix annotations --- sanic/blueprints.py | 3 ++- sanic/mixins/signals.py | 11 ++++---- sanic/signals.py | 57 +++++++++++++++++++++++++++-------------- tests/test_signals.py | 20 +++++++++++++++ 4 files changed, 65 insertions(+), 26 deletions(-) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index e5e1d333..e13cafcd 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,6 +4,7 @@ import asyncio from collections import defaultdict from copy import deepcopy +from enum import Enum from types import SimpleNamespace from typing import ( TYPE_CHECKING, @@ -144,7 +145,7 @@ class Blueprint(BaseSanic): kwargs["apply"] = False return super().exception(*args, **kwargs) - def signal(self, event: str, *args, **kwargs): + def signal(self, event: Union[str, Enum], *args, **kwargs): kwargs["apply"] = False return super().signal(event, *args, **kwargs) diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index 2be9fee2..57b01b46 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Optional, Set +from enum import Enum +from typing import Any, Callable, Dict, Optional, Set, Union from sanic.models.futures import FutureSignal from sanic.models.handler_types import SignalHandler @@ -19,7 +20,7 @@ class SignalMixin: def signal( self, - event: str, + event: Union[str, Enum], *, apply: bool = True, condition: Dict[str, Any] = None, @@ -41,13 +42,11 @@ class SignalMixin: filtering, defaults to None :type condition: Dict[str, Any], optional """ + event_value = str(event.value) if isinstance(event, Enum) else event def decorator(handler: SignalHandler): - nonlocal event - nonlocal apply - future_signal = FutureSignal( - handler, event, HashableDict(condition or {}) + handler, event_value, HashableDict(condition or {}) ) self._future_signals.add(future_signal) diff --git a/sanic/signals.py b/sanic/signals.py index 9da7eccd..7bb510fa 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from enum import Enum from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union @@ -14,29 +15,47 @@ from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler +class Event(Enum): + SERVER_INIT_AFTER = "server.init.after" + SERVER_INIT_BEFORE = "server.init.before" + SERVER_SHUTDOWN_AFTER = "server.shutdown.after" + SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" + HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" + HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" + HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" + HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle" + HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body" + HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head" + HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request" + HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response" + HTTP_ROUTING_AFTER = "http.routing.after" + HTTP_ROUTING_BEFORE = "http.routing.before" + HTTP_LIFECYCLE_SEND = "http.lifecycle.send" + HTTP_MIDDLEWARE_AFTER = "http.middleware.after" + HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" + + RESERVED_NAMESPACES = { "server": ( - # "server.main.start", - # "server.main.stop", - "server.init.before", - "server.init.after", - "server.shutdown.before", - "server.shutdown.after", + Event.SERVER_INIT_AFTER.value, + Event.SERVER_INIT_BEFORE.value, + Event.SERVER_SHUTDOWN_AFTER.value, + Event.SERVER_SHUTDOWN_BEFORE.value, ), "http": ( - "http.lifecycle.begin", - "http.lifecycle.complete", - "http.lifecycle.exception", - "http.lifecycle.handle", - "http.lifecycle.read_body", - "http.lifecycle.read_head", - "http.lifecycle.request", - "http.lifecycle.response", - "http.routing.after", - "http.routing.before", - "http.lifecycle.send", - "http.middleware.after", - "http.middleware.before", + Event.HTTP_LIFECYCLE_BEGIN.value, + Event.HTTP_LIFECYCLE_COMPLETE.value, + Event.HTTP_LIFECYCLE_EXCEPTION.value, + Event.HTTP_LIFECYCLE_HANDLE.value, + Event.HTTP_LIFECYCLE_READ_BODY.value, + Event.HTTP_LIFECYCLE_READ_HEAD.value, + Event.HTTP_LIFECYCLE_REQUEST.value, + Event.HTTP_LIFECYCLE_RESPONSE.value, + Event.HTTP_ROUTING_AFTER.value, + Event.HTTP_ROUTING_BEFORE.value, + Event.HTTP_LIFECYCLE_SEND.value, + Event.HTTP_MIDDLEWARE_AFTER.value, + Event.HTTP_MIDDLEWARE_BEFORE.value, ), } diff --git a/tests/test_signals.py b/tests/test_signals.py index 9b8a9495..51aea3c8 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,5 +1,6 @@ import asyncio +from enum import Enum from inspect import isawaitable import pytest @@ -50,6 +51,25 @@ def test_invalid_signal(app, signal): ... +@pytest.mark.asyncio +async def test_dispatch_signal_with_enum_event(app): + counter = 0 + + class FooEnum(Enum): + FOO_BAR_BAZ = "foo.bar.baz" + + @app.signal(FooEnum.FOO_BAR_BAZ) + def sync_signal(*_): + nonlocal counter + + counter += 1 + + app.signal_router.finalize() + + await app.dispatch("foo.bar.baz") + assert counter == 1 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_handlers(app): counter = 0 From abeb8d0bc0ce6c4e7ec18c794e9ecade4826f090 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 16 Nov 2021 10:16:32 +0200 Subject: [PATCH 2/8] Provide list of reloaded files (#2307) * Allow access to reloaded files * Return to simple boolean values * Resolve before adding to changed files --- sanic/reloader_helpers.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index 3a91a8f0..3c726edb 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -47,16 +47,18 @@ def _get_args_for_reloading(): return [sys.executable] + sys.argv -def restart_with_reloader(): +def restart_with_reloader(changed=None): """Create a new process and a subprocess in it with the same arguments as this one. """ + reloaded = ",".join(changed) if changed else "" return subprocess.Popen( _get_args_for_reloading(), env={ **os.environ, "SANIC_SERVER_RUNNING": "true", "SANIC_RELOADER_PROCESS": "true", + "SANIC_RELOADED_FILES": reloaded, }, ) @@ -94,24 +96,27 @@ def watchdog(sleep_interval, app): try: while True: - need_reload = False + changed = set() for filename in itertools.chain( _iter_module_files(), *(d.glob("**/*") for d in app.reload_dirs), ): try: - check = _check_file(filename, mtimes) + if _check_file(filename, mtimes): + path = ( + filename + if isinstance(filename, str) + else filename.resolve() + ) + changed.add(str(path)) except OSError: continue - if check: - need_reload = True - - if need_reload: + if changed: worker_process.terminate() worker_process.wait() - worker_process = restart_with_reloader() + worker_process = restart_with_reloader(changed) sleep(sleep_interval) except KeyboardInterrupt: From cde02b5936838e7a1574ba094e44d987176848d9 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 16 Nov 2021 13:07:33 +0200 Subject: [PATCH 3/8] More consistent config setting with post-FALLBACK_ERROR_FORMAT apply (#2310) * Update unit testing and add more consistent config * Change init and app values to private * Cleanup line lengths --- sanic/app.py | 8 ++--- sanic/config.py | 70 +++++++++++++++++++++++++++++----------- sanic/errorpages.py | 4 ++- sanic/mixins/routes.py | 4 +-- sanic/router.py | 7 ++-- tests/test_config.py | 30 ++++++++++++++++- tests/test_errorpages.py | 42 ++++++++++++++++++++++++ 7 files changed, 135 insertions(+), 30 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index fb4ed4eb..30af974a 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -190,14 +190,14 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._state: ApplicationState = ApplicationState(app=self) self.blueprints: Dict[str, Blueprint] = {} self.config: Config = config or Config( - load_env=load_env, env_prefix=env_prefix + load_env=load_env, + env_prefix=env_prefix, + app=self, ) self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.debug = False - self.error_handler: ErrorHandler = error_handler or ErrorHandler( - fallback=self.config.FALLBACK_ERROR_FORMAT, - ) + self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} diff --git a/sanic/config.py b/sanic/config.py index 496ceadb..ebe1a9a6 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from inspect import isclass from os import environ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from warnings import warn from sanic.errorpages import check_error_format @@ -9,6 +11,10 @@ from sanic.http import Http from sanic.utils import load_module_from_file_location, str_to_bool +if TYPE_CHECKING: # no cov + from sanic import Sanic + + SANIC_PREFIX = "SANIC_" @@ -73,10 +79,13 @@ class Config(dict): load_env: Optional[Union[bool, str]] = True, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, + *, + app: Optional[Sanic] = None, ): defaults = defaults or {} super().__init__({**DEFAULT_CONFIG, **defaults}) + self._app = app self._LOGO = "" if keep_alive is not None: @@ -99,6 +108,7 @@ class Config(dict): self._configure_header_size() self._check_error_format() + self._init = True def __getattr__(self, attr): try: @@ -106,23 +116,47 @@ class Config(dict): except KeyError as ke: raise AttributeError(f"Config has no '{ke.args[0]}'") - def __setattr__(self, attr, value): - self[attr] = value - if attr in ( - "REQUEST_MAX_HEADER_SIZE", - "REQUEST_BUFFER_SIZE", - "REQUEST_MAX_SIZE", - ): - self._configure_header_size() - elif attr == "FALLBACK_ERROR_FORMAT": - self._check_error_format() - elif attr == "LOGO": - self._LOGO = value - warn( - "Setting the config.LOGO is deprecated and will no longer " - "be supported starting in v22.6.", - DeprecationWarning, - ) + def __setattr__(self, attr, value) -> None: + self.update({attr: value}) + + def __setitem__(self, attr, value) -> None: + self.update({attr: value}) + + def update(self, *other, **kwargs) -> None: + other_mapping = {k: v for item in other for k, v in dict(item).items()} + super().update(*other, **kwargs) + for attr, value in {**other_mapping, **kwargs}.items(): + self._post_set(attr, value) + + def _post_set(self, attr, value) -> None: + if self.get("_init"): + if attr in ( + "REQUEST_MAX_HEADER_SIZE", + "REQUEST_BUFFER_SIZE", + "REQUEST_MAX_SIZE", + ): + self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() + if self.app and value != self.app.error_handler.fallback: + if self.app.error_handler.fallback != "auto": + warn( + "Overriding non-default ErrorHandler fallback " + "value. Changing from " + f"{self.app.error_handler.fallback} to {value}." + ) + self.app.error_handler.fallback = value + elif attr == "LOGO": + self._LOGO = value + warn( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + DeprecationWarning, + ) + + @property + def app(self): + return self._app @property def LOGO(self): diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 82cdd57a..d046c29d 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -383,6 +383,7 @@ def exception_response( """ content_type = None + print("exception_response", fallback) if not renderer: # Make sure we have something set renderer = base @@ -393,7 +394,8 @@ def exception_response( # from the route if request.route: try: - render_format = request.route.ctx.error_format + if request.route.ctx.error_format: + render_format = request.route.ctx.error_format except AttributeError: ... diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 8467a2e3..7139cd3c 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -918,7 +918,7 @@ class RouteMixin: return route - def _determine_error_format(self, handler) -> str: + def _determine_error_format(self, handler) -> Optional[str]: if not isinstance(handler, CompositionView): try: src = dedent(getsource(handler)) @@ -930,7 +930,7 @@ class RouteMixin: except (OSError, TypeError): ... - return "auto" + return None def _get_response_types(self, node): types = set() diff --git a/sanic/router.py b/sanic/router.py index 6995ed6d..b15c2a3e 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -139,11 +139,10 @@ class Router(BaseRouter): route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static - route.ctx.error_format = ( - error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT - ) + route.ctx.error_format = error_format - check_error_format(route.ctx.error_format) + if error_format: + check_error_format(route.ctx.error_format) routes.append(route) diff --git a/tests/test_config.py b/tests/test_config.py index f3447666..67324f1e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,9 @@ from contextlib import contextmanager -from email import message from os import environ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent +from unittest.mock import Mock import pytest @@ -360,3 +360,31 @@ def test_deprecation_notice_when_setting_logo(app): ) with pytest.warns(DeprecationWarning, match=message): app.config.LOGO = "My Custom Logo" + + +def test_config_set_methods(app, monkeypatch): + post_set = Mock() + monkeypatch.setattr(Config, "_post_set", post_set) + + app.config.FOO = 1 + post_set.assert_called_once_with("FOO", 1) + post_set.reset_mock() + + app.config["FOO"] = 2 + post_set.assert_called_once_with("FOO", 2) + post_set.reset_mock() + + app.config.update({"FOO": 3}) + post_set.assert_called_once_with("FOO", 3) + post_set.reset_mock() + + app.config.update([("FOO", 4)]) + post_set.assert_called_once_with("FOO", 4) + post_set.reset_mock() + + app.config.update(FOO=5) + post_set.assert_called_once_with("FOO", 5) + post_set.reset_mock() + + app.config.update_config({"FOO": 6}) + post_set.assert_called_once_with("FOO", 6) diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 5af4ca5f..84949fde 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -3,6 +3,7 @@ import pytest from sanic import Sanic from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException +from sanic.handlers import ErrorHandler from sanic.request import Request from sanic.response import HTTPResponse, html, json, text @@ -271,3 +272,44 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): ) assert response.content_type == expected + + +def test_allow_fallback_error_format_set_main_process_start(app): + @app.main_process_start + async def start(app, _): + app.config.FALLBACK_ERROR_FORMAT = "text" + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_setting_fallback_to_non_default_raise_warning(app): + app.error_handler = ErrorHandler(fallback="text") + + assert app.error_handler.fallback == "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to auto." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + assert app.error_handler.fallback == "auto" + + app.config.FALLBACK_ERROR_FORMAT = "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to json." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "json" + + assert app.error_handler.fallback == "json" From b731a6b48c8bb6148e46df79d39a635657c9c1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= <98187+Tronic@users.noreply.github.com> Date: Tue, 16 Nov 2021 13:03:27 -0800 Subject: [PATCH 4/8] Make HTTP connections start in IDLE stage, avoiding delays and error messages (#2268) * Make all new connections start in IDLE stage, and switch to REQUEST stage only once any bytes are received from client. This makes new connections without any request obey keepalive timeout rather than request timeout like they currently do. * Revert typo * Remove request timeout endpoint test which is no longer working (still tested by mocking). Fix mock timeout test setup. Co-authored-by: L. Karkkainen --- sanic/http.py | 22 +++---- tests/test_request_timeout.py | 109 ---------------------------------- tests/test_timeout_logic.py | 1 + 3 files changed, 9 insertions(+), 123 deletions(-) delete mode 100644 tests/test_request_timeout.py diff --git a/sanic/http.py b/sanic/http.py index d30e4c82..6f59ef25 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -105,7 +105,6 @@ class Http(metaclass=TouchUpMeta): self.keep_alive = True self.stage: Stage = Stage.IDLE self.dispatch = self.protocol.app.dispatch - self.init_for_request() def init_for_request(self): """Init/reset all per-request variables.""" @@ -129,14 +128,20 @@ class Http(metaclass=TouchUpMeta): """ HTTP 1.1 connection handler """ - while True: # As long as connection stays keep-alive + # Handle requests while the connection stays reusable + while self.keep_alive and self.stage is Stage.IDLE: + self.init_for_request() + # Wait for incoming bytes (in IDLE stage) + if not self.recv_buffer: + await self._receive_more() + self.stage = Stage.REQUEST try: # Receive and handle a request - self.stage = Stage.REQUEST self.response_func = self.http1_response_header await self.http1_request_header() + self.stage = Stage.HANDLER self.request.conn_info = self.protocol.conn_info await self.protocol.request_handler(self.request) @@ -187,16 +192,6 @@ class Http(metaclass=TouchUpMeta): if self.response: self.response.stream = None - # Exit and disconnect if no more requests can be taken - if self.stage is not Stage.IDLE or not self.keep_alive: - break - - self.init_for_request() - - # Wait for the next request - if not self.recv_buffer: - await self._receive_more() - async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. @@ -299,7 +294,6 @@ class Http(metaclass=TouchUpMeta): # Remove header and its trailing CRLF del buf[: pos + 4] - self.stage = Stage.HANDLER self.request, request.stream = request, self self.protocol.state["requests_count"] += 1 diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py deleted file mode 100644 index 48e23f1d..00000000 --- a/tests/test_request_timeout.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio - -import httpcore -import httpx -import pytest - -from sanic_testing.testing import SanicTestClient - -from sanic import Sanic -from sanic.response import text - - -class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): - async def arequest(self, *args, **kwargs): - await asyncio.sleep(2) - return await super().arequest(*args, **kwargs) - - async def _open_socket(self, *args, **kwargs): - retval = await super()._open_socket(*args, **kwargs) - if self._request_delay: - await asyncio.sleep(self._request_delay) - return retval - - -class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): - def __init__(self, request_delay=None, *args, **kwargs): - self._request_delay = request_delay - super().__init__(*args, **kwargs) - - async def _add_to_pool(self, connection, timeout): - connection.__class__ = DelayableHTTPConnection - connection._request_delay = self._request_delay - await super()._add_to_pool(connection, timeout) - - -class DelayableSanicSession(httpx.AsyncClient): - def __init__(self, request_delay=None, *args, **kwargs) -> None: - transport = DelayableSanicConnectionPool(request_delay=request_delay) - super().__init__(transport=transport, *args, **kwargs) - - -class DelayableSanicTestClient(SanicTestClient): - def __init__(self, app, request_delay=None): - super().__init__(app) - self._request_delay = request_delay - self._loop = None - - def get_new_session(self): - return DelayableSanicSession(request_delay=self._request_delay) - - -@pytest.fixture -def request_no_timeout_app(): - app = Sanic("test_request_no_timeout") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler2(request): - return text("OK") - - return app - - -@pytest.fixture -def request_timeout_default_app(): - app = Sanic("test_request_timeout_default") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler1(request): - return text("OK") - - @app.websocket("/ws1") - async def ws_handler1(request, ws): - await ws.send("OK") - - return app - - -def test_default_server_error_request_timeout(request_timeout_default_app): - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/1") - assert response.status == 408 - assert "Request Timeout" in response.text - - -def test_default_server_error_request_dont_timeout(request_no_timeout_app): - client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - _, response = client.get("/1") - assert response.status == 200 - assert response.text == "OK" - - -def test_default_server_error_websocket_request_timeout( - request_timeout_default_app, -): - - headers = { - "Upgrade": "websocket", - "Connection": "upgrade", - "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version": "13", - } - - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/ws1", headers=headers) - - assert response.status == 408 - assert "Request Timeout" in response.text diff --git a/tests/test_timeout_logic.py b/tests/test_timeout_logic.py index 05249f11..497deda9 100644 --- a/tests/test_timeout_logic.py +++ b/tests/test_timeout_logic.py @@ -26,6 +26,7 @@ def protocol(app, mock_transport): protocol = HttpProtocol(loop=loop, app=app) protocol.connection_made(mock_transport) protocol._setup_connection() + protocol._http.init_for_request() protocol._task = Mock(spec=asyncio.Task) protocol._task.cancel = Mock() return protocol From 85e7b712b90a82bbf7f771732495515181272c62 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 17 Nov 2021 17:29:41 +0200 Subject: [PATCH 5/8] Allow early Blueprint registrations to still apply later added objects (#2260) --- sanic/app.py | 4 ++ sanic/blueprints.py | 126 +++++++++++++++++++++++++---------- sanic/models/futures.py | 4 ++ tests/test_blueprint_copy.py | 4 +- tests/test_blueprints.py | 28 ++++++++ 5 files changed, 129 insertions(+), 37 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 30af974a..d78c67de 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -72,6 +72,7 @@ from sanic.models.futures import ( FutureException, FutureListener, FutureMiddleware, + FutureRegistry, FutureRoute, FutureSignal, FutureStatic, @@ -115,6 +116,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "_future_exceptions", "_future_listeners", "_future_middleware", + "_future_registry", "_future_routes", "_future_signals", "_future_statics", @@ -187,6 +189,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._test_manager: Any = None self._blueprint_order: List[Blueprint] = [] self._delayed_tasks: List[str] = [] + self._future_registry: FutureRegistry = FutureRegistry() self._state: ApplicationState = ApplicationState(app=self) self.blueprints: Dict[str, Blueprint] = {} self.config: Config = config or Config( @@ -1625,6 +1628,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): raise e async def _startup(self): + self._future_registry.clear() self.signalize() self.finalize() ErrorHandler.finalize(self.error_handler) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index e13cafcd..290773fa 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,7 +4,9 @@ import asyncio from collections import defaultdict from copy import deepcopy -from enum import Enum +from functools import wraps +from inspect import isfunction +from itertools import chain from types import SimpleNamespace from typing import ( TYPE_CHECKING, @@ -13,7 +15,9 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, + Tuple, Union, ) @@ -36,6 +40,32 @@ if TYPE_CHECKING: from sanic import Sanic # noqa +def lazy(func, as_decorator=True): + @wraps(func) + def decorator(bp, *args, **kwargs): + nonlocal as_decorator + kwargs["apply"] = False + pass_handler = None + + if args and isfunction(args[0]): + as_decorator = False + + def wrapper(handler): + future = func(bp, *args, **kwargs) + if as_decorator: + future = future(handler) + + if bp.registered: + for app in bp.apps: + bp.register(app, {}) + + return future + + return wrapper if as_decorator else wrapper(pass_handler) + + return decorator + + class Blueprint(BaseSanic): """ In *Sanic* terminology, a **Blueprint** is a logical collection of @@ -125,29 +155,16 @@ class Blueprint(BaseSanic): ) return self._apps - def route(self, *args, **kwargs): - kwargs["apply"] = False - return super().route(*args, **kwargs) + @property + def registered(self) -> bool: + return bool(self._apps) - def static(self, *args, **kwargs): - kwargs["apply"] = False - return super().static(*args, **kwargs) - - def middleware(self, *args, **kwargs): - kwargs["apply"] = False - return super().middleware(*args, **kwargs) - - def listener(self, *args, **kwargs): - kwargs["apply"] = False - return super().listener(*args, **kwargs) - - def exception(self, *args, **kwargs): - kwargs["apply"] = False - return super().exception(*args, **kwargs) - - def signal(self, event: Union[str, Enum], *args, **kwargs): - kwargs["apply"] = False - return super().signal(event, *args, **kwargs) + exception = lazy(BaseSanic.exception) + listener = lazy(BaseSanic.listener) + middleware = lazy(BaseSanic.middleware) + route = lazy(BaseSanic.route) + signal = lazy(BaseSanic.signal) + static = lazy(BaseSanic.static, as_decorator=False) def reset(self): self._apps: Set[Sanic] = set() @@ -284,6 +301,7 @@ class Blueprint(BaseSanic): middleware = [] exception_handlers = [] listeners = defaultdict(list) + registered = set() # Routes for future in self._future_routes: @@ -310,12 +328,15 @@ class Blueprint(BaseSanic): ) name = app._generate_name(future.name) + host = future.host or self.host + if isinstance(host, list): + host = tuple(host) apply_route = FutureRoute( future.handler, uri[1:] if uri.startswith("//") else uri, future.methods, - future.host or self.host, + host, strict_slashes, future.stream, version, @@ -329,6 +350,10 @@ class Blueprint(BaseSanic): error_format, ) + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_route(apply_route) operation = ( routes.extend if isinstance(route, list) else routes.append @@ -340,6 +365,11 @@ class Blueprint(BaseSanic): # Prepend the blueprint URI prefix if available uri = url_prefix + future.uri if url_prefix else future.uri apply_route = FutureStatic(uri, *future[1:]) + + if (self, apply_route) in app._future_registry: + continue + + registered.add(apply_route) route = app._apply_static(apply_route) routes.append(route) @@ -348,30 +378,51 @@ class Blueprint(BaseSanic): if route_names: # Middleware for future in self._future_middleware: + if (self, future) in app._future_registry: + continue middleware.append(app._apply_middleware(future, route_names)) # Exceptions for future in self._future_exceptions: + if (self, future) in app._future_registry: + continue exception_handlers.append( app._apply_exception_handler(future, route_names) ) # Event listeners - for listener in self._future_listeners: - listeners[listener.event].append(app._apply_listener(listener)) + for future in self._future_listeners: + if (self, future) in app._future_registry: + continue + listeners[future.event].append(app._apply_listener(future)) # Signals - for signal in self._future_signals: - signal.condition.update({"blueprint": self.name}) - app._apply_signal(signal) + for future in self._future_signals: + if (self, future) in app._future_registry: + continue + future.condition.update({"blueprint": self.name}) + app._apply_signal(future) - self.routes = [route for route in routes if isinstance(route, Route)] - self.websocket_routes = [ + self.routes += [route for route in routes if isinstance(route, Route)] + self.websocket_routes += [ route for route in self.routes if route.ctx.websocket ] - self.middlewares = middleware - self.exceptions = exception_handlers - self.listeners = dict(listeners) + self.middlewares += middleware + self.exceptions += exception_handlers + self.listeners.update(dict(listeners)) + + if self.registered: + self.register_futures( + self.apps, + self, + chain( + registered, + self._future_middleware, + self._future_exceptions, + self._future_listeners, + self._future_signals, + ), + ) async def dispatch(self, *args, **kwargs): condition = kwargs.pop("condition", {}) @@ -403,3 +454,10 @@ class Blueprint(BaseSanic): value = v break return value + + @staticmethod + def register_futures( + apps: Set[Sanic], bp: Blueprint, futures: Sequence[Tuple[Any, ...]] + ): + for app in apps: + app._future_registry.update(set((bp, item) for item in futures)) diff --git a/sanic/models/futures.py b/sanic/models/futures.py index fe7d77eb..74ee92b9 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -60,3 +60,7 @@ class FutureSignal(NamedTuple): handler: SignalHandler event: str condition: Optional[Dict[str, str]] + + +class FutureRegistry(set): + ... diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py index 033e2e20..ca8cd67e 100644 --- a/tests/test_blueprint_copy.py +++ b/tests/test_blueprint_copy.py @@ -1,6 +1,4 @@ -from copy import deepcopy - -from sanic import Blueprint, Sanic, blueprints, response +from sanic import Blueprint, Sanic from sanic.response import text diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index b6a23151..3aa4487a 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -1088,3 +1088,31 @@ def test_bp_set_attribute_warning(): "and will be removed in version 21.12. You should change your " "Blueprint instance to use instance.ctx.foo instead." ) + + +def test_early_registration(app): + assert len(app.router.routes) == 0 + + bp = Blueprint("bp") + + @bp.get("/one") + async def one(_): + return text("one") + + app.blueprint(bp) + + assert len(app.router.routes) == 1 + + @bp.get("/two") + async def two(_): + return text("two") + + @bp.get("/three") + async def three(_): + return text("three") + + assert len(app.router.routes) == 3 + + for path in ("one", "two", "three"): + _, response = app.test_client.get(f"/{path}") + assert response.text == path From 0860bfe1f19e3b051a31eb12a5c2a13475a8eb2d Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 17 Nov 2021 19:36:36 +0200 Subject: [PATCH 6/8] Merge release 21.9.2 (#2313) --- sanic/app.py | 4 +++- sanic/handlers.py | 9 ++++++++- tests/test_errorpages.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index d78c67de..566266e0 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1631,7 +1631,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self._future_registry.clear() self.signalize() self.finalize() - ErrorHandler.finalize(self.error_handler) + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) TouchUp.run(self) async def _server_event( diff --git a/sanic/handlers.py b/sanic/handlers.py index af667c9a..046e56e1 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -38,7 +38,14 @@ class ErrorHandler: self.base = base @classmethod - def finalize(cls, error_handler): + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + if not isinstance(error_handler, cls): error_logger.warning( f"Error handler is non-conforming: {type(error_handler)}" diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 84949fde..1843f6a7 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,6 +1,7 @@ import pytest from sanic import Sanic +from sanic.config import Config from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException from sanic.handlers import ErrorHandler @@ -313,3 +314,31 @@ def test_setting_fallback_to_non_default_raise_warning(app): app.config.FALLBACK_ERROR_FORMAT = "json" assert app.error_handler.fallback == "json" + + +def test_allow_fallback_error_format_in_config_injection(): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app = Sanic("test", config=MyConfig()) + + @app.route("/error", methods=["GET", "POST"]) + def err(request): + raise Exception("something went wrong") + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_allow_fallback_error_format_in_config_replacement(app): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app.config = MyConfig() + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" From 95631b9686376990421419f6243829cd758c4b58 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 18 Nov 2021 14:53:06 +0200 Subject: [PATCH 7/8] Coffee please (#2316) * Coffee please * Add unit tests --- sanic/app.py | 6 ++++- sanic/application/logo.py | 13 +++++++++-- sanic/application/state.py | 1 + tests/test_coffee.py | 48 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 tests/test_coffee.py diff --git a/sanic/app.py b/sanic/app.py index 566266e0..c801cd3d 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -960,6 +960,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # Execution # -------------------------------------------------------------------- # + def make_coffee(self, *args, **kwargs): + self.state.coffee = True + self.run(*args, **kwargs) + def run( self, host: Optional[str] = None, @@ -1562,7 +1566,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): extra.update(self.config.MOTD_DISPLAY) logo = ( - get_logo() + get_logo(coffee=self.state.coffee) if self.config.LOGO == "" or self.config.LOGO is True else self.config.LOGO ) diff --git a/sanic/application/logo.py b/sanic/application/logo.py index 9e3bb2fa..56b8c0b1 100644 --- a/sanic/application/logo.py +++ b/sanic/application/logo.py @@ -10,6 +10,15 @@ BASE_LOGO = """ Build Fast. Run Fast. """ +COFFEE_LOGO = """\033[48;2;255;13;104m \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ▄████████▄ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ██ ██▀▀▄ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ███████████ █ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ███████████▄▄▀ \033[0m +\033[38;2;255;255;255;48;2;255;13;104m ▀███████▀ \033[0m +\033[48;2;255;13;104m \033[0m +Dark roast. No sugar.""" + COLOR_LOGO = """\033[48;2;255;13;104m \033[0m \033[38;2;255;255;255;48;2;255;13;104m ▄███ █████ ██ \033[0m \033[38;2;255;255;255;48;2;255;13;104m ██ \033[0m @@ -32,9 +41,9 @@ FULL_COLOR_LOGO = """ ansi_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") -def get_logo(full=False): +def get_logo(full=False, coffee=False): logo = ( - (FULL_COLOR_LOGO if full else COLOR_LOGO) + (FULL_COLOR_LOGO if full else (COFFEE_LOGO if coffee else COLOR_LOGO)) if sys.stdout.isatty() else BASE_LOGO ) diff --git a/sanic/application/state.py b/sanic/application/state.py index b03c30da..eb180708 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -34,6 +34,7 @@ class Mode(StrEnum): class ApplicationState: app: Sanic asgi: bool = field(default=False) + coffee: bool = field(default=False) fast: bool = field(default=False) host: str = field(default="") mode: Mode = field(default=Mode.PRODUCTION) diff --git a/tests/test_coffee.py b/tests/test_coffee.py new file mode 100644 index 00000000..6143f17f --- /dev/null +++ b/tests/test_coffee.py @@ -0,0 +1,48 @@ +import logging + +from unittest.mock import patch + +import pytest + +from sanic.application.logo import COFFEE_LOGO, get_logo +from sanic.exceptions import SanicException + + +def has_sugar(value): + if value: + raise SanicException("I said no sugar please") + + return False + + +@pytest.mark.parametrize("sugar", (True, False)) +def test_no_sugar(sugar): + if sugar: + with pytest.raises(SanicException): + assert has_sugar(sugar) + else: + assert not has_sugar(sugar) + + +def test_get_logo_returns_expected_logo(): + with patch("sys.stdout.isatty") as isatty: + isatty.return_value = True + logo = get_logo(coffee=True) + assert logo is COFFEE_LOGO + + +def test_logo_true(app, caplog): + @app.after_server_start + async def shutdown(*_): + app.stop() + + with patch("sys.stdout.isatty") as isatty: + isatty.return_value = True + with caplog.at_level(logging.DEBUG): + app.make_coffee() + + # Only in the regular logo + assert " ▄███ █████ ██ " not in caplog.text + + # Only in the coffee logo + assert " ██ ██▀▀▄ " in caplog.text From 523db190a732177eda5a641768667173ba2e2452 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 18 Nov 2021 17:47:27 +0200 Subject: [PATCH 8/8] Add contextual exceptions (#2290) --- sanic/errorpages.py | 143 +++++++++++++++++++++++++++------------ sanic/exceptions.py | 6 +- tests/test_exceptions.py | 117 ++++++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+), 46 deletions(-) diff --git a/sanic/errorpages.py b/sanic/errorpages.py index d046c29d..66ff6c95 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -25,12 +25,13 @@ from sanic.request import Request from sanic.response import HTTPResponse, html, json, text +dumps: t.Callable[..., str] try: from ujson import dumps dumps = partial(dumps, escape_forward_slashes=False) except ImportError: # noqa - from json import dumps # type: ignore + from json import dumps FALLBACK_TEXT = ( @@ -45,6 +46,8 @@ class BaseRenderer: Base class that all renderers must inherit from. """ + dumps = staticmethod(dumps) + def __init__(self, request, exception, debug): self.request = request self.exception = exception @@ -112,14 +115,16 @@ class HTMLRenderer(BaseRenderer): TRACEBACK_STYLE = """ html { font-family: sans-serif } h2 { color: #888; } - .tb-wrapper p { margin: 0 } + .tb-wrapper p, dl, dd { margin: 0 } .frame-border { margin: 1rem } - .frame-line > * { padding: 0.3rem 0.6rem } - .frame-line { margin-bottom: 0.3rem } - .frame-code { font-size: 16px; padding-left: 4ch } - .tb-wrapper { border: 1px solid #eee } - .tb-header { background: #eee; padding: 0.3rem; font-weight: bold } - .frame-descriptor { background: #e2eafb; font-size: 14px } + .frame-line > *, dt, dd { padding: 0.3rem 0.6rem } + .frame-line, dl { margin-bottom: 0.3rem } + .frame-code, dd { font-size: 16px; padding-left: 4ch } + .tb-wrapper, dl { border: 1px solid #eee } + .tb-header,.obj-header { + background: #eee; padding: 0.3rem; font-weight: bold + } + .frame-descriptor, dt { background: #e2eafb; font-size: 14px } """ TRACEBACK_WRAPPER_HTML = ( "
{exc_name}: {exc_value}
" @@ -138,6 +143,11 @@ class HTMLRenderer(BaseRenderer): "

{0.line}" "" ) + OBJECT_WRAPPER_HTML = ( + "

{title}
" + "
{display_html}
" + ) + OBJECT_DISPLAY_HTML = "
{key}
{value}
" OUTPUT_HTML = ( "" "{title}\n" @@ -152,7 +162,7 @@ class HTMLRenderer(BaseRenderer): title=self.title, text=self.text, style=self.TRACEBACK_STYLE, - body=self._generate_body(), + body=self._generate_body(full=True), ), status=self.status, ) @@ -163,7 +173,7 @@ class HTMLRenderer(BaseRenderer): title=self.title, text=self.text, style=self.TRACEBACK_STYLE, - body="", + body=self._generate_body(full=False), ), status=self.status, headers=self.headers, @@ -177,27 +187,49 @@ class HTMLRenderer(BaseRenderer): def title(self): return escape(f"⚠️ {super().title}") - def _generate_body(self): - _, exc_value, __ = sys.exc_info() - exceptions = [] - while exc_value: - exceptions.append(self._format_exc(exc_value)) - exc_value = exc_value.__cause__ + def _generate_body(self, *, full): + lines = [] + if full: + _, exc_value, __ = sys.exc_info() + exceptions = [] + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ + + traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions)) + appname = escape(self.request.app.name) + name = escape(self.exception.__class__.__name__) + value = escape(self.exception) + path = escape(self.request.path) + lines += [ + f"

Traceback of {appname} " "(most recent call last):

", + f"{traceback_html}", + "

", + f"{name}: {value} " + f"while handling path {path}", + "

", + ] + + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + lines.append(self._generate_object_display(info, attr)) - traceback_html = self.TRACEBACK_BORDER.join(reversed(exceptions)) - appname = escape(self.request.app.name) - name = escape(self.exception.__class__.__name__) - value = escape(self.exception) - path = escape(self.request.path) - lines = [ - f"

Traceback of {appname} (most recent call last):

", - f"{traceback_html}", - "

", - f"{name}: {value} while handling path {path}", - "

", - ] return "\n".join(lines) + def _generate_object_display( + self, obj: t.Dict[str, t.Any], descriptor: str + ) -> str: + display = "".join( + self.OBJECT_DISPLAY_HTML.format(key=key, value=value) + for key, value in obj.items() + ) + return self.OBJECT_WRAPPER_HTML.format( + title=descriptor.title(), + display_html=display, + obj_type=descriptor.lower(), + ) + def _format_exc(self, exc): frames = extract_tb(exc.__traceback__) frame_html = "".join( @@ -224,7 +256,7 @@ class TextRenderer(BaseRenderer): title=self.title, text=self.text, bar=("=" * len(self.title)), - body=self._generate_body(), + body=self._generate_body(full=True), ), status=self.status, ) @@ -235,7 +267,7 @@ class TextRenderer(BaseRenderer): title=self.title, text=self.text, bar=("=" * len(self.title)), - body="", + body=self._generate_body(full=False), ), status=self.status, headers=self.headers, @@ -245,21 +277,31 @@ class TextRenderer(BaseRenderer): def title(self): return f"⚠️ {super().title}" - def _generate_body(self): - _, exc_value, __ = sys.exc_info() - exceptions = [] + def _generate_body(self, *, full): + lines = [] + if full: + _, exc_value, __ = sys.exc_info() + exceptions = [] - lines = [ - f"{self.exception.__class__.__name__}: {self.exception} while " - f"handling path {self.request.path}", - f"Traceback of {self.request.app.name} (most recent call last):\n", - ] + lines += [ + f"{self.exception.__class__.__name__}: {self.exception} while " + f"handling path {self.request.path}", + f"Traceback of {self.request.app.name} " + "(most recent call last):\n", + ] - while exc_value: - exceptions.append(self._format_exc(exc_value)) - exc_value = exc_value.__cause__ + while exc_value: + exceptions.append(self._format_exc(exc_value)) + exc_value = exc_value.__cause__ - return "\n".join(lines + exceptions[::-1]) + lines += exceptions[::-1] + + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + lines += self._generate_object_display_list(info, attr) + + return "\n".join(lines) def _format_exc(self, exc): frames = "\n\n".join( @@ -272,6 +314,13 @@ class TextRenderer(BaseRenderer): ) return f"{self.SPACER}{exc.__class__.__name__}: {exc}\n{frames}" + def _generate_object_display_list(self, obj, descriptor): + lines = [f"\n{descriptor.title()}"] + for key, value in obj.items(): + display = self.dumps(value) + lines.append(f"{self.SPACER * 2}{key}: {display}") + return lines + class JSONRenderer(BaseRenderer): """ @@ -280,11 +329,11 @@ class JSONRenderer(BaseRenderer): def full(self) -> HTTPResponse: output = self._generate_output(full=True) - return json(output, status=self.status, dumps=dumps) + return json(output, status=self.status, dumps=self.dumps) def minimal(self) -> HTTPResponse: output = self._generate_output(full=False) - return json(output, status=self.status, dumps=dumps) + return json(output, status=self.status, dumps=self.dumps) def _generate_output(self, *, full): output = { @@ -293,6 +342,11 @@ class JSONRenderer(BaseRenderer): "message": self.text, } + for attr, display in (("context", True), ("extra", bool(full))): + info = getattr(self.exception, attr, None) + if info and display: + output[attr] = info + if full: _, exc_value, __ = sys.exc_info() exceptions = [] @@ -383,7 +437,6 @@ def exception_response( """ content_type = None - print("exception_response", fallback) if not renderer: # Make sure we have something set renderer = base diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 1bb06f1d..6459f15a 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from sanic.helpers import STATUS_CODES @@ -11,7 +11,11 @@ class SanicException(Exception): message: Optional[Union[str, bytes]] = None, status_code: Optional[int] = None, quiet: Optional[bool] = None, + context: Optional[Dict[str, Any]] = None, + extra: Optional[Dict[str, Any]] = None, ) -> None: + self.context = context + self.extra = extra if message is None: if self.message: message = self.message diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0485137a..eea97935 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -18,6 +18,16 @@ from sanic.exceptions import ( from sanic.response import text +def dl_to_dict(soup, css_class): + keys, values = [], [] + for dl in soup.find_all("dl", {"class": css_class}): + for dt in dl.find_all("dt"): + keys.append(dt.text.strip()) + for dd in dl.find_all("dd"): + values.append(dd.text.strip()) + return dict(zip(keys, values)) + + class SanicExceptionTestException(Exception): pass @@ -264,3 +274,110 @@ def test_exception_in_ws_logged(caplog): error_logs = [r for r in caplog.record_tuples if r[0] == "sanic.error"] assert error_logs[1][1] == logging.ERROR assert "Exception occurred while handling uri:" in error_logs[1][2] + + +@pytest.mark.parametrize("debug", (True, False)) +def test_contextual_exception_context(debug): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + message = "Sorry, I cannot brew coffee" + + def fail(): + raise TeapotError(context={"foo": "bar"}) + + app.post("/coffee/json", error_format="json")(lambda _: fail()) + app.post("/coffee/html", error_format="html")(lambda _: fail()) + app.post("/coffee/text", error_format="text")(lambda _: fail()) + + _, response = app.test_client.post("/coffee/json", debug=debug) + assert response.status == 418 + assert response.json["message"] == "Sorry, I cannot brew coffee" + assert response.json["context"] == {"foo": "bar"} + + _, response = app.test_client.post("/coffee/html", debug=debug) + soup = BeautifulSoup(response.body, "html.parser") + dl = dl_to_dict(soup, "context") + assert response.status == 418 + assert "Sorry, I cannot brew coffee" in soup.find("p").text + assert dl == {"foo": "bar"} + + _, response = app.test_client.post("/coffee/text", debug=debug) + lines = list(map(lambda x: x.decode(), response.body.split(b"\n"))) + idx = lines.index("Context") + 1 + assert response.status == 418 + assert lines[2] == "Sorry, I cannot brew coffee" + assert lines[idx] == ' foo: "bar"' + + +@pytest.mark.parametrize("debug", (True, False)) +def test_contextual_exception_extra(debug): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + + @property + def message(self): + return f"Found {self.extra['foo']}" + + def fail(): + raise TeapotError(extra={"foo": "bar"}) + + app.post("/coffee/json", error_format="json")(lambda _: fail()) + app.post("/coffee/html", error_format="html")(lambda _: fail()) + app.post("/coffee/text", error_format="text")(lambda _: fail()) + + _, response = app.test_client.post("/coffee/json", debug=debug) + assert response.status == 418 + assert response.json["message"] == "Found bar" + if debug: + assert response.json["extra"] == {"foo": "bar"} + else: + assert "extra" not in response.json + + _, response = app.test_client.post("/coffee/html", debug=debug) + soup = BeautifulSoup(response.body, "html.parser") + dl = dl_to_dict(soup, "extra") + assert response.status == 418 + assert "Found bar" in soup.find("p").text + if debug: + assert dl == {"foo": "bar"} + else: + assert not dl + + _, response = app.test_client.post("/coffee/text", debug=debug) + lines = list(map(lambda x: x.decode(), response.body.split(b"\n"))) + assert response.status == 418 + assert lines[2] == "Found bar" + if debug: + idx = lines.index("Extra") + 1 + assert lines[idx] == ' foo: "bar"' + else: + assert "Extra" not in lines + + +@pytest.mark.parametrize("override", (True, False)) +def test_contextual_exception_functional_message(override): + app = Sanic(__name__) + + class TeapotError(SanicException): + status_code = 418 + + @property + def message(self): + return f"Received foo={self.context['foo']}" + + @app.post("/coffee", error_format="json") + async def make_coffee(_): + error_args = {"context": {"foo": "bar"}} + if override: + error_args["message"] = "override" + raise TeapotError(**error_args) + + _, response = app.test_client.post("/coffee", debug=True) + error_message = "override" if override else "Received foo=bar" + assert response.status == 418 + assert response.json["message"] == error_message + assert response.json["context"] == {"foo": "bar"}