diff --git a/examples/request_timeout.py b/examples/request_timeout.py new file mode 100644 index 00000000..261f423a --- /dev/null +++ b/examples/request_timeout.py @@ -0,0 +1,21 @@ +from sanic import Sanic +import asyncio +from sanic.response import text +from sanic.config import Config +from sanic.exceptions import RequestTimeout + +Config.REQUEST_TIMEOUT = 1 +app = Sanic(__name__) + + +@app.route('/') +async def test(request): + await asyncio.sleep(3) + return text('Hello, world!') + + +@app.exception(RequestTimeout) +def timeout(request, exception): + return text('RequestTimeout from error_handler.', 408) + +app.run(host='0.0.0.0', port=8000) diff --git a/sanic/exceptions.py b/sanic/exceptions.py index e21aca63..bc052fbd 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -30,6 +30,10 @@ class FileNotFound(NotFound): self.relative_url = relative_url +class RequestTimeout(SanicException): + status_code = 408 + + class Handler: handlers = None diff --git a/sanic/sanic.py b/sanic/sanic.py index 33e16af7..98bb230d 100644 --- a/sanic/sanic.py +++ b/sanic/sanic.py @@ -263,6 +263,7 @@ class Sanic: 'sock': sock, 'debug': debug, 'request_handler': self.handle_request, + 'error_handler': self.error_handler, 'request_timeout': self.config.REQUEST_TIMEOUT, 'request_max_size': self.config.REQUEST_MAX_SIZE, 'loop': loop diff --git a/sanic/server.py b/sanic/server.py index 6301d18f..a3074ecf 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -14,6 +14,7 @@ except ImportError: from .log import log from .request import Request +from .exceptions import RequestTimeout class Signal: @@ -34,8 +35,8 @@ class HttpProtocol(asyncio.Protocol): # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') - def __init__(self, *, loop, request_handler, signal=Signal(), - connections={}, request_timeout=60, + def __init__(self, *, loop, request_handler, error_handler, + signal=Signal(), connections={}, request_timeout=60, request_max_size=None): self.loop = loop self.transport = None @@ -46,11 +47,13 @@ class HttpProtocol(asyncio.Protocol): self.signal = signal self.connections = connections self.request_handler = request_handler + self.error_handler = error_handler self.request_timeout = request_timeout self.request_max_size = request_max_size self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None + self._request_handler_task = None # -------------------------------------------- # # Connection @@ -76,7 +79,11 @@ class HttpProtocol(asyncio.Protocol): self._timeout_handler = \ self.loop.call_later(time_left, self.connection_timeout) else: - self.bail_out("Request timed out, connection closed") + if self._request_handler_task: + self._request_handler_task.cancel() + response = self.error_handler.response( + self.request, RequestTimeout('Request Timeout')) + self.write_response(response) # -------------------------------------------- # # Parsing @@ -133,7 +140,7 @@ class HttpProtocol(asyncio.Protocol): self.request.body = body def on_message_complete(self): - self.loop.create_task( + self._request_handler_task = self.loop.create_task( self.request_handler(self.request, self.write_response)) # -------------------------------------------- # @@ -166,6 +173,7 @@ class HttpProtocol(asyncio.Protocol): self.request = None self.url = None self.headers = None + self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): @@ -205,8 +213,8 @@ def trigger_events(events, loop): loop.run_until_complete(result) -def serve(host, port, request_handler, before_start=None, after_start=None, - before_stop=None, after_stop=None, +def serve(host, port, request_handler, error_handler, before_start=None, + after_start=None, before_stop=None, after_stop=None, debug=False, request_timeout=60, sock=None, request_max_size=None, reuse_port=False, loop=None): """ @@ -241,6 +249,7 @@ def serve(host, port, request_handler, before_start=None, after_start=None, connections=connections, signal=signal, request_handler=request_handler, + error_handler=error_handler, request_timeout=request_timeout, request_max_size=request_max_size, ), host, port, reuse_port=reuse_port, sock=sock) diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py new file mode 100644 index 00000000..7b8cfb21 --- /dev/null +++ b/tests/test_request_timeout.py @@ -0,0 +1,40 @@ +from sanic import Sanic +import asyncio +from sanic.response import text +from sanic.exceptions import RequestTimeout +from sanic.utils import sanic_endpoint_test +from sanic.config import Config + +Config.REQUEST_TIMEOUT = 1 +request_timeout_app = Sanic('test_request_timeout') +request_timeout_default_app = Sanic('test_request_timeout_default') + + +@request_timeout_app.route('/1') +async def handler_1(request): + await asyncio.sleep(1) + return text('OK') + + +@request_timeout_app.exception(RequestTimeout) +def handler_exception(request, exception): + return text('Request Timeout from error_handler.', 408) + + +def test_server_error_request_timeout(): + request, response = sanic_endpoint_test(request_timeout_app, uri='/1') + assert response.status == 408 + assert response.text == 'Request Timeout from error_handler.' + + +@request_timeout_default_app.route('/1') +async def handler_2(request): + await asyncio.sleep(1) + return text('OK') + + +def test_default_server_error_request_timeout(): + request, response = sanic_endpoint_test( + request_timeout_default_app, uri='/1') + assert response.status == 408 + assert response.text == 'Error: Request Timeout'