diff --git a/sanic/app.py b/sanic/app.py index e966a4fd..57b68fc4 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -17,7 +17,7 @@ from asyncio import ( from asyncio.futures import Future from collections import defaultdict, deque from contextlib import contextmanager, suppress -from functools import partial +from functools import partial, wraps from inspect import isawaitable from os import environ from socket import socket @@ -87,7 +87,7 @@ from sanic.request import Request from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream from sanic.router import Router from sanic.server.websockets.impl import ConnectionClosed -from sanic.signals import Signal, SignalRouter +from sanic.signals import Event, Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta from sanic.types.shared_ctx import SharedContext from sanic.worker.inspector import Inspector @@ -605,6 +605,19 @@ class Sanic( raise NotFound("Could not find signal %s" % event) return await wait_for(signal.ctx.event.wait(), timeout=timeout) + def report_exception( + self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]] + ): + @wraps(handler) + async def report(exception: Exception) -> None: + await handler(self, exception) + + self.add_signal( + handler=report, event=Event.SERVER_EXCEPTION_REPORT.value + ) + + return report + def enable_websocket(self, enable=True): """Enable or disable the support for websocket. @@ -876,10 +889,12 @@ class Sanic( :raises ServerError: response 500 """ response = None - await self.dispatch( - "server.lifecycle.exception", - context={"exception": exception}, - ) + if not getattr(exception, "__dispatched__", False): + ... # DO NOT REMOVE THIS LINE. IT IS NEEDED FOR TOUCHUP. + await self.dispatch( + "server.exception.report", + context={"exception": exception}, + ) await self.dispatch( "http.lifecycle.exception", inline=True, @@ -1310,13 +1325,28 @@ class Sanic( app, loop, ): - if callable(task): + async def do(task): try: - task = task(app) - except TypeError: - task = task() + if callable(task): + try: + task = task(app) + except TypeError: + task = task() + if isawaitable(task): + await task + except CancelledError: + error_logger.warning( + f"Task {task} was cancelled before it completed." + ) + raise + except Exception as e: + await app.dispatch( + "server.exception.report", + context={"exception": e}, + ) + raise - return task + return do(task) @classmethod def _loop_add_task( diff --git a/sanic/signals.py b/sanic/signals.py index 6b6fdcce..3075dbaa 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -16,11 +16,11 @@ from sanic.models.handler_types import SignalHandler class Event(Enum): + SERVER_EXCEPTION_REPORT = "server.exception.report" SERVER_INIT_AFTER = "server.init.after" SERVER_INIT_BEFORE = "server.init.before" SERVER_SHUTDOWN_AFTER = "server.shutdown.after" SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" - SERVER_LIFECYCLE_EXCEPTION = "server.lifecycle.exception" HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" @@ -40,11 +40,11 @@ class Event(Enum): RESERVED_NAMESPACES = { "server": ( + Event.SERVER_EXCEPTION_REPORT.value, Event.SERVER_INIT_AFTER.value, Event.SERVER_INIT_BEFORE.value, Event.SERVER_SHUTDOWN_AFTER.value, Event.SERVER_SHUTDOWN_BEFORE.value, - Event.SERVER_LIFECYCLE_EXCEPTION.value, ), "http": ( Event.HTTP_LIFECYCLE_BEGIN.value, @@ -174,11 +174,12 @@ class SignalRouter(BaseRouter): if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1: error_logger.exception(e) - if event != Event.SERVER_LIFECYCLE_EXCEPTION.value: + if event != Event.SERVER_EXCEPTION_REPORT.value: await self.dispatch( - Event.SERVER_LIFECYCLE_EXCEPTION.value, + Event.SERVER_EXCEPTION_REPORT.value, context={"exception": e}, ) + setattr(e, "__dispatched__", True) raise e finally: for signal_event in events: diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 43f40e13..728c9c17 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -160,7 +160,7 @@ def test_signal_server_lifecycle_exception(app: Sanic): async def hello_route(request): return HTTPResponse() - @app.signal(Event.SERVER_LIFECYCLE_EXCEPTION) + @app.signal(Event.SERVER_EXCEPTION_REPORT) async def test_signal(exception: Exception): nonlocal trigger trigger = exception diff --git a/tests/test_signals.py b/tests/test_signals.py index 49a684e7..8a564a98 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -2,6 +2,7 @@ import asyncio from enum import Enum from inspect import isawaitable +from itertools import count import pytest @@ -9,6 +10,7 @@ from sanic_routing.exceptions import NotFound from sanic import Blueprint, Sanic, empty from sanic.exceptions import InvalidSignal, SanicException +from sanic.signals import Event def test_add_signal(app): @@ -427,3 +429,114 @@ def test_signal_reservation(app, event, expected): app.signal(event)(lambda: ...) else: app.signal(event)(lambda: ...) + + +@pytest.mark.asyncio +async def test_report_exception(app: Sanic): + @app.report_exception + async def catch_any_exception(app: Sanic, exception: Exception): + ... + + @app.route("/") + async def handler(request): + 1 / 0 + + app.signal_router.finalize() + + registered_signal_handlers = [ + handler + for handler, *_ in app.signal_router.get( + Event.SERVER_EXCEPTION_REPORT.value + ) + ] + + assert catch_any_exception in registered_signal_handlers + + +def test_report_exception_runs(app: Sanic): + event = asyncio.Event() + + @app.report_exception + async def catch_any_exception(app: Sanic, exception: Exception): + event.set() + + @app.route("/") + async def handler(request): + 1 / 0 + + app.test_client.get("/") + + assert event.is_set() + + +def test_report_exception_runs_once_inline(app: Sanic): + event = asyncio.Event() + c = count() + + @app.report_exception + async def catch_any_exception(app: Sanic, exception: Exception): + event.set() + next(c) + + @app.route("/") + async def handler(request): + ... + + @app.signal(Event.HTTP_ROUTING_AFTER.value) + async def after_routing(**_): + 1 / 0 + + app.test_client.get("/") + + assert event.is_set() + assert next(c) == 1 + + +def test_report_exception_runs_once_custom(app: Sanic): + event = asyncio.Event() + c = count() + + @app.report_exception + async def catch_any_exception(app: Sanic, exception: Exception): + event.set() + next(c) + + @app.route("/") + async def handler(request): + await app.dispatch("one.two.three") + return empty() + + @app.signal("one.two.three") + async def one_two_three(**_): + 1 / 0 + + app.test_client.get("/") + + assert event.is_set() + assert next(c) == 1 + + +def test_report_exception_runs_task(app: Sanic): + c = count() + + async def task_1(): + next(c) + + async def task_2(app): + next(c) + + @app.report_exception + async def catch_any_exception(app: Sanic, exception: Exception): + next(c) + + @app.route("/") + async def handler(request): + app.add_task(task_1) + app.add_task(task_1()) + app.add_task(task_2) + app.add_task(task_2(app)) + return empty() + + app.test_client.get("/") + + assert next(c) == 4