Async http protocol loop.

This commit is contained in:
L. Kärkkäinen 2020-02-24 16:23:53 +02:00
parent f609a4850f
commit f6a0b4a497
2 changed files with 138 additions and 290 deletions

View File

@ -953,8 +953,8 @@ class Sanic:
handler, args, kwargs, uri, name = self.router.get(request)
# Non-streaming handlers have their body preloaded
if not self.router.is_stream_handler(request):
await request.receive_body()
#if not self.router.is_stream_handler(request):
# await request.receive_body()
# -------------------------------------------- #
# Request Middleware

View File

@ -1,4 +1,5 @@
import asyncio
import enum
import os
import sys
import traceback
@ -10,6 +11,7 @@ from multiprocessing import Process
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from socket import SO_REUSEADDR, SOL_SOCKET, socket
from time import monotonic as current_time
from time import time
from httptools import HttpRequestParser # type: ignore
@ -21,6 +23,7 @@ from sanic.exceptions import (
InvalidUsage,
PayloadTooLarge,
RequestTimeout,
SanicException,
ServerError,
ServiceUnavailable,
)
@ -42,6 +45,13 @@ class Signal:
stopped = False
class Status(enum.Enum):
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
EXPECT = 2 # Sender wants 100-continue
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent
class HttpProtocol(asyncio.Protocol):
"""
This class provides a basic HTTP implementation of the sanic framework.
@ -56,10 +66,7 @@ class HttpProtocol(asyncio.Protocol):
"connections",
"signal",
# request params
"parser",
"request",
"url",
"headers",
# request config
"request_handler",
"request_timeout",
@ -68,26 +75,24 @@ class HttpProtocol(asyncio.Protocol):
"request_max_size",
"request_buffer_queue_size",
"request_class",
"is_request_stream",
"router",
"error_handler",
# enable or disable access log purpose
"access_log",
# connection management
"_total_request_size",
"_request_timeout_handler",
"_response_timeout_handler",
"_keep_alive_timeout_handler",
"_last_request_time",
"_status",
"_time",
"_last_response_time",
"_is_stream_handler",
"_not_paused",
"_keep_alive",
"_header_fragment",
"keep_alive",
"state",
"url",
"_debug",
"_handler_queue",
"_handler_task",
"_buffer",
"_can_write",
"_data_received",
"_task",
)
def __init__(
@ -113,13 +118,12 @@ class HttpProtocol(asyncio.Protocol):
debug=False,
**kwargs,
):
deprecated_loop = loop if sys.version_info < (3, 8) else None
self.loop = loop
self.app = app
self.url = None
self.transport = None
self.request = None
self.parser = HttpRequestParser(self)
self.url = None
self.headers = None
self.router = router
self.signal = signal
self.access_log = access_log
@ -132,72 +136,16 @@ class HttpProtocol(asyncio.Protocol):
self.keep_alive_timeout = keep_alive_timeout
self.request_max_size = request_max_size
self.request_class = request_class or Request
self.is_request_stream = is_request_stream
self._is_stream_handler = False
if sys.version_info.minor >= 8:
self._not_paused = asyncio.Event()
else:
self._not_paused = asyncio.Event(loop=loop)
self._total_request_size = 0
self._request_timeout_handler = None
self._response_timeout_handler = None
self._keep_alive_timeout_handler = None
self._last_request_time = None
self._last_response_time = None
self._keep_alive = keep_alive
self._header_fragment = b""
self.keep_alive = keep_alive
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._debug = debug
self._not_paused.set()
self._handler_queue = asyncio.Queue(1)
self._handler_task = self.loop.create_task(self.run_handlers())
async def run_handlers(self):
"""Runs one handler at a time, in correct order.
Otherwise we cannot control incoming requests, this keeps their
responses in order.
Handlers are inserted into the queue when headers are received. Body
may follow later and is handled by a separate queue in req.stream.
"""
while True:
try:
self.flow_control()
handler = await self._handler_queue.get()
await handler
except Exception as e:
traceback.print_exc()
self.transport.close()
raise
def flow_control(self):
"""Backpressure handling to avoid excessive buffering."""
if not self.transport:
return
if self._handler_queue.full() or (
self.request and self.request.stream.is_full()
):
self.transport.pause_reading()
else:
self.transport.resume_reading()
@property
def keep_alive(self):
"""
Check if the connection needs to be kept alive based on the params
attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
and :func:`HttpProtocol.parser.should_keep_alive`
:return: ``True`` if connection is to be kept alive ``False`` else
"""
return (
self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive()
)
self._buffer = bytearray()
self._data_received = asyncio.Event(loop=deprecated_loop)
self._can_write = asyncio.Event(loop=deprecated_loop)
self._can_write.set()
# -------------------------------------------- #
# Connection
@ -205,152 +153,120 @@ class HttpProtocol(asyncio.Protocol):
def connection_made(self, transport):
self.connections.add(self)
self._request_timeout_handler = self.loop.call_later(
self.request_timeout, self.request_timeout_callback
)
self.transport = transport
self._last_request_time = time()
self._status, self._time = Status.IDLE, current_time()
self.check_timeouts()
self._task = self.loop.create_task(self.http1())
def connection_lost(self, exc):
self.connections.discard(self)
if self._handler_task:
self._handler_task.cancel()
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
if self._task:
self._task.cancel()
def pause_writing(self):
self._not_paused.clear()
self._can_write.clear()
def resume_writing(self):
self._not_paused.set()
self._can_write.set()
def request_timeout_callback(self):
# See the docstring in the RequestTimeout exception, to see
# exactly what this timeout is checking for.
# Check if elapsed time since request initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = self.loop.call_later(
time_left, self.request_timeout_callback
)
else:
if self._handler_task:
self._handler_task.cancel()
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
# configured maximum request timeout value
time_elapsed = time() - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = self.loop.call_later(
time_left, self.response_timeout_callback
)
else:
if self._handler_task:
self._handler_task.cancel()
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
"""
Check if elapsed time since last response exceeds our configured
maximum keep alive timeout value and if so, close the transport
pipe and let the response writer handle the error.
:return: None
"""
time_elapsed = time() - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = self.loop.call_later(
time_left, self.keep_alive_timeout_callback
)
else:
logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close()
self.transport = None
async def receive_more(self):
"""Wait until more data is received into self._buffer."""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
# -------------------------------------------- #
# Parsing
# -------------------------------------------- #
def data_received(self, data):
# Check for the request itself getting too large and exceeding
# memory limits
self._total_request_size += len(data)
if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge("Payload Too Large"))
if not data:
return self.close()
self._buffer += data
if len(self._buffer) > self.request_max_size:
self.transport.pause_reading()
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
message = "Bad Request"
if self._debug:
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
if self._data_received:
self._data_received.set()
def on_message_begin(self):
assert self.request is None
self.headers = []
# requests count
self.state["requests_count"] += 1
def on_url(self, url):
if not self.url:
self.url = url
def check_timeouts(self):
"""Runs itself once a second to enforce any expired timeouts."""
duration = current_time() - self._time
status = self._status
if (status == Status.IDLE and duration > self.keep_alive_timeout):
logger.debug("KeepAlive Timeout. Closing connection.")
elif (status == Status.REQUEST and duration > self.request_timeout):
self.write_error(RequestTimeout("Request Timeout"))
elif (status.value > Status.REQUEST.value and duration > self.response_timeout):
self.write_error(ServiceUnavailable("Response Timeout"))
else:
self.url += url
self.loop.call_later(1, self.check_timeouts)
return
self.close()
def on_header(self, name, value):
self._header_fragment += name
if value is not None:
if (
self._header_fragment == b"Content-Length"
and int(value) > self.request_max_size
):
self.write_error(PayloadTooLarge("Payload Too Large"))
try:
value = value.decode()
except UnicodeDecodeError:
value = value.decode("latin_1")
self.headers.append(
(self._header_fragment.decode().casefold(), value)
)
async def http1(self):
"""HTTP 1.1 connection handler"""
try:
buf = self._buffer
while self.keep_alive:
# Read request header
pos = 0
self._time = current_time()
while len(buf) < self.request_max_size:
if buf:
self._status = Status.REQUEST
pos = buf.find(b"\r\n\r\n", pos)
if pos >= 0:
break
pos = max(0, len(buf) - 3)
await self.receive_more()
else:
raise PayloadTooLarge("Payload Too Large")
self._header_fragment = b""
self._total_request_size = pos + 4
reqline, *headers = buf[:pos].decode().split("\r\n")
method, self.url, protocol = reqline.split(" ")
assert protocol in ("HTTP/1.0", "HTTP/1.1")
del buf[:pos + 4]
headers = Header(
(name.lower(), value.lstrip())
for name, value in (h.split(":", 1) for h in headers)
)
if headers.get(EXPECT_HEADER):
self._status = Status.EXPECT
self.expect_handler()
self.state["requests_count"] += 1
# Run handler
self.request = self.request_class(
url_bytes=self.url.encode(),
headers=headers,
version="1.1",
method=method,
transport=self.transport,
app=self.app,
)
request_body = (
int(headers.get("content-length", 0)) or
headers.get("transfer-encoding") == "chunked"
)
if request_body:
self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self,
)
self.request.stream._queue.put_nowait(None)
def on_headers_complete(self):
self.request = self.request_class(
url_bytes=self.url,
headers=Header(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport,
app=self.app,
)
# Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request.
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()
self._keep_alive_timeout_handler = None
if self.request.headers.get(EXPECT_HEADER):
self.expect_handler()
self.request.stream = StreamBuffer(
self.request_buffer_queue_size, protocol=self,
)
self.execute_request_handler()
self._status, self._time = Status.HANDLER, current_time()
await self.request_handler(
self.request, self.write_response, self.stream_response
)
self._status, self._time = Status.IDLE, current_time()
except SanicException as e:
self.write_error(e)
except Exception as e:
print(repr(e))
finally:
self.close()
def expect_handler(self):
"""
@ -367,42 +283,6 @@ class HttpProtocol(asyncio.Protocol):
)
)
def on_body(self, body):
# Send request body chunks
if body:
self.request.stream._queue.put_nowait(body)
self.flow_control()
def on_message_complete(self):
# Entire request (headers and whole body) is received.
# We can cancel and remove the request timeout handler now.
if self._request_timeout_handler:
self._request_timeout_handler.cancel()
self._request_timeout_handler = None
# Mark the end of request body
self.request.stream._queue.put_nowait(None)
self.request = None
def execute_request_handler(self):
"""
Invoke the request handler defined by the
:func:`sanic.app.Sanic.handle_request` method
:return: None
"""
# Invoked when request headers have been received
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = time()
self._handler_queue.put_nowait(
self.request_handler(
self.request, self.write_response, self.stream_response
)
)
self.flow_control()
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
@ -445,13 +325,11 @@ class HttpProtocol(asyncio.Protocol):
"""
Writes response content synchronously to the transport.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
self._status, self._time = Status.RESPONSE, current_time()
self._last_response_time = self._time
self.transport.write(
response.output("1.1", keep_alive, self.keep_alive_timeout)
response.output("1.1", self.keep_alive, self.keep_alive_timeout)
)
self.log_response(response)
except AttributeError:
@ -470,24 +348,20 @@ class HttpProtocol(asyncio.Protocol):
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
self.keep_alive = False
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(repr(e))
)
finally:
if not keep_alive:
if not self.keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
async def drain(self):
await self._not_paused.wait()
await self._can_write.wait()
async def push_data(self, data):
self.transport.write(data)
@ -498,14 +372,10 @@ class HttpProtocol(asyncio.Protocol):
the transport to the response so the response consumer can
write to the response as needed.
"""
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
try:
keep_alive = self.keep_alive
self._status, self._time = Status.RESPONSE, current_time()
response.protocol = self
await response.stream("1.1", keep_alive, self.keep_alive_timeout)
await response.stream("1.1", self.keep_alive, self.keep_alive_timeout)
self.log_response(response)
except AttributeError:
logger.error(
@ -521,28 +391,15 @@ class HttpProtocol(asyncio.Protocol):
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
self.keep_alive = False
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(repr(e))
)
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = time()
self.cleanup()
def write_error(self, exception):
# An error _is_ a response.
# Don't throw a response timeout, when a response _is_ given.
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None
response = None
try:
response = self.error_handler.response(self.request, exception)
@ -559,14 +416,9 @@ class HttpProtocol(asyncio.Protocol):
from_error=True,
)
finally:
if self.parser and (
self.keep_alive or getattr(response, "status", 0) == 408
):
if self.keep_alive or getattr(response, "status") == 408:
self.log_response(response)
try:
self.transport.close()
except AttributeError:
logger.debug("Connection lost before server could close it.")
self.close()
def bail_out(self, message, from_error=False):
"""
@ -598,21 +450,13 @@ class HttpProtocol(asyncio.Protocol):
self.write_error(ServerError(message))
logger.error(message)
def cleanup(self):
"""This is called when KeepAlive feature is used,
it resets the connection in order for it to be able
to handle receiving another request on the same connection."""
self.url = None
self.headers = None
self._total_request_size = 0
def close_if_idle(self):
"""Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
if not self.parser:
self.transport.close()
if self._status == Status.IDLE:
self.close()
return True
return False
@ -621,8 +465,12 @@ class HttpProtocol(asyncio.Protocol):
Force close the connection.
"""
if self.transport is not None:
self.transport.close()
self.transport = None
try:
self.keep_alive = False
self._task.cancel()
self.transport.close()
finally:
self.transport = None
def trigger_events(events, loop):