diff --git a/sanic/app.py b/sanic/app.py index 5da7c925..1b9a48fc 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -953,8 +953,8 @@ class Sanic: handler, args, kwargs, uri, name = self.router.get(request) # Non-streaming handlers have their body preloaded - if not self.router.is_stream_handler(request): - await request.receive_body() + #if not self.router.is_stream_handler(request): + # await request.receive_body() # -------------------------------------------- # # Request Middleware diff --git a/sanic/server.py b/sanic/server.py index c2df8056..679d3bb0 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -1,4 +1,5 @@ import asyncio +import enum import os import sys import traceback @@ -10,6 +11,7 @@ from multiprocessing import Process from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func from socket import SO_REUSEADDR, SOL_SOCKET, socket +from time import monotonic as current_time from time import time from httptools import HttpRequestParser # type: ignore @@ -21,6 +23,7 @@ from sanic.exceptions import ( InvalidUsage, PayloadTooLarge, RequestTimeout, + SanicException, ServerError, ServiceUnavailable, ) @@ -42,6 +45,13 @@ class Signal: stopped = False +class Status(enum.Enum): + IDLE = 0 # Waiting for request + REQUEST = 1 # Request headers being received + EXPECT = 2 # Sender wants 100-continue + HANDLER = 3 # Headers done, handler running + RESPONSE = 4 # Response headers sent + class HttpProtocol(asyncio.Protocol): """ This class provides a basic HTTP implementation of the sanic framework. @@ -56,10 +66,7 @@ class HttpProtocol(asyncio.Protocol): "connections", "signal", # request params - "parser", "request", - "url", - "headers", # request config "request_handler", "request_timeout", @@ -68,26 +75,24 @@ class HttpProtocol(asyncio.Protocol): "request_max_size", "request_buffer_queue_size", "request_class", - "is_request_stream", "router", "error_handler", # enable or disable access log purpose "access_log", # connection management "_total_request_size", - "_request_timeout_handler", - "_response_timeout_handler", - "_keep_alive_timeout_handler", - "_last_request_time", + "_status", + "_time", "_last_response_time", - "_is_stream_handler", - "_not_paused", - "_keep_alive", - "_header_fragment", + "keep_alive", "state", + "url", "_debug", - "_handler_queue", "_handler_task", + "_buffer", + "_can_write", + "_data_received", + "_task", ) def __init__( @@ -113,13 +118,12 @@ class HttpProtocol(asyncio.Protocol): debug=False, **kwargs, ): + deprecated_loop = loop if sys.version_info < (3, 8) else None self.loop = loop self.app = app + self.url = None self.transport = None self.request = None - self.parser = HttpRequestParser(self) - self.url = None - self.headers = None self.router = router self.signal = signal self.access_log = access_log @@ -132,72 +136,16 @@ class HttpProtocol(asyncio.Protocol): self.keep_alive_timeout = keep_alive_timeout self.request_max_size = request_max_size self.request_class = request_class or Request - self.is_request_stream = is_request_stream - self._is_stream_handler = False - if sys.version_info.minor >= 8: - self._not_paused = asyncio.Event() - else: - self._not_paused = asyncio.Event(loop=loop) self._total_request_size = 0 - self._request_timeout_handler = None - self._response_timeout_handler = None - self._keep_alive_timeout_handler = None - self._last_request_time = None - self._last_response_time = None - self._keep_alive = keep_alive - self._header_fragment = b"" + self.keep_alive = keep_alive self.state = state if state else {} if "requests_count" not in self.state: self.state["requests_count"] = 0 self._debug = debug - self._not_paused.set() - self._handler_queue = asyncio.Queue(1) - self._handler_task = self.loop.create_task(self.run_handlers()) - - async def run_handlers(self): - """Runs one handler at a time, in correct order. - - Otherwise we cannot control incoming requests, this keeps their - responses in order. - - Handlers are inserted into the queue when headers are received. Body - may follow later and is handled by a separate queue in req.stream. - """ - while True: - try: - self.flow_control() - handler = await self._handler_queue.get() - await handler - except Exception as e: - traceback.print_exc() - self.transport.close() - raise - - def flow_control(self): - """Backpressure handling to avoid excessive buffering.""" - if not self.transport: - return - if self._handler_queue.full() or ( - self.request and self.request.stream.is_full() - ): - self.transport.pause_reading() - else: - self.transport.resume_reading() - - @property - def keep_alive(self): - """ - Check if the connection needs to be kept alive based on the params - attached to the `_keep_alive` attribute, :attr:`Signal.stopped` - and :func:`HttpProtocol.parser.should_keep_alive` - - :return: ``True`` if connection is to be kept alive ``False`` else - """ - return ( - self._keep_alive - and not self.signal.stopped - and self.parser.should_keep_alive() - ) + self._buffer = bytearray() + self._data_received = asyncio.Event(loop=deprecated_loop) + self._can_write = asyncio.Event(loop=deprecated_loop) + self._can_write.set() # -------------------------------------------- # # Connection @@ -205,152 +153,120 @@ class HttpProtocol(asyncio.Protocol): def connection_made(self, transport): self.connections.add(self) - self._request_timeout_handler = self.loop.call_later( - self.request_timeout, self.request_timeout_callback - ) self.transport = transport - self._last_request_time = time() + self._status, self._time = Status.IDLE, current_time() + self.check_timeouts() + self._task = self.loop.create_task(self.http1()) def connection_lost(self, exc): self.connections.discard(self) - if self._handler_task: - self._handler_task.cancel() - if self._request_timeout_handler: - self._request_timeout_handler.cancel() - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - if self._keep_alive_timeout_handler: - self._keep_alive_timeout_handler.cancel() + if self._task: + self._task.cancel() def pause_writing(self): - self._not_paused.clear() + self._can_write.clear() def resume_writing(self): - self._not_paused.set() + self._can_write.set() - def request_timeout_callback(self): - # See the docstring in the RequestTimeout exception, to see - # exactly what this timeout is checking for. - # Check if elapsed time since request initiated exceeds our - # configured maximum request timeout value - time_elapsed = time() - self._last_request_time - if time_elapsed < self.request_timeout: - time_left = self.request_timeout - time_elapsed - self._request_timeout_handler = self.loop.call_later( - time_left, self.request_timeout_callback - ) - else: - if self._handler_task: - self._handler_task.cancel() - self.write_error(RequestTimeout("Request Timeout")) - - def response_timeout_callback(self): - # Check if elapsed time since response was initiated exceeds our - # configured maximum request timeout value - time_elapsed = time() - self._last_request_time - if time_elapsed < self.response_timeout: - time_left = self.response_timeout - time_elapsed - self._response_timeout_handler = self.loop.call_later( - time_left, self.response_timeout_callback - ) - else: - if self._handler_task: - self._handler_task.cancel() - self.write_error(ServiceUnavailable("Response Timeout")) - - def keep_alive_timeout_callback(self): - """ - Check if elapsed time since last response exceeds our configured - maximum keep alive timeout value and if so, close the transport - pipe and let the response writer handle the error. - - :return: None - """ - time_elapsed = time() - self._last_response_time - if time_elapsed < self.keep_alive_timeout: - time_left = self.keep_alive_timeout - time_elapsed - self._keep_alive_timeout_handler = self.loop.call_later( - time_left, self.keep_alive_timeout_callback - ) - else: - logger.debug("KeepAlive Timeout. Closing connection.") - self.transport.close() - self.transport = None + async def receive_more(self): + """Wait until more data is received into self._buffer.""" + self.transport.resume_reading() + self._data_received.clear() + await self._data_received.wait() # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): - # Check for the request itself getting too large and exceeding - # memory limits - self._total_request_size += len(data) - if self._total_request_size > self.request_max_size: - self.write_error(PayloadTooLarge("Payload Too Large")) + if not data: + return self.close() + self._buffer += data + if len(self._buffer) > self.request_max_size: + self.transport.pause_reading() - # Parse request chunk or close connection - try: - self.parser.feed_data(data) - except HttpParserError: - message = "Bad Request" - if self._debug: - message += "\n" + traceback.format_exc() - self.write_error(InvalidUsage(message)) + if self._data_received: + self._data_received.set() - def on_message_begin(self): - assert self.request is None - self.headers = [] - - # requests count - self.state["requests_count"] += 1 - - def on_url(self, url): - if not self.url: - self.url = url + def check_timeouts(self): + """Runs itself once a second to enforce any expired timeouts.""" + duration = current_time() - self._time + status = self._status + if (status == Status.IDLE and duration > self.keep_alive_timeout): + logger.debug("KeepAlive Timeout. Closing connection.") + elif (status == Status.REQUEST and duration > self.request_timeout): + self.write_error(RequestTimeout("Request Timeout")) + elif (status.value > Status.REQUEST.value and duration > self.response_timeout): + self.write_error(ServiceUnavailable("Response Timeout")) else: - self.url += url + self.loop.call_later(1, self.check_timeouts) + return + self.close() - def on_header(self, name, value): - self._header_fragment += name - if value is not None: - if ( - self._header_fragment == b"Content-Length" - and int(value) > self.request_max_size - ): - self.write_error(PayloadTooLarge("Payload Too Large")) - try: - value = value.decode() - except UnicodeDecodeError: - value = value.decode("latin_1") - self.headers.append( - (self._header_fragment.decode().casefold(), value) - ) + async def http1(self): + """HTTP 1.1 connection handler""" + try: + buf = self._buffer + while self.keep_alive: + # Read request header + pos = 0 + self._time = current_time() + while len(buf) < self.request_max_size: + if buf: + self._status = Status.REQUEST + pos = buf.find(b"\r\n\r\n", pos) + if pos >= 0: + break + pos = max(0, len(buf) - 3) + await self.receive_more() + else: + raise PayloadTooLarge("Payload Too Large") - self._header_fragment = b"" + self._total_request_size = pos + 4 + reqline, *headers = buf[:pos].decode().split("\r\n") + method, self.url, protocol = reqline.split(" ") + assert protocol in ("HTTP/1.0", "HTTP/1.1") + del buf[:pos + 4] + headers = Header( + (name.lower(), value.lstrip()) + for name, value in (h.split(":", 1) for h in headers) + ) + if headers.get(EXPECT_HEADER): + self._status = Status.EXPECT + self.expect_handler() + self.state["requests_count"] += 1 + # Run handler + self.request = self.request_class( + url_bytes=self.url.encode(), + headers=headers, + version="1.1", + method=method, + transport=self.transport, + app=self.app, + ) + request_body = ( + int(headers.get("content-length", 0)) or + headers.get("transfer-encoding") == "chunked" + ) + if request_body: + self.request.stream = StreamBuffer( + self.request_buffer_queue_size, protocol=self, + ) + self.request.stream._queue.put_nowait(None) - def on_headers_complete(self): - self.request = self.request_class( - url_bytes=self.url, - headers=Header(self.headers), - version=self.parser.get_http_version(), - method=self.parser.get_method().decode(), - transport=self.transport, - app=self.app, - ) - # Remove any existing KeepAlive handler here, - # It will be recreated if required on the new request. - if self._keep_alive_timeout_handler: - self._keep_alive_timeout_handler.cancel() - self._keep_alive_timeout_handler = None - - if self.request.headers.get(EXPECT_HEADER): - self.expect_handler() - - self.request.stream = StreamBuffer( - self.request_buffer_queue_size, protocol=self, - ) - self.execute_request_handler() + self._status, self._time = Status.HANDLER, current_time() + await self.request_handler( + self.request, self.write_response, self.stream_response + ) + self._status, self._time = Status.IDLE, current_time() + except SanicException as e: + self.write_error(e) + except Exception as e: + print(repr(e)) + finally: + self.close() def expect_handler(self): """ @@ -367,42 +283,6 @@ class HttpProtocol(asyncio.Protocol): ) ) - def on_body(self, body): - # Send request body chunks - if body: - self.request.stream._queue.put_nowait(body) - self.flow_control() - - def on_message_complete(self): - # Entire request (headers and whole body) is received. - # We can cancel and remove the request timeout handler now. - if self._request_timeout_handler: - self._request_timeout_handler.cancel() - self._request_timeout_handler = None - - # Mark the end of request body - self.request.stream._queue.put_nowait(None) - self.request = None - - def execute_request_handler(self): - """ - Invoke the request handler defined by the - :func:`sanic.app.Sanic.handle_request` method - - :return: None - """ - # Invoked when request headers have been received - self._response_timeout_handler = self.loop.call_later( - self.response_timeout, self.response_timeout_callback - ) - self._last_request_time = time() - self._handler_queue.put_nowait( - self.request_handler( - self.request, self.write_response, self.stream_response - ) - ) - self.flow_control() - # -------------------------------------------- # # Responding # -------------------------------------------- # @@ -445,13 +325,11 @@ class HttpProtocol(asyncio.Protocol): """ Writes response content synchronously to the transport. """ - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None try: - keep_alive = self.keep_alive + self._status, self._time = Status.RESPONSE, current_time() + self._last_response_time = self._time self.transport.write( - response.output("1.1", keep_alive, self.keep_alive_timeout) + response.output("1.1", self.keep_alive, self.keep_alive_timeout) ) self.log_response(response) except AttributeError: @@ -470,24 +348,20 @@ class HttpProtocol(asyncio.Protocol): "Connection lost before response written @ %s", self.request.ip, ) - keep_alive = False + self.keep_alive = False except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(repr(e)) ) finally: - if not keep_alive: + if not self.keep_alive: self.transport.close() self.transport = None else: - self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, self.keep_alive_timeout_callback - ) self._last_response_time = time() - self.cleanup() async def drain(self): - await self._not_paused.wait() + await self._can_write.wait() async def push_data(self, data): self.transport.write(data) @@ -498,14 +372,10 @@ class HttpProtocol(asyncio.Protocol): the transport to the response so the response consumer can write to the response as needed. """ - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None - try: - keep_alive = self.keep_alive + self._status, self._time = Status.RESPONSE, current_time() response.protocol = self - await response.stream("1.1", keep_alive, self.keep_alive_timeout) + await response.stream("1.1", self.keep_alive, self.keep_alive_timeout) self.log_response(response) except AttributeError: logger.error( @@ -521,28 +391,15 @@ class HttpProtocol(asyncio.Protocol): "Connection lost before response written @ %s", self.request.ip, ) - keep_alive = False + self.keep_alive = False except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(repr(e)) ) - finally: - if not keep_alive: - self.transport.close() - self.transport = None - else: - self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, self.keep_alive_timeout_callback - ) - self._last_response_time = time() - self.cleanup() def write_error(self, exception): # An error _is_ a response. # Don't throw a response timeout, when a response _is_ given. - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None response = None try: response = self.error_handler.response(self.request, exception) @@ -559,14 +416,9 @@ class HttpProtocol(asyncio.Protocol): from_error=True, ) finally: - if self.parser and ( - self.keep_alive or getattr(response, "status", 0) == 408 - ): + if self.keep_alive or getattr(response, "status") == 408: self.log_response(response) - try: - self.transport.close() - except AttributeError: - logger.debug("Connection lost before server could close it.") + self.close() def bail_out(self, message, from_error=False): """ @@ -598,21 +450,13 @@ class HttpProtocol(asyncio.Protocol): self.write_error(ServerError(message)) logger.error(message) - def cleanup(self): - """This is called when KeepAlive feature is used, - it resets the connection in order for it to be able - to handle receiving another request on the same connection.""" - self.url = None - self.headers = None - self._total_request_size = 0 - def close_if_idle(self): """Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ - if not self.parser: - self.transport.close() + if self._status == Status.IDLE: + self.close() return True return False @@ -621,8 +465,12 @@ class HttpProtocol(asyncio.Protocol): Force close the connection. """ if self.transport is not None: - self.transport.close() - self.transport = None + try: + self.keep_alive = False + self._task.cancel() + self.transport.close() + finally: + self.transport = None def trigger_events(events, loop):