Cleanup of code and avoid mixing streaming responses.

This commit is contained in:
L. Kärkkäinen 2020-02-21 17:51:38 +02:00
parent fe64a2764d
commit f609a4850f
3 changed files with 77 additions and 98 deletions

View File

@ -47,11 +47,14 @@ class RequestParameters(dict):
class StreamBuffer:
def __init__(self, buffer_size=100):
def __init__(self, buffer_size=100, protocol=None):
self._queue = asyncio.Queue(buffer_size)
self._protocol = protocol
async def read(self):
""" Stop reading when gets None """
if self._protocol:
self._protocol.flow_control()
payload = await self._queue.get()
self._queue.task_done()
return payload

View File

@ -82,13 +82,12 @@ class HttpProtocol(asyncio.Protocol):
"_last_response_time",
"_is_stream_handler",
"_not_paused",
"_request_handler_task",
"_request_stream_task",
"_keep_alive",
"_header_fragment",
"state",
"_debug",
"_body_chunks",
"_handler_queue",
"_handler_task",
)
def __init__(
@ -112,13 +111,13 @@ class HttpProtocol(asyncio.Protocol):
router=None,
state=None,
debug=False,
**kwargs
**kwargs,
):
self.loop = loop
self.app = app
self.transport = None
self.request = None
self.parser = None
self.parser = HttpRequestParser(self)
self.url = None
self.headers = None
self.router = router
@ -145,8 +144,6 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = None
self._last_request_time = None
self._last_response_time = None
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive
self._header_fragment = b""
self.state = state if state else {}
@ -154,7 +151,38 @@ class HttpProtocol(asyncio.Protocol):
self.state["requests_count"] = 0
self._debug = debug
self._not_paused.set()
self._body_chunks = deque()
self._handler_queue = asyncio.Queue(1)
self._handler_task = self.loop.create_task(self.run_handlers())
async def run_handlers(self):
"""Runs one handler at a time, in correct order.
Otherwise we cannot control incoming requests, this keeps their
responses in order.
Handlers are inserted into the queue when headers are received. Body
may follow later and is handled by a separate queue in req.stream.
"""
while True:
try:
self.flow_control()
handler = await self._handler_queue.get()
await handler
except Exception as e:
traceback.print_exc()
self.transport.close()
raise
def flow_control(self):
"""Backpressure handling to avoid excessive buffering."""
if not self.transport:
return
if self._handler_queue.full() or (
self.request and self.request.stream.is_full()
):
self.transport.pause_reading()
else:
self.transport.resume_reading()
@property
def keep_alive(self):
@ -185,10 +213,8 @@ class HttpProtocol(asyncio.Protocol):
def connection_lost(self, exc):
self.connections.discard(self)
if self._request_handler_task:
self._request_handler_task.cancel()
if self._request_stream_task:
self._request_stream_task.cancel()
if self._handler_task:
self._handler_task.cancel()
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
if self._response_timeout_handler:
@ -214,10 +240,8 @@ class HttpProtocol(asyncio.Protocol):
time_left, self.request_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
if self._handler_task:
self._handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
@ -230,10 +254,8 @@ class HttpProtocol(asyncio.Protocol):
time_left, self.response_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
if self._handler_task:
self._handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
@ -266,15 +288,6 @@ class HttpProtocol(asyncio.Protocol):
if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
@ -284,6 +297,13 @@ class HttpProtocol(asyncio.Protocol):
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
def on_message_begin(self):
assert self.request is None
self.headers = []
# requests count
self.state["requests_count"] += 1
def on_url(self, url):
if not self.url:
self.url = url
@ -327,9 +347,10 @@ class HttpProtocol(asyncio.Protocol):
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
if self.is_request_stream:
self.request.stream = StreamBuffer(self.request_buffer_queue_size)
self.execute_request_handler()
self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self,
)
self.execute_request_handler()
def expect_handler(self):
"""
@ -347,44 +368,10 @@ class HttpProtocol(asyncio.Protocol):
)
def on_body(self, body):
# body chunks can be put into asyncio.Queue out of order if
# multiple tasks put concurrently and the queue is full in python
# 3.7. so we should not create more than one task putting into the
# queue simultaneously.
self._body_chunks.append(body)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
async def body_append(self, body):
if (
self.request is None
or self._request_stream_task is None
or self._request_stream_task.cancelled()
):
return
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
async def stream_append(self):
while self._body_chunks and self.request:
body = self._body_chunks.popleft()
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
self._body_chunks.clear()
# Send request body chunks
if body:
self.request.stream._queue.put_nowait(body)
self.flow_control()
def on_message_complete(self):
# Entire request (headers and whole body) is received.
@ -393,14 +380,9 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler.cancel()
self._request_timeout_handler = None
self._body_chunks.append(None)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
# Mark the end of request body
self.request.stream._queue.put_nowait(None)
self.request = None
def execute_request_handler(self):
"""
@ -409,15 +391,17 @@ class HttpProtocol(asyncio.Protocol):
:return: None
"""
# Invoked when request headers have been received
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = time()
self._request_handler_task = self.loop.create_task(
self._handler_queue.put_nowait(
self.request_handler(
self.request, self.write_response, self.stream_response
)
)
self.flow_control()
# -------------------------------------------- #
# Responding
@ -467,17 +451,17 @@ class HttpProtocol(asyncio.Protocol):
try:
keep_alive = self.keep_alive
self.transport.write(
response.output(
self.request.version, keep_alive, self.keep_alive_timeout
)
response.output("1.1", keep_alive, self.keep_alive_timeout)
)
self.log_response(response)
except AttributeError:
if isinstance(response, HTTPResponse):
raise
url = self.url.decode()
res_type = type(response).__name__
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
f"Invalid response object for url {url!r}, "
f"Expected Type: HTTPResponse, Actual Type: {res_type}"
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
@ -521,9 +505,7 @@ class HttpProtocol(asyncio.Protocol):
try:
keep_alive = self.keep_alive
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout
)
await response.stream("1.1", keep_alive, self.keep_alive_timeout)
self.log_response(response)
except AttributeError:
logger.error(
@ -564,8 +546,7 @@ class HttpProtocol(asyncio.Protocol):
response = None
try:
response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else "1.1"
self.transport.write(response.output(version))
self.transport.write(response.output("1.1"))
except RuntimeError:
if self._debug:
logger.error(
@ -621,12 +602,8 @@ class HttpProtocol(asyncio.Protocol):
"""This is called when KeepAlive feature is used,
it resets the connection in order for it to be able
to handle receiving another request on the same connection."""
self.parser = None
self.request = None
self.url = None
self.headers = None
self._request_handler_task = None
self._request_stream_task = None
self._total_request_size = 0
def close_if_idle(self):
@ -888,7 +865,7 @@ def serve(
reuse_port=reuse_port,
sock=sock,
backlog=backlog,
**asyncio_server_kwargs
**asyncio_server_kwargs,
)
if run_async:

View File

@ -73,7 +73,6 @@ class SanicTestClient:
):
results = [None, None]
exceptions = []
if gather_request:
def _collect_request(request):