From ecbe5c839f8a67775255c0b94c81cfba441ffdd1 Mon Sep 17 00:00:00 2001 From: Junyeong Jeong Date: Fri, 22 Nov 2019 00:33:50 +0900 Subject: [PATCH] pass request_buffer_queue_size argument to HttpProtocol (#1717) * pass request_buffer_queue_size argument to HttpProtocol * fix to use simultaneously only one task to put body to stream buffer * add a test code for REQUEST_BUFFER_QUEUE_SIZE --- sanic/request.py | 4 +++ sanic/server.py | 45 ++++++++++++++++++------- tests/test_request_buffer_queue_size.py | 36 ++++++++++++++++++++ 3 files changed, 72 insertions(+), 13 deletions(-) create mode 100644 tests/test_request_buffer_queue_size.py diff --git a/sanic/request.py b/sanic/request.py index 9712aeb3..246eb351 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -62,6 +62,10 @@ class StreamBuffer: def is_full(self): return self._queue.full() + @property + def buffer_size(self): + return self._queue.maxsize + class Request: """Properties of an HTTP request such as URL, headers, etc.""" diff --git a/sanic/server.py b/sanic/server.py index fa38e435..41af81c0 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -2,6 +2,7 @@ import asyncio import os import traceback +from collections import deque from functools import partial from inspect import isawaitable from multiprocessing import Process @@ -148,6 +149,7 @@ class HttpProtocol(asyncio.Protocol): self.state["requests_count"] = 0 self._debug = debug self._not_paused.set() + self._body_chunks = deque() @property def keep_alive(self): @@ -347,19 +349,30 @@ class HttpProtocol(asyncio.Protocol): def on_body(self, body): if self.is_request_stream and self._is_stream_handler: - self._request_stream_task = self.loop.create_task( - self.body_append(body) - ) + # body chunks can be put into asyncio.Queue out of order if + # multiple tasks put concurrently and the queue is full in python + # 3.7. so we should not create more than one task putting into the + # queue simultaneously. + self._body_chunks.append(body) + if ( + not self._request_stream_task + or self._request_stream_task.done() + ): + self._request_stream_task = self.loop.create_task( + self.stream_append() + ) else: self.request.body_push(body) - async def body_append(self, body): - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) + async def stream_append(self): + while self._body_chunks: + body = self._body_chunks.popleft() + if self.request.stream.is_full(): + self.transport.pause_reading() + await self.request.stream.put(body) + self.transport.resume_reading() + else: + await self.request.stream.put(body) def on_message_complete(self): # Entire request (headers and whole body) is received. @@ -368,9 +381,14 @@ class HttpProtocol(asyncio.Protocol): self._request_timeout_handler.cancel() self._request_timeout_handler = None if self.is_request_stream and self._is_stream_handler: - self._request_stream_task = self.loop.create_task( - self.request.stream.put(None) - ) + self._body_chunks.append(None) + if ( + not self._request_stream_task + or self._request_stream_task.done() + ): + self._request_stream_task = self.loop.create_task( + self.stream_append() + ) return self.request.body_finish() self.execute_request_handler() @@ -818,6 +836,7 @@ def serve( response_timeout=response_timeout, keep_alive_timeout=keep_alive_timeout, request_max_size=request_max_size, + request_buffer_queue_size=request_buffer_queue_size, request_class=request_class, access_log=access_log, keep_alive=keep_alive, diff --git a/tests/test_request_buffer_queue_size.py b/tests/test_request_buffer_queue_size.py new file mode 100644 index 00000000..1a9cfdf3 --- /dev/null +++ b/tests/test_request_buffer_queue_size.py @@ -0,0 +1,36 @@ +import io + +from sanic.response import text + +data = "abc" * 10_000_000 + + +def test_request_buffer_queue_size(app): + default_buf_qsz = app.config.get("REQUEST_BUFFER_QUEUE_SIZE") + qsz = 1 + while qsz == default_buf_qsz: + qsz += 1 + app.config.REQUEST_BUFFER_QUEUE_SIZE = qsz + + @app.post("/post", stream=True) + async def post(request): + assert request.stream.buffer_size == qsz + print("request.stream.buffer_size =", request.stream.buffer_size) + + bio = io.BytesIO() + while True: + bdata = await request.stream.read() + if not bdata: + break + bio.write(bdata) + + head = bdata[:3].decode("utf-8") + tail = bdata[3:][-3:].decode("utf-8") + print(head, "...", tail) + + bio.seek(0) + return text(bio.read().decode("utf-8")) + + request, response = app.test_client.post("/post", data=data) + assert response.status == 200 + assert response.text == data