diff --git a/sanic/exceptions.py b/sanic/exceptions.py index d986cd08..dd4beee8 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -139,9 +139,12 @@ class PayloadTooLarge(SanicException): class Handler: handlers = None + cached_handlers = None + _missing = object() def __init__(self): - self.handlers = {} + self.handlers = [] + self.cached_handlers = {} self.debug = False def _render_traceback_html(self, exception, request): @@ -160,7 +163,18 @@ class Handler: uri=request.url) def add(self, exception, handler): - self.handlers[exception] = handler + self.handlers.append((exception, handler)) + + def lookup(self, exception): + handler = self.cached_handlers.get(exception, self._missing) + if handler is self._missing: + for exception_class, handler in self.handlers: + if isinstance(exception, exception_class): + self.cached_handlers[type(exception)] = handler + return handler + self.cached_handlers[type(exception)] = None + handler = None + return handler def response(self, request, exception): """ @@ -170,9 +184,12 @@ class Handler: :param exception: Exception to handle :return: Response object """ - handler = self.handlers.get(type(exception), self.default) + handler = self.lookup(exception) try: - response = handler(request=request, exception=exception) + response = handler and handler( + request=request, exception=exception) + if response is None: + response = self.default(request=request, exception=exception) except: log.error(format_exc()) if self.debug: diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index c56713b6..d11f7380 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -28,6 +28,13 @@ def handler_4(request): return text(foo) +@exception_handler_app.route('/5') +def handler_5(request): + class CustomServerError(ServerError): + pass + raise CustomServerError('Custom server error') + + @exception_handler_app.exception(NotFound, ServerError) def handler_exception(request, exception): return text("OK") @@ -71,3 +78,8 @@ def test_html_traceback_output_in_debug_mode(): assert ( "NameError: name 'bar' " "is not defined while handling uri /4") == summary_text + + +def test_inherited_exception_handler(): + request, response = sanic_endpoint_test(exception_handler_app, uri='/5') + assert response.status == 200