Add #379 again and make related test rework

Original PR: https://github.com/channelcat/sanic/pull/379
This commit is contained in:
Jeong YunWon
2017-03-06 23:00:04 +09:00
parent 875790e862
commit 429e90183b
2 changed files with 65 additions and 5 deletions

View File

@@ -16,9 +16,12 @@ from sanic.response import text, html
class ErrorHandler: class ErrorHandler:
handlers = None handlers = None
cached_handlers = None
_missing = object()
def __init__(self): def __init__(self):
self.handlers = {} self.handlers = []
self.cached_handlers = {}
self.debug = False self.debug = False
def _render_traceback_html(self, exception, request): def _render_traceback_html(self, exception, request):
@@ -37,7 +40,18 @@ class ErrorHandler:
path=request.path) path=request.path)
def add(self, exception, handler): def add(self, exception, handler):
self.handlers[exception] = handler self.handlers.append((exception, handler))
def lookup(self, exception):
handler = self.cached_handlers.get(exception, self._missing)
if handler is self._missing:
for exception_class, handler in self.handlers:
if isinstance(exception, exception_class):
self.cached_handlers[type(exception)] = handler
return handler
self.cached_handlers[type(exception)] = None
handler = None
return handler
def response(self, request, exception): def response(self, request, exception):
"""Fetches and executes an exception handler and returns a response """Fetches and executes an exception handler and returns a response
@@ -47,9 +61,13 @@ class ErrorHandler:
:param exception: Exception to handle :param exception: Exception to handle
:return: Response object :return: Response object
""" """
handler = self.handlers.get(type(exception), self.default) handler = self.lookup(exception)
response = None
try: try:
if handler:
response = handler(request=request, exception=exception) response = handler(request=request, exception=exception)
if response is None:
response = self.default(request=request, exception=exception)
except Exception: except Exception:
self.log(format_exc()) self.log(format_exc())
if self.debug: if self.debug:

View File

@@ -1,6 +1,7 @@
from sanic import Sanic from sanic import Sanic
from sanic.response import text from sanic.response import text
from sanic.exceptions import InvalidUsage, ServerError, NotFound from sanic.exceptions import InvalidUsage, ServerError, NotFound
from sanic.handlers import ErrorHandler
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
exception_handler_app = Sanic('test_exception_handler') exception_handler_app = Sanic('test_exception_handler')
@@ -30,7 +31,7 @@ def handler_4(request):
@exception_handler_app.route('/5') @exception_handler_app.route('/5')
def handler_5(request): def handler_5(request):
class CustomServerError(ServerError): class CustomServerError(ServerError):
status_code=200 pass
raise CustomServerError('Custom server error') raise CustomServerError('Custom server error')
@@ -81,3 +82,44 @@ def test_html_traceback_output_in_debug_mode():
def test_inherited_exception_handler(): def test_inherited_exception_handler():
request, response = exception_handler_app.test_client.get('/5') request, response = exception_handler_app.test_client.get('/5')
assert response.status == 200 assert response.status == 200
def test_exception_handler_lookup():
class CustomError(Exception):
pass
class CustomServerError(ServerError):
pass
def custom_error_handler():
pass
def server_error_handler():
pass
def import_error_handler():
pass
try:
ModuleNotFoundError
except:
class ModuleNotFoundError(ImportError):
pass
handler = ErrorHandler()
handler.add(ImportError, import_error_handler)
handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler)
assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError('Error')) == server_error_handler
assert handler.lookup(CustomServerError('Error')) == server_error_handler
# once again to ensure there is no caching bug
assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError('Error')) == server_error_handler
assert handler.lookup(CustomServerError('Error')) == server_error_handler