pass request_buffer_queue_size argument to HttpProtocol (#1717)
* pass request_buffer_queue_size argument to HttpProtocol * fix to use simultaneously only one task to put body to stream buffer * add a test code for REQUEST_BUFFER_QUEUE_SIZE
This commit is contained in:
parent
ed1f367a8a
commit
ecbe5c839f
|
@ -62,6 +62,10 @@ class StreamBuffer:
|
|||
def is_full(self):
|
||||
return self._queue.full()
|
||||
|
||||
@property
|
||||
def buffer_size(self):
|
||||
return self._queue.maxsize
|
||||
|
||||
|
||||
class Request:
|
||||
"""Properties of an HTTP request such as URL, headers, etc."""
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import os
|
||||
import traceback
|
||||
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from inspect import isawaitable
|
||||
from multiprocessing import Process
|
||||
|
@ -148,6 +149,7 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self.state["requests_count"] = 0
|
||||
self._debug = debug
|
||||
self._not_paused.set()
|
||||
self._body_chunks = deque()
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
|
@ -347,19 +349,30 @@ class HttpProtocol(asyncio.Protocol):
|
|||
|
||||
def on_body(self, body):
|
||||
if self.is_request_stream and self._is_stream_handler:
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.body_append(body)
|
||||
)
|
||||
# body chunks can be put into asyncio.Queue out of order if
|
||||
# multiple tasks put concurrently and the queue is full in python
|
||||
# 3.7. so we should not create more than one task putting into the
|
||||
# queue simultaneously.
|
||||
self._body_chunks.append(body)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
else:
|
||||
self.request.body_push(body)
|
||||
|
||||
async def body_append(self, body):
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
async def stream_append(self):
|
||||
while self._body_chunks:
|
||||
body = self._body_chunks.popleft()
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
|
||||
def on_message_complete(self):
|
||||
# Entire request (headers and whole body) is received.
|
||||
|
@ -368,9 +381,14 @@ class HttpProtocol(asyncio.Protocol):
|
|||
self._request_timeout_handler.cancel()
|
||||
self._request_timeout_handler = None
|
||||
if self.is_request_stream and self._is_stream_handler:
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.request.stream.put(None)
|
||||
)
|
||||
self._body_chunks.append(None)
|
||||
if (
|
||||
not self._request_stream_task
|
||||
or self._request_stream_task.done()
|
||||
):
|
||||
self._request_stream_task = self.loop.create_task(
|
||||
self.stream_append()
|
||||
)
|
||||
return
|
||||
self.request.body_finish()
|
||||
self.execute_request_handler()
|
||||
|
@ -818,6 +836,7 @@ def serve(
|
|||
response_timeout=response_timeout,
|
||||
keep_alive_timeout=keep_alive_timeout,
|
||||
request_max_size=request_max_size,
|
||||
request_buffer_queue_size=request_buffer_queue_size,
|
||||
request_class=request_class,
|
||||
access_log=access_log,
|
||||
keep_alive=keep_alive,
|
||||
|
|
36
tests/test_request_buffer_queue_size.py
Normal file
36
tests/test_request_buffer_queue_size.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import io
|
||||
|
||||
from sanic.response import text
|
||||
|
||||
data = "abc" * 10_000_000
|
||||
|
||||
|
||||
def test_request_buffer_queue_size(app):
|
||||
default_buf_qsz = app.config.get("REQUEST_BUFFER_QUEUE_SIZE")
|
||||
qsz = 1
|
||||
while qsz == default_buf_qsz:
|
||||
qsz += 1
|
||||
app.config.REQUEST_BUFFER_QUEUE_SIZE = qsz
|
||||
|
||||
@app.post("/post", stream=True)
|
||||
async def post(request):
|
||||
assert request.stream.buffer_size == qsz
|
||||
print("request.stream.buffer_size =", request.stream.buffer_size)
|
||||
|
||||
bio = io.BytesIO()
|
||||
while True:
|
||||
bdata = await request.stream.read()
|
||||
if not bdata:
|
||||
break
|
||||
bio.write(bdata)
|
||||
|
||||
head = bdata[:3].decode("utf-8")
|
||||
tail = bdata[3:][-3:].decode("utf-8")
|
||||
print(head, "...", tail)
|
||||
|
||||
bio.seek(0)
|
||||
return text(bio.read().decode("utf-8"))
|
||||
|
||||
request, response = app.test_client.post("/post", data=data)
|
||||
assert response.status == 200
|
||||
assert response.text == data
|
Loading…
Reference in New Issue
Block a user