710024125e
* Remove server config args that can be read directly from app. * Linter Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
945 lines
31 KiB
Python
945 lines
31 KiB
Python
import asyncio
|
|
import multiprocessing
|
|
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from collections import deque
|
|
from functools import partial
|
|
from inspect import isawaitable
|
|
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
|
from signal import signal as signal_func
|
|
from socket import SO_REUSEADDR, SOL_SOCKET, socket
|
|
from time import time
|
|
|
|
from httptools import HttpRequestParser # type: ignore
|
|
from httptools.parser.errors import HttpParserError # type: ignore
|
|
|
|
from sanic.compat import Header, ctrlc_workaround_for_windows
|
|
from sanic.exceptions import (
|
|
HeaderExpectationFailed,
|
|
InvalidUsage,
|
|
PayloadTooLarge,
|
|
RequestTimeout,
|
|
ServerError,
|
|
ServiceUnavailable,
|
|
)
|
|
from sanic.log import access_logger, logger
|
|
from sanic.request import EXPECT_HEADER, Request, StreamBuffer
|
|
from sanic.response import HTTPResponse
|
|
|
|
|
|
try:
|
|
import uvloop # type: ignore
|
|
|
|
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
except ImportError:
|
|
pass
|
|
|
|
OS_IS_WINDOWS = os.name == "nt"
|
|
|
|
|
|
class Signal:
|
|
stopped = False
|
|
|
|
|
|
class HttpProtocol(asyncio.Protocol):
|
|
"""
|
|
This class provides a basic HTTP implementation of the sanic framework.
|
|
"""
|
|
|
|
__slots__ = (
|
|
# app
|
|
"app",
|
|
# event loop, connection
|
|
"loop",
|
|
"transport",
|
|
"connections",
|
|
"signal",
|
|
# request params
|
|
"parser",
|
|
"request",
|
|
"url",
|
|
"headers",
|
|
# request config
|
|
"request_handler",
|
|
"request_timeout",
|
|
"response_timeout",
|
|
"keep_alive_timeout",
|
|
"request_max_size",
|
|
"request_buffer_queue_size",
|
|
"request_class",
|
|
"is_request_stream",
|
|
"error_handler",
|
|
# enable or disable access log purpose
|
|
"access_log",
|
|
# connection management
|
|
"_total_request_size",
|
|
"_request_timeout_handler",
|
|
"_response_timeout_handler",
|
|
"_keep_alive_timeout_handler",
|
|
"_last_request_time",
|
|
"_last_response_time",
|
|
"_is_stream_handler",
|
|
"_not_paused",
|
|
"_request_handler_task",
|
|
"_request_stream_task",
|
|
"_keep_alive",
|
|
"_header_fragment",
|
|
"state",
|
|
"_body_chunks",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
loop,
|
|
app,
|
|
signal=Signal(),
|
|
connections=None,
|
|
state=None,
|
|
**kwargs,
|
|
):
|
|
asyncio.set_event_loop(loop)
|
|
self.loop = loop
|
|
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
|
|
self.app = app
|
|
self.transport = None
|
|
self.request = None
|
|
self.parser = None
|
|
self.url = None
|
|
self.headers = None
|
|
self.signal = signal
|
|
self.access_log = self.app.config.ACCESS_LOG
|
|
self.connections = connections if connections is not None else set()
|
|
self.request_handler = self.app.handle_request
|
|
self.error_handler = self.app.error_handler
|
|
self.request_timeout = self.app.config.REQUEST_TIMEOUT
|
|
self.request_buffer_queue_size = (
|
|
self.app.config.REQUEST_BUFFER_QUEUE_SIZE
|
|
)
|
|
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
|
|
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
|
|
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
|
|
self.request_class = self.app.request_class or Request
|
|
self.is_request_stream = self.app.is_request_stream
|
|
self._is_stream_handler = False
|
|
self._not_paused = asyncio.Event(loop=deprecated_loop)
|
|
self._total_request_size = 0
|
|
self._request_timeout_handler = None
|
|
self._response_timeout_handler = None
|
|
self._keep_alive_timeout_handler = None
|
|
self._last_request_time = None
|
|
self._last_response_time = None
|
|
self._request_handler_task = None
|
|
self._request_stream_task = None
|
|
self._keep_alive = self.app.config.KEEP_ALIVE
|
|
self._header_fragment = b""
|
|
self.state = state if state else {}
|
|
if "requests_count" not in self.state:
|
|
self.state["requests_count"] = 0
|
|
self._not_paused.set()
|
|
self._body_chunks = deque()
|
|
|
|
@property
|
|
def keep_alive(self):
|
|
"""
|
|
Check if the connection needs to be kept alive based on the params
|
|
attached to the `_keep_alive` attribute, :attr:`Signal.stopped`
|
|
and :func:`HttpProtocol.parser.should_keep_alive`
|
|
|
|
:return: ``True`` if connection is to be kept alive ``False`` else
|
|
"""
|
|
return (
|
|
self._keep_alive
|
|
and not self.signal.stopped
|
|
and self.parser.should_keep_alive()
|
|
)
|
|
|
|
# -------------------------------------------- #
|
|
# Connection
|
|
# -------------------------------------------- #
|
|
|
|
def connection_made(self, transport):
|
|
self.connections.add(self)
|
|
self._request_timeout_handler = self.loop.call_later(
|
|
self.request_timeout, self.request_timeout_callback
|
|
)
|
|
self.transport = transport
|
|
self._last_request_time = time()
|
|
|
|
def connection_lost(self, exc):
|
|
self.connections.discard(self)
|
|
if self._request_handler_task:
|
|
self._request_handler_task.cancel()
|
|
if self._request_stream_task:
|
|
self._request_stream_task.cancel()
|
|
if self._request_timeout_handler:
|
|
self._request_timeout_handler.cancel()
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
if self._keep_alive_timeout_handler:
|
|
self._keep_alive_timeout_handler.cancel()
|
|
|
|
def pause_writing(self):
|
|
self._not_paused.clear()
|
|
|
|
def resume_writing(self):
|
|
self._not_paused.set()
|
|
|
|
def request_timeout_callback(self):
|
|
# See the docstring in the RequestTimeout exception, to see
|
|
# exactly what this timeout is checking for.
|
|
# Check if elapsed time since request initiated exceeds our
|
|
# configured maximum request timeout value
|
|
time_elapsed = time() - self._last_request_time
|
|
if time_elapsed < self.request_timeout:
|
|
time_left = self.request_timeout - time_elapsed
|
|
self._request_timeout_handler = self.loop.call_later(
|
|
time_left, self.request_timeout_callback
|
|
)
|
|
else:
|
|
if self._request_stream_task:
|
|
self._request_stream_task.cancel()
|
|
if self._request_handler_task:
|
|
self._request_handler_task.cancel()
|
|
self.write_error(RequestTimeout("Request Timeout"))
|
|
|
|
def response_timeout_callback(self):
|
|
# Check if elapsed time since response was initiated exceeds our
|
|
# configured maximum request timeout value
|
|
time_elapsed = time() - self._last_request_time
|
|
if time_elapsed < self.response_timeout:
|
|
time_left = self.response_timeout - time_elapsed
|
|
self._response_timeout_handler = self.loop.call_later(
|
|
time_left, self.response_timeout_callback
|
|
)
|
|
else:
|
|
if self._request_stream_task:
|
|
self._request_stream_task.cancel()
|
|
if self._request_handler_task:
|
|
self._request_handler_task.cancel()
|
|
self.write_error(ServiceUnavailable("Response Timeout"))
|
|
|
|
def keep_alive_timeout_callback(self):
|
|
"""
|
|
Check if elapsed time since last response exceeds our configured
|
|
maximum keep alive timeout value and if so, close the transport
|
|
pipe and let the response writer handle the error.
|
|
|
|
:return: None
|
|
"""
|
|
time_elapsed = time() - self._last_response_time
|
|
if time_elapsed < self.keep_alive_timeout:
|
|
time_left = self.keep_alive_timeout - time_elapsed
|
|
self._keep_alive_timeout_handler = self.loop.call_later(
|
|
time_left, self.keep_alive_timeout_callback
|
|
)
|
|
else:
|
|
logger.debug("KeepAlive Timeout. Closing connection.")
|
|
self.transport.close()
|
|
self.transport = None
|
|
|
|
# -------------------------------------------- #
|
|
# Parsing
|
|
# -------------------------------------------- #
|
|
|
|
def data_received(self, data):
|
|
# Check for the request itself getting too large and exceeding
|
|
# memory limits
|
|
self._total_request_size += len(data)
|
|
if self._total_request_size > self.request_max_size:
|
|
self.write_error(PayloadTooLarge("Payload Too Large"))
|
|
|
|
# Create parser if this is the first time we're receiving data
|
|
if self.parser is None:
|
|
assert self.request is None
|
|
self.headers = []
|
|
self.parser = HttpRequestParser(self)
|
|
|
|
# requests count
|
|
self.state["requests_count"] = self.state["requests_count"] + 1
|
|
|
|
# Parse request chunk or close connection
|
|
try:
|
|
self.parser.feed_data(data)
|
|
except HttpParserError:
|
|
message = "Bad Request"
|
|
if self.app.debug:
|
|
message += "\n" + traceback.format_exc()
|
|
self.write_error(InvalidUsage(message))
|
|
|
|
def on_url(self, url):
|
|
if not self.url:
|
|
self.url = url
|
|
else:
|
|
self.url += url
|
|
|
|
def on_header(self, name, value):
|
|
self._header_fragment += name
|
|
|
|
if value is not None:
|
|
if (
|
|
self._header_fragment == b"Content-Length"
|
|
and int(value) > self.request_max_size
|
|
):
|
|
self.write_error(PayloadTooLarge("Payload Too Large"))
|
|
try:
|
|
value = value.decode()
|
|
except UnicodeDecodeError:
|
|
value = value.decode("latin_1")
|
|
self.headers.append(
|
|
(self._header_fragment.decode().casefold(), value)
|
|
)
|
|
|
|
self._header_fragment = b""
|
|
|
|
def on_headers_complete(self):
|
|
self.request = self.request_class(
|
|
url_bytes=self.url,
|
|
headers=Header(self.headers),
|
|
version=self.parser.get_http_version(),
|
|
method=self.parser.get_method().decode(),
|
|
transport=self.transport,
|
|
app=self.app,
|
|
)
|
|
# Remove any existing KeepAlive handler here,
|
|
# It will be recreated if required on the new request.
|
|
if self._keep_alive_timeout_handler:
|
|
self._keep_alive_timeout_handler.cancel()
|
|
self._keep_alive_timeout_handler = None
|
|
|
|
if self.request.headers.get(EXPECT_HEADER):
|
|
self.expect_handler()
|
|
|
|
if self.is_request_stream:
|
|
self._is_stream_handler = self.app.router.is_stream_handler(
|
|
self.request
|
|
)
|
|
if self._is_stream_handler:
|
|
self.request.stream = StreamBuffer(
|
|
self.request_buffer_queue_size
|
|
)
|
|
self.execute_request_handler()
|
|
|
|
def expect_handler(self):
|
|
"""
|
|
Handler for Expect Header.
|
|
"""
|
|
expect = self.request.headers.get(EXPECT_HEADER)
|
|
if self.request.version == "1.1":
|
|
if expect.lower() == "100-continue":
|
|
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
|
else:
|
|
self.write_error(
|
|
HeaderExpectationFailed(f"Unknown Expect: {expect}")
|
|
)
|
|
|
|
def on_body(self, body):
|
|
if self.is_request_stream and self._is_stream_handler:
|
|
# body chunks can be put into asyncio.Queue out of order if
|
|
# multiple tasks put concurrently and the queue is full in python
|
|
# 3.7. so we should not create more than one task putting into the
|
|
# queue simultaneously.
|
|
self._body_chunks.append(body)
|
|
if (
|
|
not self._request_stream_task
|
|
or self._request_stream_task.done()
|
|
):
|
|
self._request_stream_task = self.loop.create_task(
|
|
self.stream_append()
|
|
)
|
|
else:
|
|
self.request.body_push(body)
|
|
|
|
async def body_append(self, body):
|
|
if (
|
|
self.request is None
|
|
or self._request_stream_task is None
|
|
or self._request_stream_task.cancelled()
|
|
):
|
|
return
|
|
|
|
if self.request.stream.is_full():
|
|
self.transport.pause_reading()
|
|
await self.request.stream.put(body)
|
|
self.transport.resume_reading()
|
|
else:
|
|
await self.request.stream.put(body)
|
|
|
|
async def stream_append(self):
|
|
while self._body_chunks:
|
|
body = self._body_chunks.popleft()
|
|
if self.request.stream.is_full():
|
|
self.transport.pause_reading()
|
|
await self.request.stream.put(body)
|
|
self.transport.resume_reading()
|
|
else:
|
|
await self.request.stream.put(body)
|
|
|
|
def on_message_complete(self):
|
|
# Entire request (headers and whole body) is received.
|
|
# We can cancel and remove the request timeout handler now.
|
|
if self._request_timeout_handler:
|
|
self._request_timeout_handler.cancel()
|
|
self._request_timeout_handler = None
|
|
if self.is_request_stream and self._is_stream_handler:
|
|
self._body_chunks.append(None)
|
|
if (
|
|
not self._request_stream_task
|
|
or self._request_stream_task.done()
|
|
):
|
|
self._request_stream_task = self.loop.create_task(
|
|
self.stream_append()
|
|
)
|
|
return
|
|
self.request.body_finish()
|
|
self.execute_request_handler()
|
|
|
|
def execute_request_handler(self):
|
|
"""
|
|
Invoke the request handler defined by the
|
|
:func:`sanic.app.Sanic.handle_request` method
|
|
|
|
:return: None
|
|
"""
|
|
self._response_timeout_handler = self.loop.call_later(
|
|
self.response_timeout, self.response_timeout_callback
|
|
)
|
|
self._last_request_time = time()
|
|
self._request_handler_task = self.loop.create_task(
|
|
self.request_handler(
|
|
self.request, self.write_response, self.stream_response
|
|
)
|
|
)
|
|
|
|
# -------------------------------------------- #
|
|
# Responding
|
|
# -------------------------------------------- #
|
|
def log_response(self, response):
|
|
"""
|
|
Helper method provided to enable the logging of responses in case if
|
|
the :attr:`HttpProtocol.access_log` is enabled.
|
|
|
|
:param response: Response generated for the current request
|
|
|
|
:type response: :class:`sanic.response.HTTPResponse` or
|
|
:class:`sanic.response.StreamingHTTPResponse`
|
|
|
|
:return: None
|
|
"""
|
|
if self.access_log:
|
|
extra = {"status": getattr(response, "status", 0)}
|
|
|
|
if isinstance(response, HTTPResponse):
|
|
extra["byte"] = len(response.body)
|
|
else:
|
|
extra["byte"] = -1
|
|
|
|
extra["host"] = "UNKNOWN"
|
|
if self.request is not None:
|
|
if self.request.ip:
|
|
extra["host"] = f"{self.request.ip}:{self.request.port}"
|
|
|
|
extra["request"] = f"{self.request.method} {self.request.url}"
|
|
else:
|
|
extra["request"] = "nil"
|
|
|
|
access_logger.info("", extra=extra)
|
|
|
|
def write_response(self, response):
|
|
"""
|
|
Writes response content synchronously to the transport.
|
|
"""
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
try:
|
|
keep_alive = self.keep_alive
|
|
self.transport.write(
|
|
response.output(
|
|
self.request.version, keep_alive, self.keep_alive_timeout
|
|
)
|
|
)
|
|
self.log_response(response)
|
|
except AttributeError:
|
|
logger.error(
|
|
"Invalid response object for url %s, "
|
|
"Expected Type: HTTPResponse, Actual Type: %s",
|
|
self.url,
|
|
type(response),
|
|
)
|
|
self.write_error(ServerError("Invalid response type"))
|
|
except RuntimeError:
|
|
if self.app.debug:
|
|
logger.error(
|
|
"Connection lost before response written @ %s",
|
|
self.request.ip,
|
|
)
|
|
keep_alive = False
|
|
except Exception as e:
|
|
self.bail_out(f"Writing response failed, connection closed {e!r}")
|
|
finally:
|
|
if not keep_alive:
|
|
self.transport.close()
|
|
self.transport = None
|
|
else:
|
|
self._keep_alive_timeout_handler = self.loop.call_later(
|
|
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
|
)
|
|
self._last_response_time = time()
|
|
self.cleanup()
|
|
|
|
async def drain(self):
|
|
await self._not_paused.wait()
|
|
|
|
async def push_data(self, data):
|
|
self.transport.write(data)
|
|
|
|
async def stream_response(self, response):
|
|
"""
|
|
Streams a response to the client asynchronously. Attaches
|
|
the transport to the response so the response consumer can
|
|
write to the response as needed.
|
|
"""
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
|
|
try:
|
|
keep_alive = self.keep_alive
|
|
response.protocol = self
|
|
await response.stream(
|
|
self.request.version, keep_alive, self.keep_alive_timeout
|
|
)
|
|
self.log_response(response)
|
|
except AttributeError:
|
|
logger.error(
|
|
"Invalid response object for url %s, "
|
|
"Expected Type: HTTPResponse, Actual Type: %s",
|
|
self.url,
|
|
type(response),
|
|
)
|
|
self.write_error(ServerError("Invalid response type"))
|
|
except RuntimeError:
|
|
if self.app.debug:
|
|
logger.error(
|
|
"Connection lost before response written @ %s",
|
|
self.request.ip,
|
|
)
|
|
keep_alive = False
|
|
except Exception as e:
|
|
self.bail_out(f"Writing response failed, connection closed {e!r}")
|
|
finally:
|
|
if not keep_alive:
|
|
self.transport.close()
|
|
self.transport = None
|
|
else:
|
|
self._keep_alive_timeout_handler = self.loop.call_later(
|
|
self.keep_alive_timeout, self.keep_alive_timeout_callback
|
|
)
|
|
self._last_response_time = time()
|
|
self.cleanup()
|
|
|
|
def write_error(self, exception):
|
|
# An error _is_ a response.
|
|
# Don't throw a response timeout, when a response _is_ given.
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
response = None
|
|
try:
|
|
response = self.error_handler.response(self.request, exception)
|
|
version = self.request.version if self.request else "1.1"
|
|
self.transport.write(response.output(version))
|
|
except RuntimeError:
|
|
if self.app.debug:
|
|
logger.error(
|
|
"Connection lost before error written @ %s",
|
|
self.request.ip if self.request else "Unknown",
|
|
)
|
|
except Exception as e:
|
|
self.bail_out(
|
|
f"Writing error failed, connection closed {e!r}",
|
|
from_error=True,
|
|
)
|
|
finally:
|
|
if self.parser and (
|
|
self.keep_alive or getattr(response, "status", 0) == 408
|
|
):
|
|
self.log_response(response)
|
|
try:
|
|
self.transport.close()
|
|
except AttributeError:
|
|
logger.debug("Connection lost before server could close it.")
|
|
|
|
def bail_out(self, message, from_error=False):
|
|
"""
|
|
In case if the transport pipes are closed and the sanic app encounters
|
|
an error while writing data to the transport pipe, we log the error
|
|
with proper details.
|
|
|
|
:param message: Error message to display
|
|
:param from_error: If the bail out was invoked while handling an
|
|
exception scenario.
|
|
|
|
:type message: str
|
|
:type from_error: bool
|
|
|
|
:return: None
|
|
"""
|
|
if from_error or self.transport is None or self.transport.is_closing():
|
|
logger.error(
|
|
"Transport closed @ %s and exception "
|
|
"experienced during error handling",
|
|
(
|
|
self.transport.get_extra_info("peername")
|
|
if self.transport is not None
|
|
else "N/A"
|
|
),
|
|
)
|
|
logger.debug("Exception:", exc_info=True)
|
|
else:
|
|
self.write_error(ServerError(message))
|
|
logger.error(message)
|
|
|
|
def cleanup(self):
|
|
"""This is called when KeepAlive feature is used,
|
|
it resets the connection in order for it to be able
|
|
to handle receiving another request on the same connection."""
|
|
self.parser = None
|
|
self.request = None
|
|
self.url = None
|
|
self.headers = None
|
|
self._request_handler_task = None
|
|
self._request_stream_task = None
|
|
self._total_request_size = 0
|
|
self._is_stream_handler = False
|
|
|
|
def close_if_idle(self):
|
|
"""Close the connection if a request is not being sent or received
|
|
|
|
:return: boolean - True if closed, false if staying open
|
|
"""
|
|
if not self.parser and self.transport is not None:
|
|
self.transport.close()
|
|
return True
|
|
return False
|
|
|
|
def close(self):
|
|
"""
|
|
Force close the connection.
|
|
"""
|
|
if self.transport is not None:
|
|
self.transport.close()
|
|
self.transport = None
|
|
|
|
|
|
def trigger_events(events, loop):
|
|
"""Trigger event callbacks (functions or async)
|
|
|
|
:param events: one or more sync or async functions to execute
|
|
:param loop: event loop
|
|
"""
|
|
for event in events:
|
|
result = event(loop)
|
|
if isawaitable(result):
|
|
loop.run_until_complete(result)
|
|
|
|
|
|
class AsyncioServer:
|
|
"""
|
|
Wraps an asyncio server with functionality that might be useful to
|
|
a user who needs to manage the server lifecycle manually.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"loop",
|
|
"serve_coro",
|
|
"_after_start",
|
|
"_before_stop",
|
|
"_after_stop",
|
|
"server",
|
|
"connections",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
loop,
|
|
serve_coro,
|
|
connections,
|
|
after_start,
|
|
before_stop,
|
|
after_stop,
|
|
):
|
|
# Note, Sanic already called "before_server_start" events
|
|
# before this helper was even created. So we don't need it here.
|
|
self.loop = loop
|
|
self.serve_coro = serve_coro
|
|
self._after_start = after_start
|
|
self._before_stop = before_stop
|
|
self._after_stop = after_stop
|
|
self.server = None
|
|
self.connections = connections
|
|
|
|
def after_start(self):
|
|
"""Trigger "after_server_start" events"""
|
|
trigger_events(self._after_start, self.loop)
|
|
|
|
def before_stop(self):
|
|
"""Trigger "before_server_stop" events"""
|
|
trigger_events(self._before_stop, self.loop)
|
|
|
|
def after_stop(self):
|
|
"""Trigger "after_server_stop" events"""
|
|
trigger_events(self._after_stop, self.loop)
|
|
|
|
def is_serving(self):
|
|
if self.server:
|
|
return self.server.is_serving()
|
|
return False
|
|
|
|
def wait_closed(self):
|
|
if self.server:
|
|
return self.server.wait_closed()
|
|
|
|
def close(self):
|
|
if self.server:
|
|
self.server.close()
|
|
coro = self.wait_closed()
|
|
task = asyncio.ensure_future(coro, loop=self.loop)
|
|
return task
|
|
|
|
def start_serving(self):
|
|
if self.server:
|
|
try:
|
|
return self.server.start_serving()
|
|
except AttributeError:
|
|
raise NotImplementedError(
|
|
"server.start_serving not available in this version "
|
|
"of asyncio or uvloop."
|
|
)
|
|
|
|
def serve_forever(self):
|
|
if self.server:
|
|
try:
|
|
return self.server.serve_forever()
|
|
except AttributeError:
|
|
raise NotImplementedError(
|
|
"server.serve_forever not available in this version "
|
|
"of asyncio or uvloop."
|
|
)
|
|
|
|
def __await__(self):
|
|
"""Starts the asyncio server, returns AsyncServerCoro"""
|
|
task = asyncio.ensure_future(self.serve_coro)
|
|
while not task.done():
|
|
yield
|
|
self.server = task.result()
|
|
return self
|
|
|
|
|
|
def serve(
|
|
host,
|
|
port,
|
|
app,
|
|
before_start=None,
|
|
after_start=None,
|
|
before_stop=None,
|
|
after_stop=None,
|
|
ssl=None,
|
|
sock=None,
|
|
reuse_port=False,
|
|
loop=None,
|
|
protocol=HttpProtocol,
|
|
backlog=100,
|
|
register_sys_signals=True,
|
|
run_multiple=False,
|
|
run_async=False,
|
|
connections=None,
|
|
signal=Signal(),
|
|
state=None,
|
|
asyncio_server_kwargs=None,
|
|
):
|
|
"""Start asynchronous HTTP Server on an individual process.
|
|
|
|
:param host: Address to host on
|
|
:param port: Port to host on
|
|
:param before_start: function to be executed before the server starts
|
|
listening. Takes arguments `app` instance and `loop`
|
|
:param after_start: function to be executed after the server starts
|
|
listening. Takes arguments `app` instance and `loop`
|
|
:param before_stop: function to be executed when a stop signal is
|
|
received before it is respected. Takes arguments
|
|
`app` instance and `loop`
|
|
:param after_stop: function to be executed when a stop signal is
|
|
received after it is respected. Takes arguments
|
|
`app` instance and `loop`
|
|
:param ssl: SSLContext
|
|
:param sock: Socket for the server to accept connections from
|
|
:param reuse_port: `True` for multiple workers
|
|
:param loop: asyncio compatible event loop
|
|
:param run_async: bool: Do not create a new event loop for the server,
|
|
and return an AsyncServer object rather than running it
|
|
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
|
|
create_server method
|
|
:return: Nothing
|
|
"""
|
|
if not run_async:
|
|
# create new event_loop after fork
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
if app.debug:
|
|
loop.set_debug(app.debug)
|
|
|
|
app.asgi = False
|
|
|
|
connections = connections if connections is not None else set()
|
|
server = partial(
|
|
protocol,
|
|
loop=loop,
|
|
connections=connections,
|
|
signal=signal,
|
|
app=app,
|
|
state=state,
|
|
)
|
|
asyncio_server_kwargs = (
|
|
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
|
)
|
|
server_coroutine = loop.create_server(
|
|
server,
|
|
host,
|
|
port,
|
|
ssl=ssl,
|
|
reuse_port=reuse_port,
|
|
sock=sock,
|
|
backlog=backlog,
|
|
**asyncio_server_kwargs,
|
|
)
|
|
|
|
if run_async:
|
|
return AsyncioServer(
|
|
loop=loop,
|
|
serve_coro=server_coroutine,
|
|
connections=connections,
|
|
after_start=after_start,
|
|
before_stop=before_stop,
|
|
after_stop=after_stop,
|
|
)
|
|
|
|
trigger_events(before_start, loop)
|
|
|
|
try:
|
|
http_server = loop.run_until_complete(server_coroutine)
|
|
except BaseException:
|
|
logger.exception("Unable to start server")
|
|
return
|
|
|
|
trigger_events(after_start, loop)
|
|
|
|
# Ignore SIGINT when run_multiple
|
|
if run_multiple:
|
|
signal_func(SIGINT, SIG_IGN)
|
|
|
|
# Register signals for graceful termination
|
|
if register_sys_signals:
|
|
if OS_IS_WINDOWS:
|
|
ctrlc_workaround_for_windows(app)
|
|
else:
|
|
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
|
|
loop.add_signal_handler(_signal, app.stop)
|
|
pid = os.getpid()
|
|
try:
|
|
logger.info("Starting worker [%s]", pid)
|
|
loop.run_forever()
|
|
finally:
|
|
logger.info("Stopping worker [%s]", pid)
|
|
|
|
# Run the on_stop function if provided
|
|
trigger_events(before_stop, loop)
|
|
|
|
# Wait for event loop to finish and all connections to drain
|
|
http_server.close()
|
|
loop.run_until_complete(http_server.wait_closed())
|
|
|
|
# Complete all tasks on the loop
|
|
signal.stopped = True
|
|
for connection in connections:
|
|
connection.close_if_idle()
|
|
|
|
# Gracefully shutdown timeout.
|
|
# We should provide graceful_shutdown_timeout,
|
|
# instead of letting connection hangs forever.
|
|
# Let's roughly calcucate time.
|
|
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
|
start_shutdown = 0
|
|
while connections and (start_shutdown < graceful):
|
|
loop.run_until_complete(asyncio.sleep(0.1))
|
|
start_shutdown = start_shutdown + 0.1
|
|
|
|
# Force close non-idle connection after waiting for
|
|
# graceful_shutdown_timeout
|
|
coros = []
|
|
for conn in connections:
|
|
if hasattr(conn, "websocket") and conn.websocket:
|
|
coros.append(conn.websocket.close_connection())
|
|
else:
|
|
conn.close()
|
|
|
|
_shutdown = asyncio.gather(*coros)
|
|
loop.run_until_complete(_shutdown)
|
|
|
|
trigger_events(after_stop, loop)
|
|
|
|
loop.close()
|
|
|
|
|
|
def serve_multiple(server_settings, workers):
|
|
"""Start multiple server processes simultaneously. Stop on interrupt
|
|
and terminate signals, and drain connections when complete.
|
|
|
|
:param server_settings: kw arguments to be passed to the serve function
|
|
:param workers: number of workers to launch
|
|
:param stop_event: if provided, is used as a stop signal
|
|
:return:
|
|
"""
|
|
server_settings["reuse_port"] = True
|
|
server_settings["run_multiple"] = True
|
|
|
|
# Handling when custom socket is not provided.
|
|
if server_settings.get("sock") is None:
|
|
sock = socket()
|
|
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
|
|
sock.bind((server_settings["host"], server_settings["port"]))
|
|
sock.set_inheritable(True)
|
|
server_settings["sock"] = sock
|
|
server_settings["host"] = None
|
|
server_settings["port"] = None
|
|
|
|
processes = []
|
|
|
|
def sig_handler(signal, frame):
|
|
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
|
|
for process in processes:
|
|
os.kill(process.pid, SIGTERM)
|
|
|
|
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
|
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
|
mp = multiprocessing.get_context("fork")
|
|
|
|
for _ in range(workers):
|
|
process = mp.Process(target=serve, kwargs=server_settings)
|
|
process.daemon = True
|
|
process.start()
|
|
processes.append(process)
|
|
|
|
for process in processes:
|
|
process.join()
|
|
|
|
# the above processes will block this until they're stopped
|
|
for process in processes:
|
|
process.terminate()
|
|
server_settings.get("sock").close()
|