Add convenience method for exception reporting (#2792)

This commit is contained in:
Adam Hopkins 2023-07-18 00:21:55 +03:00 committed by GitHub
parent 31d7ba8f8c
commit 9cbe1fb8ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 16 deletions

View File

@ -17,7 +17,7 @@ from asyncio import (
from asyncio.futures import Future
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from functools import partial
from functools import partial, wraps
from inspect import isawaitable
from os import environ
from socket import socket
@ -87,7 +87,7 @@ from sanic.request import Request
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream
from sanic.router import Router
from sanic.server.websockets.impl import ConnectionClosed
from sanic.signals import Signal, SignalRouter
from sanic.signals import Event, Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.types.shared_ctx import SharedContext
from sanic.worker.inspector import Inspector
@ -605,6 +605,19 @@ class Sanic(
raise NotFound("Could not find signal %s" % event)
return await wait_for(signal.ctx.event.wait(), timeout=timeout)
def report_exception(
self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]]
):
@wraps(handler)
async def report(exception: Exception) -> None:
await handler(self, exception)
self.add_signal(
handler=report, event=Event.SERVER_EXCEPTION_REPORT.value
)
return report
def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
@ -876,8 +889,10 @@ class Sanic(
:raises ServerError: response 500
"""
response = None
if not getattr(exception, "__dispatched__", False):
... # DO NOT REMOVE THIS LINE. IT IS NEEDED FOR TOUCHUP.
await self.dispatch(
"server.lifecycle.exception",
"server.exception.report",
context={"exception": exception},
)
await self.dispatch(
@ -1310,13 +1325,28 @@ class Sanic(
app,
loop,
):
async def do(task):
try:
if callable(task):
try:
task = task(app)
except TypeError:
task = task()
if isawaitable(task):
await task
except CancelledError:
error_logger.warning(
f"Task {task} was cancelled before it completed."
)
raise
except Exception as e:
await app.dispatch(
"server.exception.report",
context={"exception": e},
)
raise
return task
return do(task)
@classmethod
def _loop_add_task(

View File

@ -16,11 +16,11 @@ from sanic.models.handler_types import SignalHandler
class Event(Enum):
SERVER_EXCEPTION_REPORT = "server.exception.report"
SERVER_INIT_AFTER = "server.init.after"
SERVER_INIT_BEFORE = "server.init.before"
SERVER_SHUTDOWN_AFTER = "server.shutdown.after"
SERVER_SHUTDOWN_BEFORE = "server.shutdown.before"
SERVER_LIFECYCLE_EXCEPTION = "server.lifecycle.exception"
HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin"
HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete"
HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception"
@ -40,11 +40,11 @@ class Event(Enum):
RESERVED_NAMESPACES = {
"server": (
Event.SERVER_EXCEPTION_REPORT.value,
Event.SERVER_INIT_AFTER.value,
Event.SERVER_INIT_BEFORE.value,
Event.SERVER_SHUTDOWN_AFTER.value,
Event.SERVER_SHUTDOWN_BEFORE.value,
Event.SERVER_LIFECYCLE_EXCEPTION.value,
),
"http": (
Event.HTTP_LIFECYCLE_BEGIN.value,
@ -174,11 +174,12 @@ class SignalRouter(BaseRouter):
if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1:
error_logger.exception(e)
if event != Event.SERVER_LIFECYCLE_EXCEPTION.value:
if event != Event.SERVER_EXCEPTION_REPORT.value:
await self.dispatch(
Event.SERVER_LIFECYCLE_EXCEPTION.value,
Event.SERVER_EXCEPTION_REPORT.value,
context={"exception": e},
)
setattr(e, "__dispatched__", True)
raise e
finally:
for signal_event in events:

View File

@ -160,7 +160,7 @@ def test_signal_server_lifecycle_exception(app: Sanic):
async def hello_route(request):
return HTTPResponse()
@app.signal(Event.SERVER_LIFECYCLE_EXCEPTION)
@app.signal(Event.SERVER_EXCEPTION_REPORT)
async def test_signal(exception: Exception):
nonlocal trigger
trigger = exception

View File

@ -2,6 +2,7 @@ import asyncio
from enum import Enum
from inspect import isawaitable
from itertools import count
import pytest
@ -9,6 +10,7 @@ from sanic_routing.exceptions import NotFound
from sanic import Blueprint, Sanic, empty
from sanic.exceptions import InvalidSignal, SanicException
from sanic.signals import Event
def test_add_signal(app):
@ -427,3 +429,114 @@ def test_signal_reservation(app, event, expected):
app.signal(event)(lambda: ...)
else:
app.signal(event)(lambda: ...)
@pytest.mark.asyncio
async def test_report_exception(app: Sanic):
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
...
@app.route("/")
async def handler(request):
1 / 0
app.signal_router.finalize()
registered_signal_handlers = [
handler
for handler, *_ in app.signal_router.get(
Event.SERVER_EXCEPTION_REPORT.value
)
]
assert catch_any_exception in registered_signal_handlers
def test_report_exception_runs(app: Sanic):
event = asyncio.Event()
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()
@app.route("/")
async def handler(request):
1 / 0
app.test_client.get("/")
assert event.is_set()
def test_report_exception_runs_once_inline(app: Sanic):
event = asyncio.Event()
c = count()
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()
next(c)
@app.route("/")
async def handler(request):
...
@app.signal(Event.HTTP_ROUTING_AFTER.value)
async def after_routing(**_):
1 / 0
app.test_client.get("/")
assert event.is_set()
assert next(c) == 1
def test_report_exception_runs_once_custom(app: Sanic):
event = asyncio.Event()
c = count()
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()
next(c)
@app.route("/")
async def handler(request):
await app.dispatch("one.two.three")
return empty()
@app.signal("one.two.three")
async def one_two_three(**_):
1 / 0
app.test_client.get("/")
assert event.is_set()
assert next(c) == 1
def test_report_exception_runs_task(app: Sanic):
c = count()
async def task_1():
next(c)
async def task_2(app):
next(c)
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
next(c)
@app.route("/")
async def handler(request):
app.add_task(task_1)
app.add_task(task_1())
app.add_task(task_2)
app.add_task(task_2(app))
return empty()
app.test_client.get("/")
assert next(c) == 4