diff --git a/docs/sanic/streaming.md b/docs/sanic/streaming.md index a785322a..e6305aa8 100644 --- a/docs/sanic/streaming.md +++ b/docs/sanic/streaming.md @@ -41,6 +41,18 @@ async def handler(request): return stream(streaming) +@app.post('/cancel_stream', stream=True) +async def cancel_stream(request): + request.cancel_stream() + result = '' + while True: + body = await request.stream.get() + if body is None: + break + result += body.decode('utf-8') + return text(result) + + @bp.put('/bp_stream', stream=True) async def bp_handler(request): result = '' diff --git a/examples/request_stream/server.py b/examples/request_stream/server.py index 37acfd5d..5a324684 100644 --- a/examples/request_stream/server.py +++ b/examples/request_stream/server.py @@ -9,6 +9,18 @@ bp = Blueprint('blueprint_request_stream') app = Sanic('request_stream') +@app.post('/cancel_stream', stream=True) +async def cancel_stream(request): + request.cancel_stream() + result = '' + while True: + body = await request.stream.get() + if body is None: + break + result += body.decode('utf-8') + return text(result) + + class SimpleView(HTTPMethodView): @stream_decorator diff --git a/sanic/request.py b/sanic/request.py index f3de36f8..6cee9583 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -45,7 +45,7 @@ class Request(dict): __slots__ = ( 'app', 'headers', 'version', 'method', '_cookies', 'transport', 'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', - '_ip', '_parsed_url', 'uri_template', 'stream' + '_ip', '_parsed_url', 'uri_template', 'stream', 'streamable' ) def __init__(self, url_bytes, headers, version, method, transport): @@ -67,6 +67,7 @@ class Request(dict): self.uri_template = None self._cookies = None self.stream = None + self.streamable = True @property def json(self): @@ -195,6 +196,9 @@ class Request(dict): self.query_string, None)) + def cancel_stream(self): + self.streamable = False + File = namedtuple('File', ['type', 'body', 'name']) diff --git a/sanic/server.py b/sanic/server.py index f3106226..370e285f 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -69,7 +69,7 @@ class HttpProtocol(asyncio.Protocol): 'has_log', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time', - '_is_stream_handler') + '_is_stream_handler', '_is_cancel_stream') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, @@ -93,6 +93,7 @@ class HttpProtocol(asyncio.Protocol): self.request_class = request_class or Request self.is_request_stream = is_request_stream self._is_stream_handler = False + self._is_cancel_stream = False self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None @@ -185,11 +186,20 @@ class HttpProtocol(asyncio.Protocol): if self._is_stream_handler: self.request.stream = asyncio.Queue() self.execute_request_handler() + if not self.request.streamable: + self._is_cancel_stream = True + self.transport.close() def on_body(self, body): if self.is_request_stream and self._is_stream_handler: - self._request_stream_task = self.loop.create_task( - self.request.stream.put(body)) + if self.request.streamable: + self._request_stream_task = self.loop.create_task( + self.request.stream.put(body)) + elif self._is_cancel_stream is True: + return + else: + self._is_cancel_stream = True + self.transport.close() return self.request.body.append(body) @@ -346,6 +356,7 @@ class HttpProtocol(asyncio.Protocol): self._request_stream_task = None self._total_request_size = 0 self._is_stream_handler = False + self._is_cancel_stream = False def close_if_idle(self): """Close the connection if a request is not being sent or received diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 0ceccf60..c04009aa 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -6,7 +6,7 @@ from sanic.views import HTTPMethodView from sanic.views import stream as stream_decorator from sanic.response import stream, text -data = "abc" * 100000 +data = 'abc' * 100000 def test_request_stream_method_view(): @@ -49,6 +49,21 @@ def test_request_stream_app(): app = Sanic('test_request_stream_app') + contents = [] + + @app.post('/cancel_stream', stream=True) + async def cancel_stream(request): + assert isinstance(request.stream, asyncio.Queue) + request.cancel_stream() + result = '' + while True: + body = await request.stream.get() + contents.append(body) + if body is None: + break + result += body.decode('utf-8') + return text(result) + @app.get('/get') async def get(request): assert request.stream is None @@ -122,6 +137,8 @@ def test_request_stream_app(): assert app.is_request_stream is True + assert len(contents) == 0 + request, response = app.test_client.get('/get') assert response.status == 200 assert response.text == 'GET'