From 976da69e79e22c08ba2ea9ffe48bed72bf3290df Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 9 Jul 2023 10:53:36 +0300 Subject: [PATCH] Add a new exception signal for ALL exceptions raised anywhere in application (#2724) --- sanic/app.py | 4 ++++ sanic/handlers/error.py | 4 +++- sanic/signals.py | 12 ++++++++++++ tests/test_signal_handlers.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 527650cc..73874239 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -773,6 +773,10 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): :raises ServerError: response 500 """ response = None + await self.dispatch( + "server.lifecycle.exception", + context={"exception": exception}, + ) await self.dispatch( "http.lifecycle.exception", inline=True, diff --git a/sanic/handlers/error.py b/sanic/handlers/error.py index d01f8ba1..0a2ff9a2 100644 --- a/sanic/handlers/error.py +++ b/sanic/handlers/error.py @@ -6,7 +6,9 @@ from sanic.errorpages import BaseRenderer, TextRenderer, exception_response from sanic.exceptions import ServerError from sanic.log import error_logger from sanic.models.handler_types import RouteHandler +from sanic.request.types import Request from sanic.response import text +from sanic.response.types import HTTPResponse class ErrorHandler: @@ -148,7 +150,7 @@ class ErrorHandler: return text("An error occurred while handling an error", 500) return response - def default(self, request, exception): + def default(self, request: Request, exception: Exception) -> HTTPResponse: """ Provide a default behavior for the objects of :class:`ErrorHandler`. If a developer chooses to extent the :class:`ErrorHandler` they can diff --git a/sanic/signals.py b/sanic/signals.py index 5710f1ed..6b6fdcce 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -20,6 +20,7 @@ class Event(Enum): SERVER_INIT_BEFORE = "server.init.before" SERVER_SHUTDOWN_AFTER = "server.shutdown.after" SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" + SERVER_LIFECYCLE_EXCEPTION = "server.lifecycle.exception" HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" @@ -43,6 +44,7 @@ RESERVED_NAMESPACES = { Event.SERVER_INIT_BEFORE.value, Event.SERVER_SHUTDOWN_AFTER.value, Event.SERVER_SHUTDOWN_BEFORE.value, + Event.SERVER_LIFECYCLE_EXCEPTION.value, ), "http": ( Event.HTTP_LIFECYCLE_BEGIN.value, @@ -168,6 +170,16 @@ class SignalRouter(BaseRouter): elif maybe_coroutine: return maybe_coroutine return None + except Exception as e: + if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1: + error_logger.exception(e) + + if event != Event.SERVER_LIFECYCLE_EXCEPTION.value: + await self.dispatch( + Event.SERVER_LIFECYCLE_EXCEPTION.value, + context={"exception": e}, + ) + raise e finally: for signal_event in events: signal_event.clear() diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 53eff83f..401353e9 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -1,18 +1,19 @@ import asyncio import os import signal - from queue import Queue from types import SimpleNamespace +from typing import Optional from unittest.mock import MagicMock import pytest - from sanic_testing.testing import HOST, PORT +from sanic import Sanic from sanic.compat import ctrlc_workaround_for_windows -from sanic.exceptions import BadRequest +from sanic.exceptions import BadRequest, ServerError from sanic.response import HTTPResponse +from sanic.signals import Event async def stop(app, loop): @@ -148,3 +149,26 @@ def test_signals_with_invalid_invocation(app): BadRequest, match="Invalid event registration: Missing event name" ): app.listener(stop) + + +def test_signal_server_lifecycle_exception(app: Sanic): + trigger: Optional[Exception] = None + + @app.route("/hello") + async def hello_route(request): + return HTTPResponse() + + @app.signal(Event.SERVER_LIFECYCLE_EXCEPTION) + async def test_signal(exception: Exception): + nonlocal trigger + trigger = exception + + @app.before_server_start + async def test_before_server_start(app): + raise ServerError("test_before_server_start") + + with pytest.raises(ServerError, match="test_before_server_start"): + app.run(single_process=True) + + assert isinstance(trigger, ServerError) + assert str(trigger) == "test_before_server_start"