diff --git a/sanic/app.py b/sanic/app.py index 92326b16..04383587 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -334,7 +334,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.named_response_middleware[_rn].appendleft(middleware) return middleware - def _apply_exception_handler(self, handler: FutureException): + def _apply_exception_handler( + self, + handler: FutureException, + route_names: Optional[List[str]] = None, + ): """Decorate a function to be registered as a handler for exceptions :param exceptions: exceptions @@ -344,9 +348,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): for exception in handler.exceptions: if isinstance(exception, (tuple, list)): for e in exception: - self.error_handler.add(e, handler.handler) + self.error_handler.add(e, handler.handler, route_names) else: - self.error_handler.add(exception, handler.handler) + self.error_handler.add(exception, handler.handler, route_names) return handler.handler def _apply_listener(self, listener: FutureListener): diff --git a/sanic/blueprints.py b/sanic/blueprints.py index dfc8240c..31809acf 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -338,7 +338,9 @@ class Blueprint(BaseSanic): # Exceptions for future in self._future_exceptions: - exception_handlers.append(app._apply_exception_handler(future)) + exception_handlers.append( + app._apply_exception_handler(future, route_names) + ) # Event listeners for listener in self._future_listeners: diff --git a/sanic/handlers.py b/sanic/handlers.py index 1210f251..e33aff76 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,3 +1,5 @@ +from typing import List, Optional + from sanic.errorpages import exception_response from sanic.exceptions import ( ContentRangeError, @@ -21,15 +23,12 @@ class ErrorHandler: """ - handlers = None - cached_handlers = None - def __init__(self): self.handlers = [] self.cached_handlers = {} self.debug = False - def add(self, exception, handler): + def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -42,11 +41,16 @@ class ErrorHandler: :return: None """ - # self.handlers to be deprecated and removed in version 21.12 + # self.handlers is deprecated and will be removed in version 22.3 self.handlers.append((exception, handler)) - self.cached_handlers[exception] = handler - def lookup(self, exception): + if route_names: + for route in route_names: + self.cached_handlers[(exception, route)] = handler + else: + self.cached_handlers[(exception, None)] = handler + + def lookup(self, exception, route_name: Optional[str]): """ Lookup the existing instance of :class:`ErrorHandler` and fetch the registered handler for a specific type of exception. @@ -61,17 +65,26 @@ class ErrorHandler: :return: Registered function if found ``None`` otherwise """ exception_class = type(exception) - if exception_class in self.cached_handlers: - return self.cached_handlers[exception_class] - for ancestor in type.mro(exception_class): - if ancestor in self.cached_handlers: - handler = self.cached_handlers[ancestor] - self.cached_handlers[exception_class] = handler + for name in (route_name, None): + exception_key = (exception_class, name) + handler = self.cached_handlers.get(exception_key) + if handler: return handler - if ancestor is BaseException: - break - self.cached_handlers[exception_class] = None + + for name in (route_name, None): + for ancestor in type.mro(exception_class): + exception_key = (ancestor, name) + if exception_key in self.cached_handlers: + handler = self.cached_handlers[exception_key] + self.cached_handlers[ + (exception_class, route_name) + ] = handler + return handler + + if ancestor is BaseException: + break + self.cached_handlers[(exception_class, route_name)] = None handler = None return handler @@ -89,7 +102,8 @@ class ErrorHandler: :return: Wrap the return value obtained from :func:`default` or registered handler for that type of exception. """ - handler = self.lookup(exception) + route_name = request.name if request else None + handler = self.lookup(exception, route_name) response = None try: if handler: diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index 44beb278..ba14102a 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -189,18 +189,24 @@ def test_exception_handler_lookup(): handler.add(CustomError, custom_error_handler) handler.add(ServerError, server_error_handler) - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) # once again to ensure there is no caching bug - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) def test_exception_handler_processed_request_middleware():