* This commit adds handlers for the asyncio/uvloop protocol callbacks for pause_writing and resume_writing. These are needed for the correct functioning of built-in tcp flow-control provided by uvloop and asyncio. This is somewhat of a breaking change, because the `write` function in user streaming callbacks now must be `await`ed. This is necessary because it is possible now that the http protocol may be paused, and any calls to write may need to wait on an async event to be called to become unpaused. Updated examples and tests to reflect this change. This change does not apply to websocket connections. A change to websocket connections may be required to match this change. * Fix a couple of PEP8 errors caused by previous rebase. * update docs add await syntax to response.write in response-streaming docs. * remove commented out code from a test file
724 lines
26 KiB
Python
724 lines
26 KiB
Python
import asyncio
|
|
import os
|
|
import traceback
|
|
from functools import partial
|
|
from inspect import isawaitable
|
|
from multiprocessing import Process
|
|
from signal import (
|
|
SIGTERM, SIGINT, SIG_IGN,
|
|
signal as signal_func,
|
|
Signals
|
|
)
|
|
from socket import (
|
|
socket,
|
|
SOL_SOCKET,
|
|
SO_REUSEADDR,
|
|
)
|
|
from time import time
|
|
|
|
from httptools import HttpRequestParser
|
|
from httptools.parser.errors import HttpParserError
|
|
from multidict import CIMultiDict
|
|
|
|
try:
|
|
import uvloop
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
except ImportError:
|
|
pass
|
|
|
|
from sanic.log import logger, access_logger
|
|
from sanic.response import HTTPResponse
|
|
from sanic.request import Request
|
|
from sanic.exceptions import (
|
|
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError,
|
|
ServiceUnavailable)
|
|
|
|
current_time = None
|
|
|
|
|
|
class Signal:
|
|
stopped = False
|
|
|
|
|
|
class HttpProtocol(asyncio.Protocol):
|
|
__slots__ = (
|
|
# event loop, connection
|
|
'loop', 'transport', 'connections', 'signal',
|
|
# request params
|
|
'parser', 'request', 'url', 'headers',
|
|
# request config
|
|
'request_handler', 'request_timeout', 'response_timeout',
|
|
'keep_alive_timeout', 'request_max_size', 'request_class',
|
|
'is_request_stream', 'router',
|
|
# enable or disable access log purpose
|
|
'access_log',
|
|
# connection management
|
|
'_total_request_size', '_request_timeout_handler',
|
|
'_response_timeout_handler', '_keep_alive_timeout_handler',
|
|
'_last_request_time', '_last_response_time', '_is_stream_handler',
|
|
'_not_paused')
|
|
|
|
def __init__(self, *, loop, request_handler, error_handler,
|
|
signal=Signal(), connections=set(), request_timeout=60,
|
|
response_timeout=60, keep_alive_timeout=5,
|
|
request_max_size=None, request_class=None, access_log=True,
|
|
keep_alive=True, is_request_stream=False, router=None,
|
|
state=None, debug=False, **kwargs):
|
|
self.loop = loop
|
|
self.transport = None
|
|
self.request = None
|
|
self.parser = None
|
|
self.url = None
|
|
self.headers = None
|
|
self.router = router
|
|
self.signal = signal
|
|
self.access_log = access_log
|
|
self.connections = connections
|
|
self.request_handler = request_handler
|
|
self.error_handler = error_handler
|
|
self.request_timeout = request_timeout
|
|
self.response_timeout = response_timeout
|
|
self.keep_alive_timeout = keep_alive_timeout
|
|
self.request_max_size = request_max_size
|
|
self.request_class = request_class or Request
|
|
self.is_request_stream = is_request_stream
|
|
self._is_stream_handler = False
|
|
self._not_paused = asyncio.Event(loop=loop)
|
|
self._total_request_size = 0
|
|
self._request_timeout_handler = None
|
|
self._response_timeout_handler = None
|
|
self._keep_alive_timeout_handler = None
|
|
self._last_request_time = None
|
|
self._last_response_time = None
|
|
self._request_handler_task = None
|
|
self._request_stream_task = None
|
|
self._keep_alive = keep_alive
|
|
self._header_fragment = b''
|
|
self.state = state if state else {}
|
|
if 'requests_count' not in self.state:
|
|
self.state['requests_count'] = 0
|
|
self._debug = debug
|
|
self._not_paused.set()
|
|
|
|
@property
|
|
def keep_alive(self):
|
|
return (
|
|
self._keep_alive and
|
|
not self.signal.stopped and
|
|
self.parser.should_keep_alive())
|
|
|
|
# -------------------------------------------- #
|
|
# Connection
|
|
# -------------------------------------------- #
|
|
|
|
def connection_made(self, transport):
|
|
self.connections.add(self)
|
|
self._request_timeout_handler = self.loop.call_later(
|
|
self.request_timeout, self.request_timeout_callback)
|
|
self.transport = transport
|
|
self._last_request_time = current_time
|
|
|
|
def connection_lost(self, exc):
|
|
self.connections.discard(self)
|
|
if self._request_timeout_handler:
|
|
self._request_timeout_handler.cancel()
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
if self._keep_alive_timeout_handler:
|
|
self._keep_alive_timeout_handler.cancel()
|
|
|
|
def pause_writing(self):
|
|
self._not_paused.clear()
|
|
|
|
def resume_writing(self):
|
|
self._not_paused.set()
|
|
|
|
def request_timeout_callback(self):
|
|
# See the docstring in the RequestTimeout exception, to see
|
|
# exactly what this timeout is checking for.
|
|
# Check if elapsed time since request initiated exceeds our
|
|
# configured maximum request timeout value
|
|
time_elapsed = current_time - self._last_request_time
|
|
if time_elapsed < self.request_timeout:
|
|
time_left = self.request_timeout - time_elapsed
|
|
self._request_timeout_handler = (
|
|
self.loop.call_later(time_left,
|
|
self.request_timeout_callback)
|
|
)
|
|
else:
|
|
if self._request_stream_task:
|
|
self._request_stream_task.cancel()
|
|
if self._request_handler_task:
|
|
self._request_handler_task.cancel()
|
|
try:
|
|
raise RequestTimeout('Request Timeout')
|
|
except RequestTimeout as exception:
|
|
self.write_error(exception)
|
|
|
|
def response_timeout_callback(self):
|
|
# Check if elapsed time since response was initiated exceeds our
|
|
# configured maximum request timeout value
|
|
time_elapsed = current_time - self._last_request_time
|
|
if time_elapsed < self.response_timeout:
|
|
time_left = self.response_timeout - time_elapsed
|
|
self._response_timeout_handler = (
|
|
self.loop.call_later(time_left,
|
|
self.response_timeout_callback)
|
|
)
|
|
else:
|
|
if self._request_stream_task:
|
|
self._request_stream_task.cancel()
|
|
if self._request_handler_task:
|
|
self._request_handler_task.cancel()
|
|
try:
|
|
raise ServiceUnavailable('Response Timeout')
|
|
except ServiceUnavailable as exception:
|
|
self.write_error(exception)
|
|
|
|
def keep_alive_timeout_callback(self):
|
|
# Check if elapsed time since last response exceeds our configured
|
|
# maximum keep alive timeout value
|
|
time_elapsed = current_time - self._last_response_time
|
|
if time_elapsed < self.keep_alive_timeout:
|
|
time_left = self.keep_alive_timeout - time_elapsed
|
|
self._keep_alive_timeout_handler = (
|
|
self.loop.call_later(time_left,
|
|
self.keep_alive_timeout_callback)
|
|
)
|
|
else:
|
|
logger.debug('KeepAlive Timeout. Closing connection.')
|
|
self.transport.close()
|
|
self.transport = None
|
|
|
|
# -------------------------------------------- #
|
|
# Parsing
|
|
# -------------------------------------------- #
|
|
|
|
def data_received(self, data):
|
|
# Check for the request itself getting too large and exceeding
|
|
# memory limits
|
|
self._total_request_size += len(data)
|
|
if self._total_request_size > self.request_max_size:
|
|
exception = PayloadTooLarge('Payload Too Large')
|
|
self.write_error(exception)
|
|
|
|
# Create parser if this is the first time we're receiving data
|
|
if self.parser is None:
|
|
assert self.request is None
|
|
self.headers = []
|
|
self.parser = HttpRequestParser(self)
|
|
|
|
# requests count
|
|
self.state['requests_count'] = self.state['requests_count'] + 1
|
|
|
|
# Parse request chunk or close connection
|
|
try:
|
|
self.parser.feed_data(data)
|
|
except HttpParserError:
|
|
message = 'Bad Request'
|
|
if self._debug:
|
|
message += '\n' + traceback.format_exc()
|
|
exception = InvalidUsage(message)
|
|
self.write_error(exception)
|
|
|
|
def on_url(self, url):
|
|
if not self.url:
|
|
self.url = url
|
|
else:
|
|
self.url += url
|
|
|
|
def on_header(self, name, value):
|
|
self._header_fragment += name
|
|
|
|
if value is not None:
|
|
if self._header_fragment == b'Content-Length' \
|
|
and int(value) > self.request_max_size:
|
|
exception = PayloadTooLarge('Payload Too Large')
|
|
self.write_error(exception)
|
|
try:
|
|
value = value.decode()
|
|
except UnicodeDecodeError:
|
|
value = value.decode('latin_1')
|
|
self.headers.append(
|
|
(self._header_fragment.decode().casefold(), value))
|
|
|
|
self._header_fragment = b''
|
|
|
|
def on_headers_complete(self):
|
|
self.request = self.request_class(
|
|
url_bytes=self.url,
|
|
headers=CIMultiDict(self.headers),
|
|
version=self.parser.get_http_version(),
|
|
method=self.parser.get_method().decode(),
|
|
transport=self.transport
|
|
)
|
|
# Remove any existing KeepAlive handler here,
|
|
# It will be recreated if required on the new request.
|
|
if self._keep_alive_timeout_handler:
|
|
self._keep_alive_timeout_handler.cancel()
|
|
self._keep_alive_timeout_handler = None
|
|
if self.is_request_stream:
|
|
self._is_stream_handler = self.router.is_stream_handler(
|
|
self.request)
|
|
if self._is_stream_handler:
|
|
self.request.stream = asyncio.Queue()
|
|
self.execute_request_handler()
|
|
|
|
def on_body(self, body):
|
|
if self.is_request_stream and self._is_stream_handler:
|
|
self._request_stream_task = self.loop.create_task(
|
|
self.request.stream.put(body))
|
|
return
|
|
self.request.body.append(body)
|
|
|
|
def on_message_complete(self):
|
|
# Entire request (headers and whole body) is received.
|
|
# We can cancel and remove the request timeout handler now.
|
|
if self._request_timeout_handler:
|
|
self._request_timeout_handler.cancel()
|
|
self._request_timeout_handler = None
|
|
if self.is_request_stream and self._is_stream_handler:
|
|
self._request_stream_task = self.loop.create_task(
|
|
self.request.stream.put(None))
|
|
return
|
|
self.request.body = b''.join(self.request.body)
|
|
self.execute_request_handler()
|
|
|
|
def execute_request_handler(self):
|
|
self._response_timeout_handler = self.loop.call_later(
|
|
self.response_timeout, self.response_timeout_callback)
|
|
self._last_request_time = current_time
|
|
self._request_handler_task = self.loop.create_task(
|
|
self.request_handler(
|
|
self.request,
|
|
self.write_response,
|
|
self.stream_response))
|
|
|
|
# -------------------------------------------- #
|
|
# Responding
|
|
# -------------------------------------------- #
|
|
def log_response(self, response):
|
|
if self.access_log:
|
|
extra = {
|
|
'status': getattr(response, 'status', 0),
|
|
}
|
|
|
|
if isinstance(response, HTTPResponse):
|
|
extra['byte'] = len(response.body)
|
|
else:
|
|
extra['byte'] = -1
|
|
|
|
extra['host'] = 'UNKNOWN'
|
|
if self.request is not None:
|
|
if self.request.ip:
|
|
extra['host'] = '{0}:{1}'.format(self.request.ip,
|
|
self.request.port)
|
|
|
|
extra['request'] = '{0} {1}'.format(self.request.method,
|
|
self.request.url)
|
|
else:
|
|
extra['request'] = 'nil'
|
|
|
|
access_logger.info('', extra=extra)
|
|
|
|
def write_response(self, response):
|
|
"""
|
|
Writes response content synchronously to the transport.
|
|
"""
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
try:
|
|
keep_alive = self.keep_alive
|
|
self.transport.write(
|
|
response.output(
|
|
self.request.version, keep_alive,
|
|
self.keep_alive_timeout))
|
|
self.log_response(response)
|
|
except AttributeError:
|
|
logger.error('Invalid response object for url %s, '
|
|
'Expected Type: HTTPResponse, Actual Type: %s',
|
|
self.url, type(response))
|
|
self.write_error(ServerError('Invalid response type'))
|
|
except RuntimeError:
|
|
if self._debug:
|
|
logger.error('Connection lost before response written @ %s',
|
|
self.request.ip)
|
|
keep_alive = False
|
|
except Exception as e:
|
|
self.bail_out(
|
|
"Writing response failed, connection closed {}".format(
|
|
repr(e)))
|
|
finally:
|
|
if not keep_alive:
|
|
self.transport.close()
|
|
self.transport = None
|
|
else:
|
|
self._keep_alive_timeout_handler = self.loop.call_later(
|
|
self.keep_alive_timeout,
|
|
self.keep_alive_timeout_callback)
|
|
self._last_response_time = current_time
|
|
self.cleanup()
|
|
|
|
async def drain(self):
|
|
await self._not_paused.wait()
|
|
|
|
def push_data(self, data):
|
|
self.transport.write(data)
|
|
|
|
async def stream_response(self, response):
|
|
"""
|
|
Streams a response to the client asynchronously. Attaches
|
|
the transport to the response so the response consumer can
|
|
write to the response as needed.
|
|
"""
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
|
|
try:
|
|
keep_alive = self.keep_alive
|
|
response.protocol = self
|
|
await response.stream(
|
|
self.request.version, keep_alive, self.keep_alive_timeout)
|
|
self.log_response(response)
|
|
except AttributeError:
|
|
logger.error('Invalid response object for url %s, '
|
|
'Expected Type: HTTPResponse, Actual Type: %s',
|
|
self.url, type(response))
|
|
self.write_error(ServerError('Invalid response type'))
|
|
except RuntimeError:
|
|
if self._debug:
|
|
logger.error('Connection lost before response written @ %s',
|
|
self.request.ip)
|
|
keep_alive = False
|
|
except Exception as e:
|
|
self.bail_out(
|
|
"Writing response failed, connection closed {}".format(
|
|
repr(e)))
|
|
finally:
|
|
if not keep_alive:
|
|
self.transport.close()
|
|
self.transport = None
|
|
else:
|
|
self._keep_alive_timeout_handler = self.loop.call_later(
|
|
self.keep_alive_timeout,
|
|
self.keep_alive_timeout_callback)
|
|
self._last_response_time = current_time
|
|
self.cleanup()
|
|
|
|
def write_error(self, exception):
|
|
# An error _is_ a response.
|
|
# Don't throw a response timeout, when a response _is_ given.
|
|
if self._response_timeout_handler:
|
|
self._response_timeout_handler.cancel()
|
|
self._response_timeout_handler = None
|
|
response = None
|
|
try:
|
|
response = self.error_handler.response(self.request, exception)
|
|
version = self.request.version if self.request else '1.1'
|
|
self.transport.write(response.output(version))
|
|
except RuntimeError:
|
|
if self._debug:
|
|
logger.error('Connection lost before error written @ %s',
|
|
self.request.ip if self.request else 'Unknown')
|
|
except Exception as e:
|
|
self.bail_out(
|
|
"Writing error failed, connection closed {}".format(
|
|
repr(e)), from_error=True
|
|
)
|
|
finally:
|
|
if self.parser and (self.keep_alive
|
|
or getattr(response, 'status', 0) == 408):
|
|
self.log_response(response)
|
|
try:
|
|
self.transport.close()
|
|
except AttributeError as e:
|
|
logger.debug('Connection lost before server could close it.')
|
|
|
|
def bail_out(self, message, from_error=False):
|
|
if from_error or self.transport.is_closing():
|
|
logger.error("Transport closed @ %s and exception "
|
|
"experienced during error handling",
|
|
self.transport.get_extra_info('peername'))
|
|
logger.debug('Exception:\n%s', traceback.format_exc())
|
|
else:
|
|
exception = ServerError(message)
|
|
self.write_error(exception)
|
|
logger.error(message)
|
|
|
|
def cleanup(self):
|
|
"""This is called when KeepAlive feature is used,
|
|
it resets the connection in order for it to be able
|
|
to handle receiving another request on the same connection."""
|
|
self.parser = None
|
|
self.request = None
|
|
self.url = None
|
|
self.headers = None
|
|
self._request_handler_task = None
|
|
self._request_stream_task = None
|
|
self._total_request_size = 0
|
|
self._is_stream_handler = False
|
|
|
|
def close_if_idle(self):
|
|
"""Close the connection if a request is not being sent or received
|
|
|
|
:return: boolean - True if closed, false if staying open
|
|
"""
|
|
if not self.parser:
|
|
self.transport.close()
|
|
return True
|
|
return False
|
|
|
|
def close(self):
|
|
"""
|
|
Force close the connection.
|
|
"""
|
|
if self.transport is not None:
|
|
self.transport.close()
|
|
self.transport = None
|
|
|
|
|
|
def update_current_time(loop):
|
|
"""Cache the current time, since it is needed at the end of every
|
|
keep-alive request to update the request timeout time
|
|
|
|
:param loop:
|
|
:return:
|
|
"""
|
|
global current_time
|
|
current_time = time()
|
|
loop.call_later(1, partial(update_current_time, loop))
|
|
|
|
|
|
def trigger_events(events, loop):
|
|
"""Trigger event callbacks (functions or async)
|
|
|
|
:param events: one or more sync or async functions to execute
|
|
:param loop: event loop
|
|
"""
|
|
for event in events:
|
|
result = event(loop)
|
|
if isawaitable(result):
|
|
loop.run_until_complete(result)
|
|
|
|
|
|
def serve(host, port, request_handler, error_handler, before_start=None,
|
|
after_start=None, before_stop=None, after_stop=None, debug=False,
|
|
request_timeout=60, response_timeout=60, keep_alive_timeout=5,
|
|
ssl=None, sock=None, request_max_size=None, reuse_port=False,
|
|
loop=None, protocol=HttpProtocol, backlog=100,
|
|
register_sys_signals=True, run_multiple=False, run_async=False,
|
|
connections=None, signal=Signal(), request_class=None,
|
|
access_log=True, keep_alive=True, is_request_stream=False,
|
|
router=None, websocket_max_size=None, websocket_max_queue=None,
|
|
websocket_read_limit=2 ** 16, websocket_write_limit=2 ** 16,
|
|
state=None, graceful_shutdown_timeout=15.0):
|
|
"""Start asynchronous HTTP Server on an individual process.
|
|
|
|
:param host: Address to host on
|
|
:param port: Port to host on
|
|
:param request_handler: Sanic request handler with middleware
|
|
:param error_handler: Sanic error handler with middleware
|
|
:param before_start: function to be executed before the server starts
|
|
listening. Takes arguments `app` instance and `loop`
|
|
:param after_start: function to be executed after the server starts
|
|
listening. Takes arguments `app` instance and `loop`
|
|
:param before_stop: function to be executed when a stop signal is
|
|
received before it is respected. Takes arguments
|
|
`app` instance and `loop`
|
|
:param after_stop: function to be executed when a stop signal is
|
|
received after it is respected. Takes arguments
|
|
`app` instance and `loop`
|
|
:param debug: enables debug output (slows server)
|
|
:param request_timeout: time in seconds
|
|
:param response_timeout: time in seconds
|
|
:param keep_alive_timeout: time in seconds
|
|
:param ssl: SSLContext
|
|
:param sock: Socket for the server to accept connections from
|
|
:param request_max_size: size in bytes, `None` for no limit
|
|
:param reuse_port: `True` for multiple workers
|
|
:param loop: asyncio compatible event loop
|
|
:param protocol: subclass of asyncio protocol class
|
|
:param request_class: Request class to use
|
|
:param access_log: disable/enable access log
|
|
:param websocket_max_size: enforces the maximum size for
|
|
incoming messages in bytes.
|
|
:param websocket_max_queue: sets the maximum length of the queue
|
|
that holds incoming messages.
|
|
:param websocket_read_limit: sets the high-water limit of the buffer for
|
|
incoming bytes, the low-water limit is half
|
|
the high-water limit.
|
|
:param websocket_write_limit: sets the high-water limit of the buffer for
|
|
outgoing bytes, the low-water limit is a
|
|
quarter of the high-water limit.
|
|
:param is_request_stream: disable/enable Request.stream
|
|
:param router: Router object
|
|
:return: Nothing
|
|
"""
|
|
if not run_async:
|
|
# create new event_loop after fork
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
if debug:
|
|
loop.set_debug(debug)
|
|
|
|
connections = connections if connections is not None else set()
|
|
server = partial(
|
|
protocol,
|
|
loop=loop,
|
|
connections=connections,
|
|
signal=signal,
|
|
request_handler=request_handler,
|
|
error_handler=error_handler,
|
|
request_timeout=request_timeout,
|
|
response_timeout=response_timeout,
|
|
keep_alive_timeout=keep_alive_timeout,
|
|
request_max_size=request_max_size,
|
|
request_class=request_class,
|
|
access_log=access_log,
|
|
keep_alive=keep_alive,
|
|
is_request_stream=is_request_stream,
|
|
router=router,
|
|
websocket_max_size=websocket_max_size,
|
|
websocket_max_queue=websocket_max_queue,
|
|
websocket_read_limit=websocket_read_limit,
|
|
websocket_write_limit=websocket_write_limit,
|
|
state=state,
|
|
debug=debug,
|
|
)
|
|
|
|
server_coroutine = loop.create_server(
|
|
server,
|
|
host,
|
|
port,
|
|
ssl=ssl,
|
|
reuse_port=reuse_port,
|
|
sock=sock,
|
|
backlog=backlog
|
|
)
|
|
|
|
# Instead of pulling time at the end of every request,
|
|
# pull it once per minute
|
|
loop.call_soon(partial(update_current_time, loop))
|
|
|
|
if run_async:
|
|
return server_coroutine
|
|
|
|
trigger_events(before_start, loop)
|
|
|
|
try:
|
|
http_server = loop.run_until_complete(server_coroutine)
|
|
except BaseException:
|
|
logger.exception("Unable to start server")
|
|
return
|
|
|
|
trigger_events(after_start, loop)
|
|
|
|
# Ignore SIGINT when run_multiple
|
|
if run_multiple:
|
|
signal_func(SIGINT, SIG_IGN)
|
|
|
|
# Register signals for graceful termination
|
|
if register_sys_signals:
|
|
_singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM)
|
|
for _signal in _singals:
|
|
try:
|
|
loop.add_signal_handler(_signal, loop.stop)
|
|
except NotImplementedError:
|
|
logger.warning('Sanic tried to use loop.add_signal_handler '
|
|
'but it is not implemented on this platform.')
|
|
pid = os.getpid()
|
|
try:
|
|
logger.info('Starting worker [%s]', pid)
|
|
loop.run_forever()
|
|
finally:
|
|
logger.info("Stopping worker [%s]", pid)
|
|
|
|
# Run the on_stop function if provided
|
|
trigger_events(before_stop, loop)
|
|
|
|
# Wait for event loop to finish and all connections to drain
|
|
http_server.close()
|
|
loop.run_until_complete(http_server.wait_closed())
|
|
|
|
# Complete all tasks on the loop
|
|
signal.stopped = True
|
|
for connection in connections:
|
|
connection.close_if_idle()
|
|
|
|
# Gracefully shutdown timeout.
|
|
# We should provide graceful_shutdown_timeout,
|
|
# instead of letting connection hangs forever.
|
|
# Let's roughly calcucate time.
|
|
start_shutdown = 0
|
|
while connections and (start_shutdown < graceful_shutdown_timeout):
|
|
loop.run_until_complete(asyncio.sleep(0.1))
|
|
start_shutdown = start_shutdown + 0.1
|
|
|
|
# Force close non-idle connection after waiting for
|
|
# graceful_shutdown_timeout
|
|
coros = []
|
|
for conn in connections:
|
|
if hasattr(conn, "websocket") and conn.websocket:
|
|
coros.append(
|
|
conn.websocket.close_connection()
|
|
)
|
|
else:
|
|
conn.close()
|
|
|
|
_shutdown = asyncio.gather(*coros, loop=loop)
|
|
loop.run_until_complete(_shutdown)
|
|
|
|
trigger_events(after_stop, loop)
|
|
|
|
loop.close()
|
|
|
|
|
|
def serve_multiple(server_settings, workers):
|
|
"""Start multiple server processes simultaneously. Stop on interrupt
|
|
and terminate signals, and drain connections when complete.
|
|
|
|
:param server_settings: kw arguments to be passed to the serve function
|
|
:param workers: number of workers to launch
|
|
:param stop_event: if provided, is used as a stop signal
|
|
:return:
|
|
"""
|
|
server_settings['reuse_port'] = True
|
|
server_settings['run_multiple'] = True
|
|
|
|
# Handling when custom socket is not provided.
|
|
if server_settings.get('sock') is None:
|
|
sock = socket()
|
|
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
|
|
sock.bind((server_settings['host'], server_settings['port']))
|
|
sock.set_inheritable(True)
|
|
server_settings['sock'] = sock
|
|
server_settings['host'] = None
|
|
server_settings['port'] = None
|
|
|
|
def sig_handler(signal, frame):
|
|
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
|
|
for process in processes:
|
|
os.kill(process.pid, SIGTERM)
|
|
|
|
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
|
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
|
|
|
processes = []
|
|
|
|
for _ in range(workers):
|
|
process = Process(target=serve, kwargs=server_settings)
|
|
process.daemon = True
|
|
process.start()
|
|
processes.append(process)
|
|
|
|
for process in processes:
|
|
process.join()
|
|
|
|
# the above processes will block this until they're stopped
|
|
for process in processes:
|
|
process.terminate()
|
|
server_settings.get('sock').close()
|