Cleanup of code and avoid mixing streaming responses.

This commit is contained in:
L. Kärkkäinen 2020-02-21 17:51:38 +02:00
parent fe64a2764d
commit f609a4850f
3 changed files with 77 additions and 98 deletions

View File

@ -47,11 +47,14 @@ class RequestParameters(dict):
class StreamBuffer: class StreamBuffer:
def __init__(self, buffer_size=100): def __init__(self, buffer_size=100, protocol=None):
self._queue = asyncio.Queue(buffer_size) self._queue = asyncio.Queue(buffer_size)
self._protocol = protocol
async def read(self): async def read(self):
""" Stop reading when gets None """ """ Stop reading when gets None """
if self._protocol:
self._protocol.flow_control()
payload = await self._queue.get() payload = await self._queue.get()
self._queue.task_done() self._queue.task_done()
return payload return payload

View File

@ -82,13 +82,12 @@ class HttpProtocol(asyncio.Protocol):
"_last_response_time", "_last_response_time",
"_is_stream_handler", "_is_stream_handler",
"_not_paused", "_not_paused",
"_request_handler_task",
"_request_stream_task",
"_keep_alive", "_keep_alive",
"_header_fragment", "_header_fragment",
"state", "state",
"_debug", "_debug",
"_body_chunks", "_handler_queue",
"_handler_task",
) )
def __init__( def __init__(
@ -112,13 +111,13 @@ class HttpProtocol(asyncio.Protocol):
router=None, router=None,
state=None, state=None,
debug=False, debug=False,
**kwargs **kwargs,
): ):
self.loop = loop self.loop = loop
self.app = app self.app = app
self.transport = None self.transport = None
self.request = None self.request = None
self.parser = None self.parser = HttpRequestParser(self)
self.url = None self.url = None
self.headers = None self.headers = None
self.router = router self.router = router
@ -145,8 +144,6 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = None self._keep_alive_timeout_handler = None
self._last_request_time = None self._last_request_time = None
self._last_response_time = None self._last_response_time = None
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive self._keep_alive = keep_alive
self._header_fragment = b"" self._header_fragment = b""
self.state = state if state else {} self.state = state if state else {}
@ -154,7 +151,38 @@ class HttpProtocol(asyncio.Protocol):
self.state["requests_count"] = 0 self.state["requests_count"] = 0
self._debug = debug self._debug = debug
self._not_paused.set() self._not_paused.set()
self._body_chunks = deque() self._handler_queue = asyncio.Queue(1)
self._handler_task = self.loop.create_task(self.run_handlers())
async def run_handlers(self):
"""Runs one handler at a time, in correct order.
Otherwise we cannot control incoming requests, this keeps their
responses in order.
Handlers are inserted into the queue when headers are received. Body
may follow later and is handled by a separate queue in req.stream.
"""
while True:
try:
self.flow_control()
handler = await self._handler_queue.get()
await handler
except Exception as e:
traceback.print_exc()
self.transport.close()
raise
def flow_control(self):
"""Backpressure handling to avoid excessive buffering."""
if not self.transport:
return
if self._handler_queue.full() or (
self.request and self.request.stream.is_full()
):
self.transport.pause_reading()
else:
self.transport.resume_reading()
@property @property
def keep_alive(self): def keep_alive(self):
@ -185,10 +213,8 @@ class HttpProtocol(asyncio.Protocol):
def connection_lost(self, exc): def connection_lost(self, exc):
self.connections.discard(self) self.connections.discard(self)
if self._request_handler_task: if self._handler_task:
self._request_handler_task.cancel() self._handler_task.cancel()
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_timeout_handler: if self._request_timeout_handler:
self._request_timeout_handler.cancel() self._request_timeout_handler.cancel()
if self._response_timeout_handler: if self._response_timeout_handler:
@ -214,10 +240,8 @@ class HttpProtocol(asyncio.Protocol):
time_left, self.request_timeout_callback time_left, self.request_timeout_callback
) )
else: else:
if self._request_stream_task: if self._handler_task:
self._request_stream_task.cancel() self._handler_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout")) self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self): def response_timeout_callback(self):
@ -230,10 +254,8 @@ class HttpProtocol(asyncio.Protocol):
time_left, self.response_timeout_callback time_left, self.response_timeout_callback
) )
else: else:
if self._request_stream_task: if self._handler_task:
self._request_stream_task.cancel() self._handler_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout")) self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self): def keep_alive_timeout_callback(self):
@ -266,15 +288,6 @@ class HttpProtocol(asyncio.Protocol):
if self._total_request_size > self.request_max_size: if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large")) self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection # Parse request chunk or close connection
try: try:
self.parser.feed_data(data) self.parser.feed_data(data)
@ -284,6 +297,13 @@ class HttpProtocol(asyncio.Protocol):
message += "\n" + traceback.format_exc() message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message)) self.write_error(InvalidUsage(message))
def on_message_begin(self):
assert self.request is None
self.headers = []
# requests count
self.state["requests_count"] += 1
def on_url(self, url): def on_url(self, url):
if not self.url: if not self.url:
self.url = url self.url = url
@ -327,8 +347,9 @@ class HttpProtocol(asyncio.Protocol):
if self.request.headers.get(EXPECT_HEADER): if self.request.headers.get(EXPECT_HEADER):
self.expect_handler() self.expect_handler()
if self.is_request_stream: self.request.stream = StreamBuffer(
self.request.stream = StreamBuffer(self.request_buffer_queue_size) self.request_buffer_queue_size, protocol=self,
)
self.execute_request_handler() self.execute_request_handler()
def expect_handler(self): def expect_handler(self):
@ -347,44 +368,10 @@ class HttpProtocol(asyncio.Protocol):
) )
def on_body(self, body): def on_body(self, body):
# body chunks can be put into asyncio.Queue out of order if # Send request body chunks
# multiple tasks put concurrently and the queue is full in python if body:
# 3.7. so we should not create more than one task putting into the self.request.stream._queue.put_nowait(body)
# queue simultaneously. self.flow_control()
self._body_chunks.append(body)
if (
not self._request_stream_task
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
async def body_append(self, body):
if (
self.request is None
or self._request_stream_task is None
or self._request_stream_task.cancelled()
):
return
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
async def stream_append(self):
while self._body_chunks and self.request:
body = self._body_chunks.popleft()
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
self._body_chunks.clear()
def on_message_complete(self): def on_message_complete(self):
# Entire request (headers and whole body) is received. # Entire request (headers and whole body) is received.
@ -393,14 +380,9 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler.cancel() self._request_timeout_handler.cancel()
self._request_timeout_handler = None self._request_timeout_handler = None
self._body_chunks.append(None) # Mark the end of request body
if ( self.request.stream._queue.put_nowait(None)
not self._request_stream_task self.request = None
or self._request_stream_task.done()
):
self._request_stream_task = self.loop.create_task(
self.stream_append()
)
def execute_request_handler(self): def execute_request_handler(self):
""" """
@ -409,15 +391,17 @@ class HttpProtocol(asyncio.Protocol):
:return: None :return: None
""" """
# Invoked when request headers have been received
self._response_timeout_handler = self.loop.call_later( self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback self.response_timeout, self.response_timeout_callback
) )
self._last_request_time = time() self._last_request_time = time()
self._request_handler_task = self.loop.create_task( self._handler_queue.put_nowait(
self.request_handler( self.request_handler(
self.request, self.write_response, self.stream_response self.request, self.write_response, self.stream_response
) )
) )
self.flow_control()
# -------------------------------------------- # # -------------------------------------------- #
# Responding # Responding
@ -467,17 +451,17 @@ class HttpProtocol(asyncio.Protocol):
try: try:
keep_alive = self.keep_alive keep_alive = self.keep_alive
self.transport.write( self.transport.write(
response.output( response.output("1.1", keep_alive, self.keep_alive_timeout)
self.request.version, keep_alive, self.keep_alive_timeout
)
) )
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
if isinstance(response, HTTPResponse):
raise
url = self.url.decode()
res_type = type(response).__name__
logger.error( logger.error(
"Invalid response object for url %s, " f"Invalid response object for url {url!r}, "
"Expected Type: HTTPResponse, Actual Type: %s", f"Expected Type: HTTPResponse, Actual Type: {res_type}"
self.url,
type(response),
) )
self.write_error(ServerError("Invalid response type")) self.write_error(ServerError("Invalid response type"))
except RuntimeError: except RuntimeError:
@ -521,9 +505,7 @@ class HttpProtocol(asyncio.Protocol):
try: try:
keep_alive = self.keep_alive keep_alive = self.keep_alive
response.protocol = self response.protocol = self
await response.stream( await response.stream("1.1", keep_alive, self.keep_alive_timeout)
self.request.version, keep_alive, self.keep_alive_timeout
)
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
logger.error( logger.error(
@ -564,8 +546,7 @@ class HttpProtocol(asyncio.Protocol):
response = None response = None
try: try:
response = self.error_handler.response(self.request, exception) response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else "1.1" self.transport.write(response.output("1.1"))
self.transport.write(response.output(version))
except RuntimeError: except RuntimeError:
if self._debug: if self._debug:
logger.error( logger.error(
@ -621,12 +602,8 @@ class HttpProtocol(asyncio.Protocol):
"""This is called when KeepAlive feature is used, """This is called when KeepAlive feature is used,
it resets the connection in order for it to be able it resets the connection in order for it to be able
to handle receiving another request on the same connection.""" to handle receiving another request on the same connection."""
self.parser = None
self.request = None
self.url = None self.url = None
self.headers = None self.headers = None
self._request_handler_task = None
self._request_stream_task = None
self._total_request_size = 0 self._total_request_size = 0
def close_if_idle(self): def close_if_idle(self):
@ -888,7 +865,7 @@ def serve(
reuse_port=reuse_port, reuse_port=reuse_port,
sock=sock, sock=sock,
backlog=backlog, backlog=backlog,
**asyncio_server_kwargs **asyncio_server_kwargs,
) )
if run_async: if run_async:

View File

@ -73,7 +73,6 @@ class SanicTestClient:
): ):
results = [None, None] results = [None, None]
exceptions = [] exceptions = []
if gather_request: if gather_request:
def _collect_request(request): def _collect_request(request):