diff --git a/sanic/handlers.py b/sanic/handlers.py index 0b1ee89c..64df2c2c 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -16,9 +16,12 @@ from sanic.response import text, html class ErrorHandler: handlers = None + cached_handlers = None + _missing = object() def __init__(self): - self.handlers = {} + self.handlers = [] + self.cached_handlers = {} self.debug = False def _render_traceback_html(self, exception, request): @@ -37,7 +40,18 @@ class ErrorHandler: path=request.path) def add(self, exception, handler): - self.handlers[exception] = handler + self.handlers.append((exception, handler)) + + def lookup(self, exception): + handler = self.cached_handlers.get(exception, self._missing) + if handler is self._missing: + for exception_class, handler in self.handlers: + if isinstance(exception, exception_class): + self.cached_handlers[type(exception)] = handler + return handler + self.cached_handlers[type(exception)] = None + handler = None + return handler def response(self, request, exception): """Fetches and executes an exception handler and returns a response @@ -47,9 +61,13 @@ class ErrorHandler: :param exception: Exception to handle :return: Response object """ - handler = self.handlers.get(type(exception), self.default) + handler = self.lookup(exception) + response = None try: - response = handler(request=request, exception=exception) + if handler: + response = handler(request=request, exception=exception) + if response is None: + response = self.default(request=request, exception=exception) except Exception: self.log(format_exc()) if self.debug: diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index aca3345d..da8a74e9 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,6 +1,7 @@ from sanic import Sanic from sanic.response import text from sanic.exceptions import InvalidUsage, ServerError, NotFound +from sanic.handlers import ErrorHandler from bs4 import BeautifulSoup exception_handler_app = Sanic('test_exception_handler') @@ -30,7 +31,7 @@ def handler_4(request): @exception_handler_app.route('/5') def handler_5(request): class CustomServerError(ServerError): - status_code=200 + pass raise CustomServerError('Custom server error') @@ -81,3 +82,44 @@ def test_html_traceback_output_in_debug_mode(): def test_inherited_exception_handler(): request, response = exception_handler_app.test_client.get('/5') assert response.status == 200 + + +def test_exception_handler_lookup(): + class CustomError(Exception): + pass + + class CustomServerError(ServerError): + pass + + def custom_error_handler(): + pass + + def server_error_handler(): + pass + + def import_error_handler(): + pass + + try: + ModuleNotFoundError + except: + class ModuleNotFoundError(ImportError): + pass + + handler = ErrorHandler() + handler.add(ImportError, import_error_handler) + handler.add(CustomError, custom_error_handler) + handler.add(ServerError, server_error_handler) + + assert handler.lookup(ImportError()) == import_error_handler + assert handler.lookup(ModuleNotFoundError()) == import_error_handler + assert handler.lookup(CustomError()) == custom_error_handler + assert handler.lookup(ServerError('Error')) == server_error_handler + assert handler.lookup(CustomServerError('Error')) == server_error_handler + + # once again to ensure there is no caching bug + assert handler.lookup(ImportError()) == import_error_handler + assert handler.lookup(ModuleNotFoundError()) == import_error_handler + assert handler.lookup(CustomError()) == custom_error_handler + assert handler.lookup(ServerError('Error')) == server_error_handler + assert handler.lookup(CustomServerError('Error')) == server_error_handler