From f609a4850fa4a46923d42cebf87ab262d62cf374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 21 Feb 2020 17:51:38 +0200 Subject: [PATCH] Cleanup of code and avoid mixing streaming responses. --- sanic/request.py | 5 +- sanic/server.py | 169 ++++++++++++++++++++--------------------------- sanic/testing.py | 1 - 3 files changed, 77 insertions(+), 98 deletions(-) diff --git a/sanic/request.py b/sanic/request.py index 6cb1b73f..ebe43a4f 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -47,11 +47,14 @@ class RequestParameters(dict): class StreamBuffer: - def __init__(self, buffer_size=100): + def __init__(self, buffer_size=100, protocol=None): self._queue = asyncio.Queue(buffer_size) + self._protocol = protocol async def read(self): """ Stop reading when gets None """ + if self._protocol: + self._protocol.flow_control() payload = await self._queue.get() self._queue.task_done() return payload diff --git a/sanic/server.py b/sanic/server.py index 97242555..c2df8056 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -82,13 +82,12 @@ class HttpProtocol(asyncio.Protocol): "_last_response_time", "_is_stream_handler", "_not_paused", - "_request_handler_task", - "_request_stream_task", "_keep_alive", "_header_fragment", "state", "_debug", - "_body_chunks", + "_handler_queue", + "_handler_task", ) def __init__( @@ -112,13 +111,13 @@ class HttpProtocol(asyncio.Protocol): router=None, state=None, debug=False, - **kwargs + **kwargs, ): self.loop = loop self.app = app self.transport = None self.request = None - self.parser = None + self.parser = HttpRequestParser(self) self.url = None self.headers = None self.router = router @@ -145,8 +144,6 @@ class HttpProtocol(asyncio.Protocol): self._keep_alive_timeout_handler = None self._last_request_time = None self._last_response_time = None - self._request_handler_task = None - self._request_stream_task = None self._keep_alive = keep_alive self._header_fragment = b"" self.state = state if state else {} @@ -154,7 +151,38 @@ class HttpProtocol(asyncio.Protocol): self.state["requests_count"] = 0 self._debug = debug self._not_paused.set() - self._body_chunks = deque() + self._handler_queue = asyncio.Queue(1) + self._handler_task = self.loop.create_task(self.run_handlers()) + + async def run_handlers(self): + """Runs one handler at a time, in correct order. + + Otherwise we cannot control incoming requests, this keeps their + responses in order. + + Handlers are inserted into the queue when headers are received. Body + may follow later and is handled by a separate queue in req.stream. + """ + while True: + try: + self.flow_control() + handler = await self._handler_queue.get() + await handler + except Exception as e: + traceback.print_exc() + self.transport.close() + raise + + def flow_control(self): + """Backpressure handling to avoid excessive buffering.""" + if not self.transport: + return + if self._handler_queue.full() or ( + self.request and self.request.stream.is_full() + ): + self.transport.pause_reading() + else: + self.transport.resume_reading() @property def keep_alive(self): @@ -185,10 +213,8 @@ class HttpProtocol(asyncio.Protocol): def connection_lost(self, exc): self.connections.discard(self) - if self._request_handler_task: - self._request_handler_task.cancel() - if self._request_stream_task: - self._request_stream_task.cancel() + if self._handler_task: + self._handler_task.cancel() if self._request_timeout_handler: self._request_timeout_handler.cancel() if self._response_timeout_handler: @@ -214,10 +240,8 @@ class HttpProtocol(asyncio.Protocol): time_left, self.request_timeout_callback ) else: - if self._request_stream_task: - self._request_stream_task.cancel() - if self._request_handler_task: - self._request_handler_task.cancel() + if self._handler_task: + self._handler_task.cancel() self.write_error(RequestTimeout("Request Timeout")) def response_timeout_callback(self): @@ -230,10 +254,8 @@ class HttpProtocol(asyncio.Protocol): time_left, self.response_timeout_callback ) else: - if self._request_stream_task: - self._request_stream_task.cancel() - if self._request_handler_task: - self._request_handler_task.cancel() + if self._handler_task: + self._handler_task.cancel() self.write_error(ServiceUnavailable("Response Timeout")) def keep_alive_timeout_callback(self): @@ -266,15 +288,6 @@ class HttpProtocol(asyncio.Protocol): if self._total_request_size > self.request_max_size: self.write_error(PayloadTooLarge("Payload Too Large")) - # Create parser if this is the first time we're receiving data - if self.parser is None: - assert self.request is None - self.headers = [] - self.parser = HttpRequestParser(self) - - # requests count - self.state["requests_count"] = self.state["requests_count"] + 1 - # Parse request chunk or close connection try: self.parser.feed_data(data) @@ -284,6 +297,13 @@ class HttpProtocol(asyncio.Protocol): message += "\n" + traceback.format_exc() self.write_error(InvalidUsage(message)) + def on_message_begin(self): + assert self.request is None + self.headers = [] + + # requests count + self.state["requests_count"] += 1 + def on_url(self, url): if not self.url: self.url = url @@ -327,9 +347,10 @@ class HttpProtocol(asyncio.Protocol): if self.request.headers.get(EXPECT_HEADER): self.expect_handler() - if self.is_request_stream: - self.request.stream = StreamBuffer(self.request_buffer_queue_size) - self.execute_request_handler() + self.request.stream = StreamBuffer( + self.request_buffer_queue_size, protocol=self, + ) + self.execute_request_handler() def expect_handler(self): """ @@ -347,44 +368,10 @@ class HttpProtocol(asyncio.Protocol): ) def on_body(self, body): - # body chunks can be put into asyncio.Queue out of order if - # multiple tasks put concurrently and the queue is full in python - # 3.7. so we should not create more than one task putting into the - # queue simultaneously. - self._body_chunks.append(body) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) - - async def body_append(self, body): - if ( - self.request is None - or self._request_stream_task is None - or self._request_stream_task.cancelled() - ): - return - - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) - - async def stream_append(self): - while self._body_chunks and self.request: - body = self._body_chunks.popleft() - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) - self._body_chunks.clear() + # Send request body chunks + if body: + self.request.stream._queue.put_nowait(body) + self.flow_control() def on_message_complete(self): # Entire request (headers and whole body) is received. @@ -393,14 +380,9 @@ class HttpProtocol(asyncio.Protocol): self._request_timeout_handler.cancel() self._request_timeout_handler = None - self._body_chunks.append(None) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) + # Mark the end of request body + self.request.stream._queue.put_nowait(None) + self.request = None def execute_request_handler(self): """ @@ -409,15 +391,17 @@ class HttpProtocol(asyncio.Protocol): :return: None """ + # Invoked when request headers have been received self._response_timeout_handler = self.loop.call_later( self.response_timeout, self.response_timeout_callback ) self._last_request_time = time() - self._request_handler_task = self.loop.create_task( + self._handler_queue.put_nowait( self.request_handler( self.request, self.write_response, self.stream_response ) ) + self.flow_control() # -------------------------------------------- # # Responding @@ -467,17 +451,17 @@ class HttpProtocol(asyncio.Protocol): try: keep_alive = self.keep_alive self.transport.write( - response.output( - self.request.version, keep_alive, self.keep_alive_timeout - ) + response.output("1.1", keep_alive, self.keep_alive_timeout) ) self.log_response(response) except AttributeError: + if isinstance(response, HTTPResponse): + raise + url = self.url.decode() + res_type = type(response).__name__ logger.error( - "Invalid response object for url %s, " - "Expected Type: HTTPResponse, Actual Type: %s", - self.url, - type(response), + f"Invalid response object for url {url!r}, " + f"Expected Type: HTTPResponse, Actual Type: {res_type}" ) self.write_error(ServerError("Invalid response type")) except RuntimeError: @@ -521,9 +505,7 @@ class HttpProtocol(asyncio.Protocol): try: keep_alive = self.keep_alive response.protocol = self - await response.stream( - self.request.version, keep_alive, self.keep_alive_timeout - ) + await response.stream("1.1", keep_alive, self.keep_alive_timeout) self.log_response(response) except AttributeError: logger.error( @@ -564,8 +546,7 @@ class HttpProtocol(asyncio.Protocol): response = None try: response = self.error_handler.response(self.request, exception) - version = self.request.version if self.request else "1.1" - self.transport.write(response.output(version)) + self.transport.write(response.output("1.1")) except RuntimeError: if self._debug: logger.error( @@ -621,12 +602,8 @@ class HttpProtocol(asyncio.Protocol): """This is called when KeepAlive feature is used, it resets the connection in order for it to be able to handle receiving another request on the same connection.""" - self.parser = None - self.request = None self.url = None self.headers = None - self._request_handler_task = None - self._request_stream_task = None self._total_request_size = 0 def close_if_idle(self): @@ -888,7 +865,7 @@ def serve( reuse_port=reuse_port, sock=sock, backlog=backlog, - **asyncio_server_kwargs + **asyncio_server_kwargs, ) if run_async: diff --git a/sanic/testing.py b/sanic/testing.py index c836a943..f4b297da 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -73,7 +73,6 @@ class SanicTestClient: ): results = [None, None] exceptions = [] - if gather_request: def _collect_request(request):