sanic/sanic/server.py

518 lines
18 KiB
Python
Raw Normal View History

2016-10-15 20:59:00 +01:00
import asyncio
import os
import traceback
from functools import partial
2016-10-15 20:59:00 +01:00
from inspect import isawaitable
2017-02-27 00:31:39 +00:00
from multiprocessing import Process
from signal import (
SIGTERM, SIGINT,
signal as signal_func,
Signals
)
from socket import (
socket,
SOL_SOCKET,
SO_REUSEADDR,
)
2016-10-28 11:13:03 +01:00
from time import time
from httptools import HttpRequestParser
from httptools.parser.errors import HttpParserError
2016-10-15 20:59:00 +01:00
try:
import uvloop as async_loop
2016-10-16 14:01:59 +01:00
except ImportError:
2016-10-15 20:59:00 +01:00
async_loop = asyncio
from sanic.log import log, netlog
from sanic.response import HTTPResponse
2017-02-16 02:54:00 +00:00
from sanic.request import Request
from sanic.exceptions import (
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError)
2016-10-15 20:59:00 +01:00
2017-01-17 00:12:42 +00:00
current_time = None
2016-10-15 20:59:00 +01:00
class Signal:
stopped = False
2017-01-17 00:12:42 +00:00
class CIDict(dict):
"""Case Insensitive dict where all keys are converted to lowercase
2017-01-17 00:12:42 +00:00
This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive
"""
2017-01-17 00:12:42 +00:00
def get(self, key, default=None):
return super().get(key.casefold(), default)
def __getitem__(self, key):
return super().__getitem__(key.casefold())
def __setitem__(self, key, value):
return super().__setitem__(key.casefold(), value)
def __contains__(self, key):
return super().__contains__(key.casefold())
2016-10-28 11:35:30 +01:00
2016-10-15 20:59:00 +01:00
class HttpProtocol(asyncio.Protocol):
2016-10-16 14:01:59 +01:00
__slots__ = (
# event loop, connection
'loop', 'transport', 'connections', 'signal',
# request params
'parser', 'request', 'url', 'headers',
# request config
'request_handler', 'request_timeout', 'request_max_size',
'request_class',
# enable or disable access log / error log purpose
'has_log',
2016-10-16 14:01:59 +01:00
# connection management
'_total_request_size', '_timeout_handler', '_last_communication_time')
2016-10-16 14:01:59 +01:00
def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60,
2017-04-24 08:47:01 +01:00
request_max_size=None, request_class=None, has_log=True,
2017-04-17 06:43:49 +01:00
keep_alive=True):
2016-10-15 20:59:00 +01:00
self.loop = loop
self.transport = None
self.request = None
self.parser = None
self.url = None
self.headers = None
self.signal = signal
self.has_log = has_log
2016-10-15 20:59:00 +01:00
self.connections = connections
self.request_handler = request_handler
self.error_handler = error_handler
2016-10-15 20:59:00 +01:00
self.request_timeout = request_timeout
self.request_max_size = request_max_size
self.request_class = request_class or Request
2016-10-15 20:59:00 +01:00
self._total_request_size = 0
self._timeout_handler = None
self._last_request_time = None
self._request_handler_task = None
2017-04-17 06:43:49 +01:00
self._keep_alive = keep_alive
2016-10-15 20:59:00 +01:00
2017-04-12 09:55:22 +01:00
@property
def keep_alive(self):
2017-04-17 06:43:49 +01:00
return (self._keep_alive
2017-04-17 05:39:18 +01:00
and not self.signal.stopped
and self.parser.should_keep_alive())
2017-04-12 09:55:22 +01:00
2016-10-27 15:09:36 +01:00
# -------------------------------------------- #
2016-10-15 20:59:00 +01:00
# Connection
# -------------------------------------------- #
def connection_made(self, transport):
self.connections.add(self)
2016-10-16 14:01:59 +01:00
self._timeout_handler = self.loop.call_later(
self.request_timeout, self.connection_timeout)
2016-10-15 20:59:00 +01:00
self.transport = transport
self._last_request_time = current_time
2016-10-15 20:59:00 +01:00
def connection_lost(self, exc):
self.connections.discard(self)
2016-10-15 20:59:00 +01:00
self._timeout_handler.cancel()
def connection_timeout(self):
# Check if
time_elapsed = current_time - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._timeout_handler = (
self.loop.call_later(time_left, self.connection_timeout))
else:
2016-11-26 07:47:16 +00:00
if self._request_handler_task:
self._request_handler_task.cancel()
exception = RequestTimeout('Request Timeout')
self.write_error(exception)
2016-10-15 20:59:00 +01:00
2016-10-27 15:09:36 +01:00
# -------------------------------------------- #
2016-10-15 20:59:00 +01:00
# Parsing
# -------------------------------------------- #
def data_received(self, data):
2016-10-16 14:01:59 +01:00
# Check for the request itself getting too large and exceeding
# memory limits
2016-10-15 20:59:00 +01:00
self._total_request_size += len(data)
if self._total_request_size > self.request_max_size:
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
2016-10-15 20:59:00 +01:00
# Create parser if this is the first time we're receiving data
if self.parser is None:
assert self.request is None
self.headers = []
self.parser = HttpRequestParser(self)
2016-10-15 20:59:00 +01:00
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
2016-12-18 00:25:39 +00:00
except HttpParserError:
exception = InvalidUsage('Bad Request')
self.write_error(exception)
2016-10-15 20:59:00 +01:00
def on_url(self, url):
self.url = url
def on_header(self, name, value):
if name == b'Content-Length' and int(value) > self.request_max_size:
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
2016-10-15 20:59:00 +01:00
self.headers.append((name.decode().casefold(), value.decode()))
2016-10-15 20:59:00 +01:00
def on_headers_complete(self):
self.request = self.request_class(
2016-10-15 20:59:00 +01:00
url_bytes=self.url,
2017-01-17 00:12:42 +00:00
headers=CIDict(self.headers),
2016-10-15 20:59:00 +01:00
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport
2016-10-15 20:59:00 +01:00
)
def on_body(self, body):
2017-01-26 05:56:49 +00:00
self.request.body.append(body)
2016-10-15 20:59:00 +01:00
def on_message_complete(self):
self.request.body = b''.join(self.request.body)
2017-02-21 16:05:06 +00:00
self._request_handler_task = self.loop.create_task(
2017-02-21 16:28:45 +00:00
self.request_handler(
self.request,
self.write_response,
self.stream_response))
2016-10-15 20:59:00 +01:00
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
def write_response(self, response):
2017-02-21 16:28:45 +00:00
"""
Writes response content synchronously to the transport.
"""
2016-10-15 20:59:00 +01:00
try:
2017-04-12 09:55:22 +01:00
keep_alive = self.keep_alive
2016-10-16 14:01:59 +01:00
self.transport.write(
response.output(
2017-02-21 16:28:45 +00:00
self.request.version, keep_alive,
self.request_timeout))
if self.has_log:
netlog.info('', extra={
'status': response.status,
'byte': len(response.body),
'host': '%s:%d' % (self.request.ip[0], self.request.ip[1]),
'request': '%s %s' % (self.request.method,
self.request.url)
})
2017-02-21 16:28:45 +00:00
except AttributeError:
log.error(
('Invalid response object for url {}, '
'Expected Type: HTTPResponse, Actual Type: {}').format(
self.url, type(response)))
self.write_error(ServerError('Invalid response type'))
except RuntimeError:
log.error(
'Connection lost before response written @ {}'.format(
self.request.ip))
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(
repr(e)))
finally:
if not keep_alive:
self.transport.close()
2017-02-21 16:05:06 +00:00
else:
2017-02-21 16:28:45 +00:00
self._last_request_time = current_time
self.cleanup()
async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
the transport to the response so the response consumer can
write to the response as needed.
"""
try:
2017-04-12 09:55:22 +01:00
keep_alive = self.keep_alive
2017-02-21 16:28:45 +00:00
response.transport = self.transport
await response.stream(
self.request.version, keep_alive, self.request_timeout)
if self.has_log:
netlog.info('', extra={
'status': response.status,
'byte': -1,
'host': '%s:%d' % self.request.ip,
'request': '%s %s' % (self.request.method,
self.request.url)
})
except AttributeError:
log.error(
('Invalid response object for url {}, '
'Expected Type: HTTPResponse, Actual Type: {}').format(
self.url, type(response)))
self.write_error(ServerError('Invalid response type'))
except RuntimeError:
log.error(
'Connection lost before response written @ {}'.format(
self.request.ip))
2016-10-15 20:59:00 +01:00
except Exception as e:
2016-10-16 14:01:59 +01:00
self.bail_out(
"Writing response failed, connection closed {}".format(
repr(e)))
finally:
2016-10-15 20:59:00 +01:00
if not keep_alive:
self.transport.close()
else:
self._last_request_time = current_time
2016-10-15 20:59:00 +01:00
self.cleanup()
def write_error(self, exception):
try:
response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else '1.1'
self.transport.write(response.output(version))
except RuntimeError:
log.error(
'Connection lost before error written @ {}'.format(
self.request.ip if self.request else 'Unknown'))
except Exception as e:
self.bail_out(
"Writing error failed, connection closed {}".format(repr(e)),
from_error=True)
finally:
if self.has_log:
extra = {
'status': response.status,
'host': '',
'request': str(self.request) + str(self.url)
}
if response and isinstance(response, HTTPResponse):
extra['byte'] = len(response.body)
else:
extra['byte'] = -1
if self.request:
extra['host'] = '%s:%d' % self.request.ip,
extra['request'] = '%s %s' % (self.request.method,
self.url)
netlog.info('', extra=extra)
self.transport.close()
2016-10-15 20:59:00 +01:00
def bail_out(self, message, from_error=False):
2017-03-09 04:36:01 +00:00
if from_error or self.transport.is_closing():
log.error(
("Transport closed @ {} and exception "
"experienced during error handling").format(
self.transport.get_extra_info('peername')))
log.debug(
'Exception:\n{}'.format(traceback.format_exc()))
else:
exception = ServerError(message)
self.write_error(exception)
log.error(message)
2016-10-15 20:59:00 +01:00
def cleanup(self):
self.parser = None
self.request = None
self.url = None
self.headers = None
self._request_handler_task = None
2016-10-15 20:59:00 +01:00
self._total_request_size = 0
def close_if_idle(self):
"""Close the connection if a request is not being sent or received
2016-10-15 20:59:00 +01:00
:return: boolean - True if closed, false if staying open
"""
if not self.parser:
self.transport.close()
return True
return False
def update_current_time(loop):
"""Cache the current time, since it is needed at the end of every
keep-alive request to update the request timeout time
:param loop:
:return:
"""
global current_time
2016-10-28 11:13:03 +01:00
current_time = time()
loop.call_later(1, partial(update_current_time, loop))
2016-10-15 20:59:00 +01:00
def trigger_events(events, loop):
"""Trigger event callbacks (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
for event in events:
result = event(loop)
if isawaitable(result):
loop.run_until_complete(result)
def serve(host, port, request_handler, error_handler, before_start=None,
2016-12-24 02:40:07 +00:00
after_start=None, before_stop=None, after_stop=None, debug=False,
request_timeout=60, ssl=None, sock=None, request_max_size=None,
reuse_port=False, loop=None, protocol=HttpProtocol, backlog=100,
2017-03-16 03:55:10 +00:00
register_sys_signals=True, run_async=False, connections=None,
2017-04-24 08:47:01 +01:00
signal=Signal(), request_class=None, has_log=True, keep_alive=True):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
:param port: Port to host on
:param request_handler: Sanic request handler with middleware
2016-12-24 02:40:07 +00:00
:param error_handler: Sanic error handler with middleware
:param before_start: function to be executed before the server starts
2017-01-26 00:25:16 +00:00
listening. Takes arguments `app` instance and `loop`
:param after_start: function to be executed after the server starts
2017-01-26 00:25:16 +00:00
listening. Takes arguments `app` instance and `loop`
:param before_stop: function to be executed when a stop signal is
2017-01-26 00:25:16 +00:00
received before it is respected. Takes arguments
`app` instance and `loop`
:param after_stop: function to be executed when a stop signal is
2017-01-26 00:25:16 +00:00
received after it is respected. Takes arguments
2017-03-16 05:52:18 +00:00
`app` instance and `loop`
:param debug: enables debug output (slows server)
:param request_timeout: time in seconds
:param ssl: SSLContext
:param sock: Socket for the server to accept connections from
:param request_max_size: size in bytes, `None` for no limit
:param reuse_port: `True` for multiple workers
:param loop: asyncio compatible event loop
:param protocol: subclass of asyncio protocol class
:param request_class: Request class to use
:param has_log: disable/enable access log and error log
:return: Nothing
"""
2017-02-04 07:27:46 +00:00
if not run_async:
loop = async_loop.new_event_loop()
asyncio.set_event_loop(loop)
if debug:
loop.set_debug(debug)
2016-10-15 20:59:00 +01:00
trigger_events(before_start, loop)
2017-03-15 09:43:47 +00:00
connections = connections if connections is not None else set()
server = partial(
2016-12-22 15:13:38 +00:00
protocol,
2016-10-15 20:59:00 +01:00
loop=loop,
connections=connections,
signal=signal,
request_handler=request_handler,
error_handler=error_handler,
2016-10-15 20:59:00 +01:00
request_timeout=request_timeout,
request_max_size=request_max_size,
request_class=request_class,
2017-04-24 08:47:01 +01:00
has_log=has_log,
2017-04-17 06:43:49 +01:00
keep_alive=keep_alive,
)
server_coroutine = loop.create_server(
server,
host,
port,
ssl=ssl,
reuse_port=reuse_port,
2017-01-04 02:35:11 +00:00
sock=sock,
backlog=backlog
)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
2017-01-24 03:58:37 +00:00
if run_async:
return server_coroutine
2016-10-15 20:59:00 +01:00
try:
http_server = loop.run_until_complete(server_coroutine)
except:
log.exception("Unable to start server")
2016-10-15 20:59:00 +01:00
return
trigger_events(after_start, loop)
2016-10-15 20:59:00 +01:00
# Register signals for graceful termination
if register_sys_signals:
for _signal in (SIGINT, SIGTERM):
2017-02-14 02:07:35 +00:00
try:
loop.add_signal_handler(_signal, loop.stop)
except NotImplementedError:
2017-02-14 20:53:55 +00:00
log.warn('Sanic tried to use loop.add_signal_handler but it is'
' not implemented on this platform.')
pid = os.getpid()
2016-10-15 20:59:00 +01:00
try:
log.info('Starting worker [{}]'.format(pid))
2016-10-15 20:59:00 +01:00
loop.run_forever()
finally:
log.info("Stopping worker [{}]".format(pid))
2016-10-15 20:59:00 +01:00
# Run the on_stop function if provided
trigger_events(before_stop, loop)
2016-10-15 20:59:00 +01:00
# Wait for event loop to finish and all connections to drain
http_server.close()
loop.run_until_complete(http_server.wait_closed())
# Complete all tasks on the loop
signal.stopped = True
for connection in connections:
2016-10-15 20:59:00 +01:00
connection.close_if_idle()
while connections:
loop.run_until_complete(asyncio.sleep(0.1))
trigger_events(after_stop, loop)
2016-10-15 20:59:00 +01:00
loop.close()
2017-02-27 00:31:39 +00:00
def serve_multiple(server_settings, workers):
"""Start multiple server processes simultaneously. Stop on interrupt
and terminate signals, and drain connections when complete.
:param server_settings: kw arguments to be passed to the serve function
:param workers: number of workers to launch
:param stop_event: if provided, is used as a stop signal
:return:
"""
server_settings['reuse_port'] = True
# Handling when custom socket is not provided.
2017-03-02 04:42:55 +00:00
if server_settings.get('sock') is None:
sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
sock.bind((server_settings['host'], server_settings['port']))
sock.set_inheritable(True)
server_settings['sock'] = sock
server_settings['host'] = None
server_settings['port'] = None
2017-02-27 00:31:39 +00:00
def sig_handler(signal, frame):
log.info("Received signal {}. Shutting down.".format(
Signals(signal).name))
for process in processes:
os.kill(process.pid, SIGINT)
2017-02-27 00:31:39 +00:00
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
processes = []
for _ in range(workers):
process = Process(target=serve, kwargs=server_settings)
process.daemon = True
process.start()
processes.append(process)
for process in processes:
process.join()
# the above processes will block this until they're stopped
for process in processes:
process.terminate()
server_settings.get('sock').close()