Async http protocol loop.

This commit is contained in:
L. Kärkkäinen 2020-02-24 16:23:53 +02:00
parent f609a4850f
commit f6a0b4a497
2 changed files with 138 additions and 290 deletions

View File

@ -953,8 +953,8 @@ class Sanic:
handler, args, kwargs, uri, name = self.router.get(request) handler, args, kwargs, uri, name = self.router.get(request)
# Non-streaming handlers have their body preloaded # Non-streaming handlers have their body preloaded
if not self.router.is_stream_handler(request): #if not self.router.is_stream_handler(request):
await request.receive_body() # await request.receive_body()
# -------------------------------------------- # # -------------------------------------------- #
# Request Middleware # Request Middleware

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import enum
import os import os
import sys import sys
import traceback import traceback
@ -10,6 +11,7 @@ from multiprocessing import Process
from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from socket import SO_REUSEADDR, SOL_SOCKET, socket from socket import SO_REUSEADDR, SOL_SOCKET, socket
from time import monotonic as current_time
from time import time from time import time
from httptools import HttpRequestParser # type: ignore from httptools import HttpRequestParser # type: ignore
@ -21,6 +23,7 @@ from sanic.exceptions import (
InvalidUsage, InvalidUsage,
PayloadTooLarge, PayloadTooLarge,
RequestTimeout, RequestTimeout,
SanicException,
ServerError, ServerError,
ServiceUnavailable, ServiceUnavailable,
) )
@ -42,6 +45,13 @@ class Signal:
stopped = False stopped = False
class Status(enum.Enum):
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
EXPECT = 2 # Sender wants 100-continue
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
""" """
This class provides a basic HTTP implementation of the sanic framework. This class provides a basic HTTP implementation of the sanic framework.
@ -56,10 +66,7 @@ class HttpProtocol(asyncio.Protocol):
"connections", "connections",
"signal", "signal",
# request params # request params
"parser",
"request", "request",
"url",
"headers",
# request config # request config
"request_handler", "request_handler",
"request_timeout", "request_timeout",
@ -68,26 +75,24 @@ class HttpProtocol(asyncio.Protocol):
"request_max_size", "request_max_size",
"request_buffer_queue_size", "request_buffer_queue_size",
"request_class", "request_class",
"is_request_stream",
"router", "router",
"error_handler", "error_handler",
# enable or disable access log purpose # enable or disable access log purpose
"access_log", "access_log",
# connection management # connection management
"_total_request_size", "_total_request_size",
"_request_timeout_handler", "_status",
"_response_timeout_handler", "_time",
"_keep_alive_timeout_handler",
"_last_request_time",
"_last_response_time", "_last_response_time",
"_is_stream_handler", "keep_alive",
"_not_paused",
"_keep_alive",
"_header_fragment",
"state", "state",
"url",
"_debug", "_debug",
"_handler_queue",
"_handler_task", "_handler_task",
"_buffer",
"_can_write",
"_data_received",
"_task",
) )
def __init__( def __init__(
@ -113,13 +118,12 @@ class HttpProtocol(asyncio.Protocol):
debug=False, debug=False,
**kwargs, **kwargs,
): ):
deprecated_loop = loop if sys.version_info < (3, 8) else None
self.loop = loop self.loop = loop
self.app = app self.app = app
self.url = None
self.transport = None self.transport = None
self.request = None self.request = None
self.parser = HttpRequestParser(self)
self.url = None
self.headers = None
self.router = router self.router = router
self.signal = signal self.signal = signal
self.access_log = access_log self.access_log = access_log
@ -132,72 +136,16 @@ class HttpProtocol(asyncio.Protocol):
self.keep_alive_timeout = keep_alive_timeout self.keep_alive_timeout = keep_alive_timeout
self.request_max_size = request_max_size self.request_max_size = request_max_size
self.request_class = request_class or Request self.request_class = request_class or Request
self.is_request_stream = is_request_stream
self._is_stream_handler = False
if sys.version_info.minor >= 8:
self._not_paused = asyncio.Event()
else:
self._not_paused = asyncio.Event(loop=loop)
self._total_request_size = 0 self._total_request_size = 0
self._request_timeout_handler = None self.keep_alive = keep_alive
self._response_timeout_handler = None
self._keep_alive_timeout_handler = None
self._last_request_time = None
self._last_response_time = None
self._keep_alive = keep_alive
self._header_fragment = b""
self.state = state if state else {} self.state = state if state else {}
if "requests_count" not in self.state: if "requests_count" not in self.state:
self.state["requests_count"] = 0 self.state["requests_count"] = 0
self._debug = debug self._debug = debug
self._not_paused.set() self._buffer = bytearray()
self._handler_queue = asyncio.Queue(1) self._data_received = asyncio.Event(loop=deprecated_loop)
self._handler_task = self.loop.create_task(self.run_handlers()) self._can_write = asyncio.Event(loop=deprecated_loop)
self._can_write.set()
async def run_handlers(self):
"""Runs one handler at a time, in correct order.
Otherwise we cannot control incoming requests, this keeps their
responses in order.
Handlers are inserted into the queue when headers are received. Body
may follow later and is handled by a separate queue in req.stream.
"""
while True:
try:
self.flow_control()
handler = await self._handler_queue.get()
await handler
except Exception as e:
traceback.print_exc()
self.transport.close()
raise
def flow_control(self):
"""Backpressure handling to avoid excessive buffering."""
if not self.transport:
return
if self._handler_queue.full() or (
self.request and self.request.stream.is_full()
):
self.transport.pause_reading()
else:
self.transport.resume_reading()
@property
def keep_alive(self):
"""
Check if the connection needs to be kept alive based on the params
attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
and :func:`HttpProtocol.parser.should_keep_alive`
:return: ``True`` if connection is to be kept alive ``False`` else
"""
return (
self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive()
)
# -------------------------------------------- # # -------------------------------------------- #
# Connection # Connection
@ -205,152 +153,120 @@ class HttpProtocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
self.connections.add(self) self.connections.add(self)
self._request_timeout_handler = self.loop.call_later(
self.request_timeout, self.request_timeout_callback
)
self.transport = transport self.transport = transport
self._last_request_time = time() self._status, self._time = Status.IDLE, current_time()
self.check_timeouts()
self._task = self.loop.create_task(self.http1())
def connection_lost(self, exc): def connection_lost(self, exc):
self.connections.discard(self) self.connections.discard(self)
if self._handler_task: if self._task:
self._handler_task.cancel() self._task.cancel()
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
def pause_writing(self): def pause_writing(self):
self._not_paused.clear() self._can_write.clear()
def resume_writing(self): def resume_writing(self):
self._not_paused.set() self._can_write.set()
def request_timeout_callback(self): async def receive_more(self):
# See the docstring in the RequestTimeout exception, to see """Wait until more data is received into self._buffer."""
# exactly what this timeout is checking for. self.transport.resume_reading()
# Check if elapsed time since request initiated exceeds our self._data_received.clear()
# configured maximum request timeout value await self._data_received.wait()
time_elapsed = time() - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = self.loop.call_later(
time_left, self.request_timeout_callback
)
else:
if self._handler_task:
self._handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = self.loop.call_later(
time_left, self.response_timeout_callback
)
else:
if self._handler_task:
self._handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
"""
Check if elapsed time since last response exceeds our configured
maximum keep alive timeout value and if so, close the transport
pipe and let the response writer handle the error.
:return: None
"""
time_elapsed = time() - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = self.loop.call_later(
time_left, self.keep_alive_timeout_callback
)
else:
logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close()
self.transport = None
# -------------------------------------------- # # -------------------------------------------- #
# Parsing # Parsing
# -------------------------------------------- # # -------------------------------------------- #
def data_received(self, data): def data_received(self, data):
# Check for the request itself getting too large and exceeding if not data:
# memory limits return self.close()
self._total_request_size += len(data) self._buffer += data
if self._total_request_size > self.request_max_size: if len(self._buffer) > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large")) self.transport.pause_reading()
# Parse request chunk or close connection if self._data_received:
try: self._data_received.set()
self.parser.feed_data(data)
except HttpParserError:
message = "Bad Request"
if self._debug:
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
def on_message_begin(self): def check_timeouts(self):
assert self.request is None """Runs itself once a second to enforce any expired timeouts."""
self.headers = [] duration = current_time() - self._time
status = self._status
# requests count if (status == Status.IDLE and duration > self.keep_alive_timeout):
self.state["requests_count"] += 1 logger.debug("KeepAlive Timeout. Closing connection.")
elif (status == Status.REQUEST and duration > self.request_timeout):
def on_url(self, url): self.write_error(RequestTimeout("Request Timeout"))
if not self.url: elif (status.value > Status.REQUEST.value and duration > self.response_timeout):
self.url = url self.write_error(ServiceUnavailable("Response Timeout"))
else: else:
self.url += url self.loop.call_later(1, self.check_timeouts)
return
self.close()
def on_header(self, name, value):
self._header_fragment += name
if value is not None: async def http1(self):
if ( """HTTP 1.1 connection handler"""
self._header_fragment == b"Content-Length" try:
and int(value) > self.request_max_size buf = self._buffer
): while self.keep_alive:
self.write_error(PayloadTooLarge("Payload Too Large")) # Read request header
try: pos = 0
value = value.decode() self._time = current_time()
except UnicodeDecodeError: while len(buf) < self.request_max_size:
value = value.decode("latin_1") if buf:
self.headers.append( self._status = Status.REQUEST
(self._header_fragment.decode().casefold(), value) pos = buf.find(b"\r\n\r\n", pos)
) if pos >= 0:
break
pos = max(0, len(buf) - 3)
await self.receive_more()
else:
raise PayloadTooLarge("Payload Too Large")
self._header_fragment = b"" self._total_request_size = pos + 4
reqline, *headers = buf[:pos].decode().split("\r\n")
method, self.url, protocol = reqline.split(" ")
assert protocol in ("HTTP/1.0", "HTTP/1.1")
del buf[:pos + 4]
headers = Header(
(name.lower(), value.lstrip())
for name, value in (h.split(":", 1) for h in headers)
)
if headers.get(EXPECT_HEADER):
self._status = Status.EXPECT
self.expect_handler()
self.state["requests_count"] += 1
# Run handler
self.request = self.request_class(
url_bytes=self.url.encode(),
headers=headers,
version="1.1",
method=method,
transport=self.transport,
app=self.app,
)
request_body = (
int(headers.get("content-length", 0)) or
headers.get("transfer-encoding") == "chunked"
)
if request_body:
self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self,
)
self.request.stream._queue.put_nowait(None)
def on_headers_complete(self): self._status, self._time = Status.HANDLER, current_time()
self.request = self.request_class( await self.request_handler(
url_bytes=self.url, self.request, self.write_response, self.stream_response
headers=Header(self.headers), )
version=self.parser.get_http_version(), self._status, self._time = Status.IDLE, current_time()
method=self.parser.get_method().decode(), except SanicException as e:
transport=self.transport, self.write_error(e)
app=self.app, except Exception as e:
) print(repr(e))
# Remove any existing KeepAlive handler here, finally:
# It will be recreated if required on the new request. self.close()
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self,
)
self.execute_request_handler()
def expect_handler(self): def expect_handler(self):
""" """
@ -367,42 +283,6 @@ class HttpProtocol(asyncio.Protocol):
) )
) )
def on_body(self, body):
# Send request body chunks
if body:
self.request.stream._queue.put_nowait(body)
self.flow_control()
def on_message_complete(self):
# Entire request (headers and whole body) is received.
# We can cancel and remove the request timeout handler now.
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
self._request_timeout_handler = None
# Mark the end of request body
self.request.stream._queue.put_nowait(None)
self.request = None
def execute_request_handler(self):
"""
Invoke the request handler defined by the
:func:`sanic.app.Sanic.handle_request` method
:return: None
"""
# Invoked when request headers have been received
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = time()
self._handler_queue.put_nowait(
self.request_handler(
self.request, self.write_response, self.stream_response
)
)
self.flow_control()
# -------------------------------------------- # # -------------------------------------------- #
# Responding # Responding
# -------------------------------------------- # # -------------------------------------------- #
@ -445,13 +325,11 @@ class HttpProtocol(asyncio.Protocol):
""" """
Writes response content synchronously to the transport. Writes response content synchronously to the transport.
""" """
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try: try:
keep_alive = self.keep_alive self._status, self._time = Status.RESPONSE, current_time()
self._last_response_time = self._time
self.transport.write( self.transport.write(
response.output("1.1", keep_alive, self.keep_alive_timeout) response.output("1.1", self.keep_alive, self.keep_alive_timeout)
) )
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
@ -470,24 +348,20 @@ class HttpProtocol(asyncio.Protocol):
"Connection lost before response written @ %s", "Connection lost before response written @ %s",
self.request.ip, self.request.ip,
) )
keep_alive = False self.keep_alive = False
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing response failed, connection closed {}".format(repr(e)) "Writing response failed, connection closed {}".format(repr(e))
) )
finally: finally:
if not keep_alive: if not self.keep_alive:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
else: else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time() self._last_response_time = time()
self.cleanup()
async def drain(self): async def drain(self):
await self._not_paused.wait() await self._can_write.wait()
async def push_data(self, data): async def push_data(self, data):
self.transport.write(data) self.transport.write(data)
@ -498,14 +372,10 @@ class HttpProtocol(asyncio.Protocol):
the transport to the response so the response consumer can the transport to the response so the response consumer can
write to the response as needed. write to the response as needed.
""" """
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try: try:
keep_alive = self.keep_alive self._status, self._time = Status.RESPONSE, current_time()
response.protocol = self response.protocol = self
await response.stream("1.1", keep_alive, self.keep_alive_timeout) await response.stream("1.1", self.keep_alive, self.keep_alive_timeout)
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
logger.error( logger.error(
@ -521,28 +391,15 @@ class HttpProtocol(asyncio.Protocol):
"Connection lost before response written @ %s", "Connection lost before response written @ %s",
self.request.ip, self.request.ip,
) )
keep_alive = False self.keep_alive = False
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing response failed, connection closed {}".format(repr(e)) "Writing response failed, connection closed {}".format(repr(e))
) )
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
def write_error(self, exception): def write_error(self, exception):
# An error _is_ a response. # An error _is_ a response.
# Don't throw a response timeout, when a response _is_ given. # Don't throw a response timeout, when a response _is_ given.
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
response = None response = None
try: try:
response = self.error_handler.response(self.request, exception) response = self.error_handler.response(self.request, exception)
@ -559,14 +416,9 @@ class HttpProtocol(asyncio.Protocol):
from_error=True, from_error=True,
) )
finally: finally:
if self.parser and ( if self.keep_alive or getattr(response, "status") == 408:
self.keep_alive or getattr(response, "status", 0) == 408
):
self.log_response(response) self.log_response(response)
try: self.close()
self.transport.close()
except AttributeError:
logger.debug("Connection lost before server could close it.")
def bail_out(self, message, from_error=False): def bail_out(self, message, from_error=False):
""" """
@ -598,21 +450,13 @@ class HttpProtocol(asyncio.Protocol):
self.write_error(ServerError(message)) self.write_error(ServerError(message))
logger.error(message) logger.error(message)
def cleanup(self):
"""This is called when KeepAlive feature is used,
it resets the connection in order for it to be able
to handle receiving another request on the same connection."""
self.url = None
self.headers = None
self._total_request_size = 0
def close_if_idle(self): def close_if_idle(self):
"""Close the connection if a request is not being sent or received """Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open :return: boolean - True if closed, false if staying open
""" """
if not self.parser: if self._status == Status.IDLE:
self.transport.close() self.close()
return True return True
return False return False
@ -621,8 +465,12 @@ class HttpProtocol(asyncio.Protocol):
Force close the connection. Force close the connection.
""" """
if self.transport is not None: if self.transport is not None:
self.transport.close() try:
self.transport = None self.keep_alive = False
self._task.cancel()
self.transport.close()
finally:
self.transport = None
def trigger_events(events, loop): def trigger_events(events, loop):