diff --git a/docs/sanic/streaming.md b/docs/sanic/streaming.md index a785322a..e9380338 100644 --- a/docs/sanic/streaming.md +++ b/docs/sanic/streaming.md @@ -2,7 +2,7 @@ ## Request Streaming -Sanic allows you to get request data by stream, as below. When the request ends, `request.stream.get()` returns `None`. Only post, put and patch decorator have stream argument. +Sanic allows you to get request data by stream, as below. When the request ends, `request.stream.get()` returns `None`. In order to do flow controll, calling `request.stream.task_done()` right after processing is required. Only post, put and patch decorator have stream argument. ```python from sanic import Sanic @@ -26,6 +26,7 @@ class SimpleView(HTTPMethodView): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) @@ -38,6 +39,7 @@ async def handler(request): break body = body.decode('utf-8').replace('1', 'A') response.write(body) + request.stream.task_done() return stream(streaming) @@ -49,6 +51,7 @@ async def bp_handler(request): if body is None: break result += body.decode('utf-8').replace('1', 'A') + request.stream.task_done() return text(result) @@ -59,6 +62,7 @@ async def post_handler(request): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) app.blueprint(bp) diff --git a/sanic/server.py b/sanic/server.py index 15ae4708..fe0bb53d 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -78,6 +78,7 @@ class HttpProtocol(asyncio.Protocol): def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, response_timeout=60, keep_alive_timeout=5, + request_max_queue_size=20, request_max_size=None, request_class=None, access_log=True, keep_alive=True, is_request_stream=False, router=None, state=None, debug=False, **kwargs): @@ -96,9 +97,11 @@ class HttpProtocol(asyncio.Protocol): self.request_timeout = request_timeout self.response_timeout = response_timeout self.keep_alive_timeout = keep_alive_timeout + self.request_max_queue_size = request_max_queue_size self.request_max_size = request_max_size self.request_class = request_class or Request self.is_request_stream = is_request_stream + self._paused = False self._is_stream_handler = False self._total_request_size = 0 self._request_timeout_handler = None @@ -230,6 +233,18 @@ class HttpProtocol(asyncio.Protocol): exception = InvalidUsage(message) self.write_error(exception) + if self.is_request_stream and not self._paused and \ + self.request is not None and self.request.stream: + if self.request.stream.qsize() > self.request_max_queue_size: + self.transport.pause_reading() + self._paused = True + self.loop.create_task(self.resume_reading()) + + async def resume_reading(self): + await self.request.stream.join() + self.transport.resume_reading() + self._paused = False + def on_url(self, url): if not self.url: self.url = url diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 4ca4e44e..a046b947 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -29,6 +29,7 @@ def test_request_stream_method_view(): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) app.add_route(SimpleView.as_view(), '/method_view') @@ -84,6 +85,7 @@ def test_request_stream_app(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) @app.put('/_put') @@ -101,6 +103,7 @@ def test_request_stream_app(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) @app.patch('/_patch') @@ -118,6 +121,7 @@ def test_request_stream_app(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) assert app.is_request_stream is True @@ -178,6 +182,7 @@ def test_request_stream_handle_exception(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) # 404 @@ -232,6 +237,7 @@ def test_request_stream_blueprint(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) @bp.put('/_put') @@ -249,6 +255,7 @@ def test_request_stream_blueprint(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) @bp.patch('/_patch') @@ -266,6 +273,7 @@ def test_request_stream_blueprint(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) app.blueprint(bp) @@ -330,6 +338,7 @@ def test_request_stream_composition_view(): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) view = CompositionView() @@ -369,6 +378,7 @@ def test_request_stream(): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) @app.post('/stream', stream=True) @@ -381,6 +391,7 @@ def test_request_stream(): if body is None: break response.write(body.decode('utf-8')) + request.stream.task_done() return stream(streaming) @app.get('/get') @@ -397,6 +408,7 @@ def test_request_stream(): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) @bp.get('/bp_get') @@ -416,6 +428,7 @@ def test_request_stream(): if body is None: break result += body.decode('utf-8') + request.stream.task_done() return text(result) app.add_route(SimpleView.as_view(), '/method_view')