Blueprint specific exception handlers (#2208)

This commit is contained in:
Adam Hopkins 2021-08-31 12:32:51 +03:00 committed by GitHub
parent 945885d501
commit 69c5dde9bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 31 deletions

View File

@ -334,7 +334,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
self.named_response_middleware[_rn].appendleft(middleware)
return middleware
def _apply_exception_handler(self, handler: FutureException):
def _apply_exception_handler(
self,
handler: FutureException,
route_names: Optional[List[str]] = None,
):
"""Decorate a function to be registered as a handler for exceptions
:param exceptions: exceptions
@ -344,9 +348,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
for exception in handler.exceptions:
if isinstance(exception, (tuple, list)):
for e in exception:
self.error_handler.add(e, handler.handler)
self.error_handler.add(e, handler.handler, route_names)
else:
self.error_handler.add(exception, handler.handler)
self.error_handler.add(exception, handler.handler, route_names)
return handler.handler
def _apply_listener(self, listener: FutureListener):

View File

@ -338,7 +338,9 @@ class Blueprint(BaseSanic):
# Exceptions
for future in self._future_exceptions:
exception_handlers.append(app._apply_exception_handler(future))
exception_handlers.append(
app._apply_exception_handler(future, route_names)
)
# Event listeners
for listener in self._future_listeners:

View File

@ -1,3 +1,5 @@
from typing import List, Optional
from sanic.errorpages import exception_response
from sanic.exceptions import (
ContentRangeError,
@ -21,15 +23,12 @@ class ErrorHandler:
"""
handlers = None
cached_handlers = None
def __init__(self):
self.handlers = []
self.cached_handlers = {}
self.debug = False
def add(self, exception, handler):
def add(self, exception, handler, route_names: Optional[List[str]] = None):
"""
Add a new exception handler to an already existing handler object.
@ -42,11 +41,16 @@ class ErrorHandler:
:return: None
"""
# self.handlers to be deprecated and removed in version 21.12
# self.handlers is deprecated and will be removed in version 22.3
self.handlers.append((exception, handler))
self.cached_handlers[exception] = handler
def lookup(self, exception):
if route_names:
for route in route_names:
self.cached_handlers[(exception, route)] = handler
else:
self.cached_handlers[(exception, None)] = handler
def lookup(self, exception, route_name: Optional[str]):
"""
Lookup the existing instance of :class:`ErrorHandler` and fetch the
registered handler for a specific type of exception.
@ -61,17 +65,26 @@ class ErrorHandler:
:return: Registered function if found ``None`` otherwise
"""
exception_class = type(exception)
if exception_class in self.cached_handlers:
return self.cached_handlers[exception_class]
for ancestor in type.mro(exception_class):
if ancestor in self.cached_handlers:
handler = self.cached_handlers[ancestor]
self.cached_handlers[exception_class] = handler
for name in (route_name, None):
exception_key = (exception_class, name)
handler = self.cached_handlers.get(exception_key)
if handler:
return handler
if ancestor is BaseException:
break
self.cached_handlers[exception_class] = None
for name in (route_name, None):
for ancestor in type.mro(exception_class):
exception_key = (ancestor, name)
if exception_key in self.cached_handlers:
handler = self.cached_handlers[exception_key]
self.cached_handlers[
(exception_class, route_name)
] = handler
return handler
if ancestor is BaseException:
break
self.cached_handlers[(exception_class, route_name)] = None
handler = None
return handler
@ -89,7 +102,8 @@ class ErrorHandler:
:return: Wrap the return value obtained from :func:`default`
or registered handler for that type of exception.
"""
handler = self.lookup(exception)
route_name = request.name if request else None
handler = self.lookup(exception, route_name)
response = None
try:
if handler:

View File

@ -189,18 +189,24 @@ def test_exception_handler_lookup():
handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler)
assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler
assert handler.lookup(ImportError(), None) == import_error_handler
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
assert handler.lookup(CustomError(), None) == custom_error_handler
assert handler.lookup(ServerError("Error"), None) == server_error_handler
assert (
handler.lookup(CustomServerError("Error"), None)
== server_error_handler
)
# once again to ensure there is no caching bug
assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler
assert handler.lookup(ImportError(), None) == import_error_handler
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
assert handler.lookup(CustomError(), None) == custom_error_handler
assert handler.lookup(ServerError("Error"), None) == server_error_handler
assert (
handler.lookup(CustomServerError("Error"), None)
== server_error_handler
)
def test_exception_handler_processed_request_middleware():