diff --git a/sanic/exceptions.py b/sanic/exceptions.py index bc052fbd..369a87a2 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -34,6 +34,10 @@ class RequestTimeout(SanicException): status_code = 408 +class PayloadTooLarge(SanicException): + status_code = 413 + + class Handler: handlers = None diff --git a/sanic/server.py b/sanic/server.py index a3074ecf..534436fa 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -14,7 +14,7 @@ except ImportError: from .log import log from .request import Request -from .exceptions import RequestTimeout +from .exceptions import RequestTimeout, PayloadTooLarge class Signal: @@ -81,9 +81,8 @@ class HttpProtocol(asyncio.Protocol): else: if self._request_handler_task: self._request_handler_task.cancel() - response = self.error_handler.response( - self.request, RequestTimeout('Request Timeout')) - self.write_response(response) + exception = RequestTimeout('Request Timeout') + self.write_error(exception) # -------------------------------------------- # # Parsing @@ -94,9 +93,8 @@ class HttpProtocol(asyncio.Protocol): # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: - return self.bail_out( - "Request too large ({}), connection closed".format( - self._total_request_size)) + exception = PayloadTooLarge('Payload Too Large') + self.write_error(exception) # Create parser if this is the first time we're receiving data if self.parser is None: @@ -116,8 +114,8 @@ class HttpProtocol(asyncio.Protocol): def on_header(self, name, value): if name == b'Content-Length' and int(value) > self.request_max_size: - return self.bail_out( - "Request body too large ({}), connection closed".format(value)) + exception = PayloadTooLarge('Payload Too Large') + self.write_error(exception) self.headers.append((name.decode(), value.decode('utf-8'))) @@ -164,6 +162,16 @@ class HttpProtocol(asyncio.Protocol): self.bail_out( "Writing response failed, connection closed {}".format(e)) + def write_error(self, exception): + try: + response = self.error_handler.response(self.request, exception) + version = self.request.version if self.request else '1.1' + self.transport.write(response.output(version)) + self.transport.close() + except Exception as e: + self.bail_out( + "Writing error failed, connection closed {}".format(e)) + def bail_out(self, message): log.debug(message) self.transport.close() diff --git a/tests/test_payload_too_large.py b/tests/test_payload_too_large.py new file mode 100644 index 00000000..e8eec09e --- /dev/null +++ b/tests/test_payload_too_large.py @@ -0,0 +1,54 @@ +from sanic import Sanic +from sanic.response import text +from sanic.exceptions import PayloadTooLarge +from sanic.utils import sanic_endpoint_test + +data_received_app = Sanic('data_received') +data_received_app.config.REQUEST_MAX_SIZE = 1 +data_received_default_app = Sanic('data_received_default') +data_received_default_app.config.REQUEST_MAX_SIZE = 1 +on_header_default_app = Sanic('on_header') +on_header_default_app.config.REQUEST_MAX_SIZE = 500 + + +@data_received_app.route('/1') +async def handler1(request): + return text('OK') + + +@data_received_app.exception(PayloadTooLarge) +def handler_exception(request, exception): + return text('Payload Too Large from error_handler.', 413) + + +def test_payload_too_large_from_error_handler(): + response = sanic_endpoint_test( + data_received_app, uri='/1', gather_request=False) + assert response.status == 413 + assert response.text == 'Payload Too Large from error_handler.' + + +@data_received_default_app.route('/1') +async def handler2(request): + return text('OK') + + +def test_payload_too_large_at_data_received_default(): + response = sanic_endpoint_test( + data_received_default_app, uri='/1', gather_request=False) + assert response.status == 413 + assert response.text == 'Error: Payload Too Large' + + +@on_header_default_app.route('/1') +async def handler3(request): + return text('OK') + + +def test_payload_too_large_at_on_header_default(): + data = 'a' * 1000 + response = sanic_endpoint_test( + on_header_default_app, method='post', uri='/1', + gather_request=False, data=data) + assert response.status == 413 + assert response.text == 'Error: Payload Too Large'