Move server.py into its own module (#2230)

* Move server.py into its own module

* Change monkeypatch path on test_logging.py
This commit is contained in:
Adam Hopkins 2021-08-31 11:51:32 +03:00 committed by GitHub
parent 9d0b54c90d
commit 945885d501
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 951 additions and 890 deletions

View File

@ -0,0 +1,52 @@
from types import SimpleNamespace
from sanic.models.protocol_types import TransportProtocol
class Signal:
stopped = False
class ConnInfo:
"""
Local and remote addresses and SSL status info.
"""
__slots__ = (
"client_port",
"client",
"client_ip",
"ctx",
"peername",
"server_port",
"server",
"sockname",
"ssl",
)
def __init__(self, transport: TransportProtocol, unix=None):
self.ctx = SimpleNamespace()
self.peername = None
self.server = self.client = ""
self.server_port = self.client_port = 0
self.client_ip = ""
self.sockname = addr = transport.get_extra_info("sockname")
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
if isinstance(addr, str): # UNIX socket
self.server = unix or addr
return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.server_port = addr[1]
# self.server gets non-standard port appended
if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername")
if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.client_ip = addr[0]
self.client_port = addr[1]

View File

@ -1,889 +0,0 @@
from __future__ import annotations
from ssl import SSLContext
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Optional,
Type,
Union,
)
from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio
import multiprocessing
import os
import secrets
import socket
import stat
from asyncio import CancelledError
from asyncio.transports import Transport
from functools import partial
from inspect import isawaitable
from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from time import monotonic as current_time
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.config import Config
from sanic.exceptions import RequestTimeout, SanicException, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.request import Request
try:
import uvloop # type: ignore
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
class Signal:
stopped = False
class ConnInfo:
"""
Local and remote addresses and SSL status info.
"""
__slots__ = (
"client_port",
"client",
"client_ip",
"ctx",
"peername",
"server_port",
"server",
"sockname",
"ssl",
)
def __init__(self, transport: TransportProtocol, unix=None):
self.ctx = SimpleNamespace()
self.peername = None
self.server = self.client = ""
self.server_port = self.client_port = 0
self.client_ip = ""
self.sockname = addr = transport.get_extra_info("sockname")
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
if isinstance(addr, str): # UNIX socket
self.server = unix or addr
return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.server_port = addr[1]
# self.server gets non-standard port appended
if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername")
if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.client_ip = addr[0]
self.client_port = addr[1]
class SanicProtocol(asyncio.Protocol):
__slots__ = (
"app",
# event loop, connection
"loop",
"transport",
"connections",
"conn_info",
"signal",
"_can_write",
"_time",
"_task",
"_unix",
"_data_received",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
self.signal = signal or Signal()
self.transport: Optional[Transport] = None
self.connections = connections if connections is not None else set()
self.conn_info: Optional[ConnInfo] = None
self._can_write = asyncio.Event()
self._can_write.set()
self._unix = unix
self._time = 0.0 # type: float
self._task = None # type: Optional[asyncio.Task]
self._data_received = asyncio.Event()
@property
def ctx(self):
if self.conn_info is not None:
return self.conn_info.ctx
else:
return None
async def send(self, data):
"""
Generic data write implementation with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
self._time = current_time()
async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None
# asyncio.Protocol API Callbacks #
# ------------------------------ #
def connection_made(self, transport):
"""
Generic connection-made, with no connection_task, and no recv_buffer.
Override this for protocol-specific connection implementations.
"""
try:
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except BaseException:
error_logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
if self._data_received:
self._data_received.set()
except BaseException:
error_logger.exception("protocol.data_received")
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
"""
This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport
"""
__touchup__ = (
"send",
"connection_task",
)
__slots__ = (
# request params
"request",
# request config
"request_handler",
"request_timeout",
"response_timeout",
"keep_alive_timeout",
"request_max_size",
"request_class",
"error_handler",
# enable or disable access log purpose
"access_log",
# connection management
"state",
"url",
"_handler_task",
"_http",
"_exception",
"recv_buffer",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
state=None,
unix=None,
**kwargs,
):
super().__init__(
loop=loop,
app=app,
signal=signal,
connections=connections,
unix=unix,
)
self.url = None
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._exception = None
def _setup_connection(self):
self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self): # no cov
"""
Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
"""
try:
self._setup_connection()
await self.app.dispatch(
"http.lifecycle.begin",
inline=True,
context={"conn_info": self.conn_info},
)
await self._http.http1()
except CancelledError:
pass
except Exception:
error_logger.exception("protocol.connection_task uncaught")
finally:
if self.app.debug and self._http:
ip = self.transport.get_extra_info("peername")
error_logger.error(
"Connection lost before response written"
f" @ {ip} {self._http.request}"
)
self._http = None
self._task = None
try:
self.close()
except BaseException:
error_logger.exception("Closing failed")
finally:
await self.app.dispatch(
"http.lifecycle.complete",
inline=True,
context={"conn_info": self.conn_info},
)
...
def check_timeouts(self):
"""
Runs itself periodically to enforce any expired timeouts.
"""
try:
if not self._task:
return
duration = current_time() - self._time
stage = self._http.stage
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
logger.debug("KeepAlive Timeout. Closing connection.")
elif stage is Stage.REQUEST and duration > self.request_timeout:
logger.debug("Request Timeout. Closing connection.")
self._http.exception = RequestTimeout("Request Timeout")
elif stage is Stage.HANDLER and self._http.upgrade_websocket:
logger.debug("Handling websocket. Timeouts disabled.")
return
elif (
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
and duration > self.response_timeout
):
logger.debug("Response Timeout. Closing connection.")
self._http.exception = ServiceUnavailable("Response Timeout")
else:
interval = (
min(
self.keep_alive_timeout,
self.request_timeout,
self.response_timeout,
)
/ 2
)
self.loop.call_later(max(0.1, interval), self.check_timeouts)
return
self._task.cancel()
except Exception:
error_logger.exception("protocol.check_timeouts")
async def send(self, data): # no cov
"""
Writes HTTP data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
context={"data": data},
)
self.transport.write(data)
self._time = current_time()
def close_if_idle(self) -> bool:
"""
Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
if self._http is None or self._http.stage is Stage.IDLE:
self.close()
return True
return False
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
"""
HTTP-protocol-specific new connection handler
"""
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
if (
len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except Exception:
error_logger.exception("protocol.data_received")
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
"""
Trigger event callbacks (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
if events:
for event in events:
result = event(loop)
if isawaitable(result):
loop.run_until_complete(result)
class AsyncioServer:
"""
Wraps an asyncio server with functionality that might be useful to
a user who needs to manage the server lifecycle manually.
"""
__slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
def __init__(
self,
app,
loop,
serve_coro,
connections,
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
self.app = app
self.connections = connections
self.loop = loop
self.serve_coro = serve_coro
self.server = None
self.init = False
def startup(self):
"""
Trigger "before_server_start" events
"""
self.init = True
return self.app._startup()
def before_start(self):
"""
Trigger "before_server_start" events
"""
return self._server_event("init", "before")
def after_start(self):
"""
Trigger "after_server_start" events
"""
return self._server_event("init", "after")
def before_stop(self):
"""
Trigger "before_server_stop" events
"""
return self._server_event("shutdown", "before")
def after_stop(self):
"""
Trigger "after_server_stop" events
"""
return self._server_event("shutdown", "after")
def is_serving(self) -> bool:
if self.server:
return self.server.is_serving()
return False
def wait_closed(self):
if self.server:
return self.server.wait_closed()
def close(self):
if self.server:
self.server.close()
coro = self.wait_closed()
task = asyncio.ensure_future(coro, loop=self.loop)
return task
def start_serving(self):
if self.server:
try:
return self.server.start_serving()
except AttributeError:
raise NotImplementedError(
"server.start_serving not available in this version "
"of asyncio or uvloop."
)
def serve_forever(self):
if self.server:
try:
return self.server.serve_forever()
except AttributeError:
raise NotImplementedError(
"server.serve_forever not available in this version "
"of asyncio or uvloop."
)
def _server_event(self, concern: str, action: str):
if not self.init:
raise SanicException(
"Cannot dispatch server event without "
"first running server.startup()"
)
return self.app._server_event(concern, action, loop=self.loop)
def __await__(self):
"""
Starts the asyncio server, returns AsyncServerCoro
"""
task = asyncio.ensure_future(self.serve_coro)
while not task.done():
yield
self.server = task.result()
return self
def serve(
host,
port,
app: Sanic,
ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
reuse_port: bool = False,
loop=None,
protocol: Type[asyncio.Protocol] = HttpProtocol,
backlog: int = 100,
register_sys_signals: bool = True,
run_multiple: bool = False,
run_async: bool = False,
connections=None,
signal=Signal(),
state=None,
asyncio_server_kwargs=None,
):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
:param port: Port to host on
:param before_start: function to be executed before the server starts
listening. Takes arguments `app` instance and `loop`
:param after_start: function to be executed after the server starts
listening. Takes arguments `app` instance and `loop`
:param before_stop: function to be executed when a stop signal is
received before it is respected. Takes arguments
`app` instance and `loop`
:param after_stop: function to be executed when a stop signal is
received after it is respected. Takes arguments
`app` instance and `loop`
:param ssl: SSLContext
:param sock: Socket for the server to accept connections from
:param unix: Unix socket to listen on instead of TCP port
:param reuse_port: `True` for multiple workers
:param loop: asyncio compatible event loop
:param run_async: bool: Do not create a new event loop for the server,
and return an AsyncServer object rather than running it
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
create_server method
:return: Nothing
"""
if not run_async and not loop:
# create new event_loop after fork
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if app.debug:
loop.set_debug(app.debug)
app.asgi = False
connections = connections if connections is not None else set()
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
server = partial(
protocol,
loop=loop,
connections=connections,
signal=signal,
app=app,
state=state,
unix=unix,
**protocol_kwargs,
)
asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
# UNIX sockets are always bound by us (to preserve semantics between modes)
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_coroutine = loop.create_server(
server,
None if sock else host,
None if sock else port,
ssl=ssl,
reuse_port=reuse_port,
sock=sock,
backlog=backlog,
**asyncio_server_kwargs,
)
if run_async:
return AsyncioServer(
app=app,
loop=loop,
serve_coro=server_coroutine,
connections=connections,
)
loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
try:
http_server = loop.run_until_complete(server_coroutine)
except BaseException:
error_logger.exception("Unable to start server")
return
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
finally:
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain
http_server.close()
loop.run_until_complete(http_server.wait_closed())
# Complete all tasks on the loop
signal.stopped = True
for connection in connections:
connection.close_if_idle()
# Gracefully shutdown timeout.
# We should provide graceful_shutdown_timeout,
# instead of letting connection hangs forever.
# Let's roughly calcucate time.
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
start_shutdown: float = 0
while connections and (start_shutdown < graceful):
loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection())
else:
conn.close()
_shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown)
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix)
def _build_protocol_kwargs(
protocol: Type[asyncio.Protocol], config: Config
) -> Dict[str, Union[int, float]]:
if hasattr(protocol, "websocket_handshake"):
return {
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
"websocket_read_limit": config.WEBSOCKET_READ_LIMIT,
"websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT,
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
}
return {}
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
"""Create TCP server socket.
:param host: IPv4, IPv6 or hostname may be specified
:param port: TCP port number
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host)
host = str(ip)
sock = socket.socket(
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
)
except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(backlog)
return sock
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
"""Create unix socket.
:param path: filesystem path
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
"""Open or atomically replace existing socket with zero downtime."""
# Sanitise and pre-verify socket path
path = os.path.abspath(path)
folder = os.path.dirname(path)
if not os.path.isdir(folder):
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
try:
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
raise FileExistsError(f"Existing file is not a socket: {path}")
except FileNotFoundError:
pass
# Create new socket with a random temporary name
tmp_path = f"{path}.{secrets.token_urlsafe()}"
sock = socket.socket(socket.AF_UNIX)
try:
# Critical section begins (filename races)
sock.bind(tmp_path)
try:
os.chmod(tmp_path, mode)
# Start listening before rename to avoid connection failures
sock.listen(backlog)
os.rename(tmp_path, path)
except: # noqa: E722
try:
os.unlink(tmp_path)
finally:
raise
except: # noqa: E722
try:
sock.close()
finally:
raise
return sock
def remove_unix_socket(path: Optional[str]) -> None:
"""Remove dead unix socket during server exit."""
if not path:
return
try:
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
# Is it actually dead (doesn't belong to a new server instance)?
with socket.socket(socket.AF_UNIX) as testsock:
try:
testsock.connect(path)
except ConnectionRefusedError:
os.unlink(path)
except FileNotFoundError:
pass
def serve_single(server_settings):
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
if not server_settings.get("run_async"):
# create new event_loop after fork
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_settings["loop"] = loop
trigger_events(main_start, server_settings["loop"])
serve(**server_settings)
trigger_events(main_stop, server_settings["loop"])
server_settings["loop"].close()
def serve_multiple(server_settings, workers):
"""Start multiple server processes simultaneously. Stop on interrupt
and terminate signals, and drain connections when complete.
:param server_settings: kw arguments to be passed to the serve function
:param workers: number of workers to launch
:param stop_event: if provided, is used as a stop signal
:return:
"""
server_settings["reuse_port"] = True
server_settings["run_multiple"] = True
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
trigger_events(main_start, loop)
# Create a listening socket or use the one in settings
sock = server_settings.get("sock")
unix = server_settings["unix"]
backlog = server_settings["backlog"]
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_settings["unix"] = unix
if sock is None:
sock = bind_socket(
server_settings["host"], server_settings["port"], backlog=backlog
)
sock.set_inheritable(True)
server_settings["sock"] = sock
server_settings["host"] = None
server_settings["port"] = None
processes = []
def sig_handler(signal, frame):
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
for process in processes:
os.kill(process.pid, SIGTERM)
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)
process.daemon = True
process.start()
processes.append(process)
for process in processes:
process.join()
# the above processes will block this until they're stopped
for process in processes:
process.terminate()
trigger_events(main_stop, loop)
sock.close()
loop.close()
remove_unix_socket(unix)

26
sanic/server/__init__.py Normal file
View File

@ -0,0 +1,26 @@
import asyncio
from sanic.models.server_types import ConnInfo, Signal
from sanic.server.async_server import AsyncioServer
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.runners import serve, serve_multiple, serve_single
try:
import uvloop # type: ignore
if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy):
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
__all__ = (
"AsyncioServer",
"ConnInfo",
"HttpProtocol",
"Signal",
"serve",
"serve_multiple",
"serve_single",
)

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import asyncio
from sanic.exceptions import SanicException
class AsyncioServer:
"""
Wraps an asyncio server with functionality that might be useful to
a user who needs to manage the server lifecycle manually.
"""
__slots__ = ("app", "connections", "loop", "serve_coro", "server", "init")
def __init__(
self,
app,
loop,
serve_coro,
connections,
):
# Note, Sanic already called "before_server_start" events
# before this helper was even created. So we don't need it here.
self.app = app
self.connections = connections
self.loop = loop
self.serve_coro = serve_coro
self.server = None
self.init = False
def startup(self):
"""
Trigger "before_server_start" events
"""
self.init = True
return self.app._startup()
def before_start(self):
"""
Trigger "before_server_start" events
"""
return self._server_event("init", "before")
def after_start(self):
"""
Trigger "after_server_start" events
"""
return self._server_event("init", "after")
def before_stop(self):
"""
Trigger "before_server_stop" events
"""
return self._server_event("shutdown", "before")
def after_stop(self):
"""
Trigger "after_server_stop" events
"""
return self._server_event("shutdown", "after")
def is_serving(self) -> bool:
if self.server:
return self.server.is_serving()
return False
def wait_closed(self):
if self.server:
return self.server.wait_closed()
def close(self):
if self.server:
self.server.close()
coro = self.wait_closed()
task = asyncio.ensure_future(coro, loop=self.loop)
return task
def start_serving(self):
if self.server:
try:
return self.server.start_serving()
except AttributeError:
raise NotImplementedError(
"server.start_serving not available in this version "
"of asyncio or uvloop."
)
def serve_forever(self):
if self.server:
try:
return self.server.serve_forever()
except AttributeError:
raise NotImplementedError(
"server.serve_forever not available in this version "
"of asyncio or uvloop."
)
def _server_event(self, concern: str, action: str):
if not self.init:
raise SanicException(
"Cannot dispatch server event without "
"first running server.startup()"
)
return self.app._server_event(concern, action, loop=self.loop)
def __await__(self):
"""
Starts the asyncio server, returns AsyncServerCoro
"""
task = asyncio.ensure_future(self.serve_coro)
while not task.done():
yield
self.server = task.result()
return self

16
sanic/server/events.py Normal file
View File

@ -0,0 +1,16 @@
from inspect import isawaitable
from typing import Any, Callable, Iterable, Optional
def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop):
"""
Trigger event callbacks (functions or async)
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
if events:
for event in events:
result = event(loop)
if isawaitable(result):
loop.run_until_complete(result)

View File

View File

@ -0,0 +1,132 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio
from asyncio import CancelledError
from asyncio.transports import Transport
from time import monotonic as current_time
from sanic.log import error_logger
from sanic.models.server_types import ConnInfo, Signal
class SanicProtocol(asyncio.Protocol):
__slots__ = (
"app",
# event loop, connection
"loop",
"transport",
"connections",
"conn_info",
"signal",
"_can_write",
"_time",
"_task",
"_unix",
"_data_received",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
self.signal = signal or Signal()
self.transport: Optional[Transport] = None
self.connections = connections if connections is not None else set()
self.conn_info: Optional[ConnInfo] = None
self._can_write = asyncio.Event()
self._can_write.set()
self._unix = unix
self._time = 0.0 # type: float
self._task = None # type: Optional[asyncio.Task]
self._data_received = asyncio.Event()
@property
def ctx(self):
if self.conn_info is not None:
return self.conn_info.ctx
else:
return None
async def send(self, data):
"""
Generic data write implementation with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
self._time = current_time()
async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()
def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None
# asyncio.Protocol API Callbacks #
# ------------------------------ #
def connection_made(self, transport):
"""
Generic connection-made, with no connection_task, and no recv_buffer.
Override this for protocol-specific connection implementations.
"""
try:
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except BaseException:
error_logger.exception("protocol.connection_lost")
def pause_writing(self):
self._can_write.clear()
def resume_writing(self):
self._can_write.set()
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
if self._data_received:
self._data_received.set()
except BaseException:
error_logger.exception("protocol.data_received")

View File

@ -0,0 +1,233 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING:
from sanic.app import Sanic
from asyncio import CancelledError
from time import monotonic as current_time
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import error_logger, logger
from sanic.models.server_types import ConnInfo
from sanic.request import Request
from sanic.server.protocols.base_protocol import SanicProtocol
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
"""
This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport
"""
__touchup__ = (
"send",
"connection_task",
)
__slots__ = (
# request params
"request",
# request config
"request_handler",
"request_timeout",
"response_timeout",
"keep_alive_timeout",
"request_max_size",
"request_class",
"error_handler",
# enable or disable access log purpose
"access_log",
# connection management
"state",
"url",
"_handler_task",
"_http",
"_exception",
"recv_buffer",
)
def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
state=None,
unix=None,
**kwargs,
):
super().__init__(
loop=loop,
app=app,
signal=signal,
connections=connections,
unix=unix,
)
self.url = None
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._exception = None
def _setup_connection(self):
self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self): # no cov
"""
Run a HTTP connection.
Timeouts and some additional error handling occur here, while most of
everything else happens in class Http or in code called from there.
"""
try:
self._setup_connection()
await self.app.dispatch(
"http.lifecycle.begin",
inline=True,
context={"conn_info": self.conn_info},
)
await self._http.http1()
except CancelledError:
pass
except Exception:
error_logger.exception("protocol.connection_task uncaught")
finally:
if self.app.debug and self._http:
ip = self.transport.get_extra_info("peername")
error_logger.error(
"Connection lost before response written"
f" @ {ip} {self._http.request}"
)
self._http = None
self._task = None
try:
self.close()
except BaseException:
error_logger.exception("Closing failed")
finally:
await self.app.dispatch(
"http.lifecycle.complete",
inline=True,
context={"conn_info": self.conn_info},
)
# Important to keep this Ellipsis here for the TouchUp module
...
def check_timeouts(self):
"""
Runs itself periodically to enforce any expired timeouts.
"""
try:
if not self._task:
return
duration = current_time() - self._time
stage = self._http.stage
if stage is Stage.IDLE and duration > self.keep_alive_timeout:
logger.debug("KeepAlive Timeout. Closing connection.")
elif stage is Stage.REQUEST and duration > self.request_timeout:
logger.debug("Request Timeout. Closing connection.")
self._http.exception = RequestTimeout("Request Timeout")
elif stage is Stage.HANDLER and self._http.upgrade_websocket:
logger.debug("Handling websocket. Timeouts disabled.")
return
elif (
stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED)
and duration > self.response_timeout
):
logger.debug("Response Timeout. Closing connection.")
self._http.exception = ServiceUnavailable("Response Timeout")
else:
interval = (
min(
self.keep_alive_timeout,
self.request_timeout,
self.response_timeout,
)
/ 2
)
self.loop.call_later(max(0.1, interval), self.check_timeouts)
return
self._task.cancel()
except Exception:
error_logger.exception("protocol.check_timeouts")
async def send(self, data): # no cov
"""
Writes HTTP data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
await self.app.dispatch(
"http.lifecycle.send",
inline=True,
context={"data": data},
)
self.transport.write(data)
self._time = current_time()
def close_if_idle(self) -> bool:
"""
Close the connection if a request is not being sent or received
:return: boolean - True if closed, false if staying open
"""
if self._http is None or self._http.stage is Stage.IDLE:
self.close()
return True
return False
# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #
def connection_made(self, transport):
"""
HTTP-protocol-specific new connection handler
"""
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self._task = self.loop.create_task(self.connection_task())
self.recv_buffer = bytearray()
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")
def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data
if (
len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading()
if self._data_received:
self._data_received.set()
except Exception:
error_logger.exception("protocol.data_received")

287
sanic/server/runners.py Normal file
View File

@ -0,0 +1,287 @@
from __future__ import annotations
from ssl import SSLContext
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from sanic.config import Config
from sanic.server.events import trigger_events
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio
import multiprocessing
import os
import socket
from functools import partial
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.log import error_logger, logger
from sanic.models.server_types import Signal
from sanic.server.async_server import AsyncioServer
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.socket import (
bind_socket,
bind_unix_socket,
remove_unix_socket,
)
def serve(
host,
port,
app: Sanic,
ssl: Optional[SSLContext] = None,
sock: Optional[socket.socket] = None,
unix: Optional[str] = None,
reuse_port: bool = False,
loop=None,
protocol: Type[asyncio.Protocol] = HttpProtocol,
backlog: int = 100,
register_sys_signals: bool = True,
run_multiple: bool = False,
run_async: bool = False,
connections=None,
signal=Signal(),
state=None,
asyncio_server_kwargs=None,
):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
:param port: Port to host on
:param before_start: function to be executed before the server starts
listening. Takes arguments `app` instance and `loop`
:param after_start: function to be executed after the server starts
listening. Takes arguments `app` instance and `loop`
:param before_stop: function to be executed when a stop signal is
received before it is respected. Takes arguments
`app` instance and `loop`
:param after_stop: function to be executed when a stop signal is
received after it is respected. Takes arguments
`app` instance and `loop`
:param ssl: SSLContext
:param sock: Socket for the server to accept connections from
:param unix: Unix socket to listen on instead of TCP port
:param reuse_port: `True` for multiple workers
:param loop: asyncio compatible event loop
:param run_async: bool: Do not create a new event loop for the server,
and return an AsyncServer object rather than running it
:param asyncio_server_kwargs: key-value args for asyncio/uvloop
create_server method
:return: Nothing
"""
if not run_async and not loop:
# create new event_loop after fork
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if app.debug:
loop.set_debug(app.debug)
app.asgi = False
connections = connections if connections is not None else set()
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
server = partial(
protocol,
loop=loop,
connections=connections,
signal=signal,
app=app,
state=state,
unix=unix,
**protocol_kwargs,
)
asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {}
)
# UNIX sockets are always bound by us (to preserve semantics between modes)
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_coroutine = loop.create_server(
server,
None if sock else host,
None if sock else port,
ssl=ssl,
reuse_port=reuse_port,
sock=sock,
backlog=backlog,
**asyncio_server_kwargs,
)
if run_async:
return AsyncioServer(
app=app,
loop=loop,
serve_coro=server_coroutine,
connections=connections,
)
loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
try:
http_server = loop.run_until_complete(server_coroutine)
except BaseException:
error_logger.exception("Unable to start server")
return
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
finally:
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
loop.run_until_complete(app._server_event("shutdown", "before"))
# Wait for event loop to finish and all connections to drain
http_server.close()
loop.run_until_complete(http_server.wait_closed())
# Complete all tasks on the loop
signal.stopped = True
for connection in connections:
connection.close_if_idle()
# Gracefully shutdown timeout.
# We should provide graceful_shutdown_timeout,
# instead of letting connection hangs forever.
# Let's roughly calcucate time.
graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT
start_shutdown: float = 0
while connections and (start_shutdown < graceful):
loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection())
else:
conn.close()
_shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown)
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix)
def serve_single(server_settings):
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
if not server_settings.get("run_async"):
# create new event_loop after fork
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_settings["loop"] = loop
trigger_events(main_start, server_settings["loop"])
serve(**server_settings)
trigger_events(main_stop, server_settings["loop"])
server_settings["loop"].close()
def serve_multiple(server_settings, workers):
"""Start multiple server processes simultaneously. Stop on interrupt
and terminate signals, and drain connections when complete.
:param server_settings: kw arguments to be passed to the serve function
:param workers: number of workers to launch
:param stop_event: if provided, is used as a stop signal
:return:
"""
server_settings["reuse_port"] = True
server_settings["run_multiple"] = True
main_start = server_settings.pop("main_start", None)
main_stop = server_settings.pop("main_stop", None)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
trigger_events(main_start, loop)
# Create a listening socket or use the one in settings
sock = server_settings.get("sock")
unix = server_settings["unix"]
backlog = server_settings["backlog"]
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_settings["unix"] = unix
if sock is None:
sock = bind_socket(
server_settings["host"], server_settings["port"], backlog=backlog
)
sock.set_inheritable(True)
server_settings["sock"] = sock
server_settings["host"] = None
server_settings["port"] = None
processes = []
def sig_handler(signal, frame):
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
for process in processes:
os.kill(process.pid, SIGTERM)
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)
process.daemon = True
process.start()
processes.append(process)
for process in processes:
process.join()
# the above processes will block this until they're stopped
for process in processes:
process.terminate()
trigger_events(main_stop, loop)
sock.close()
loop.close()
remove_unix_socket(unix)
def _build_protocol_kwargs(
protocol: Type[asyncio.Protocol], config: Config
) -> Dict[str, Union[int, float]]:
if hasattr(protocol, "websocket_handshake"):
return {
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
"websocket_read_limit": config.WEBSOCKET_READ_LIMIT,
"websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT,
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
}
return {}

87
sanic/server/socket.py Normal file
View File

@ -0,0 +1,87 @@
from __future__ import annotations
import os
import secrets
import socket
import stat
from ipaddress import ip_address
from typing import Optional
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
"""Create TCP server socket.
:param host: IPv4, IPv6 or hostname may be specified
:param port: TCP port number
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host)
host = str(ip)
sock = socket.socket(
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
)
except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(backlog)
return sock
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
"""Create unix socket.
:param path: filesystem path
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
"""Open or atomically replace existing socket with zero downtime."""
# Sanitise and pre-verify socket path
path = os.path.abspath(path)
folder = os.path.dirname(path)
if not os.path.isdir(folder):
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
try:
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
raise FileExistsError(f"Existing file is not a socket: {path}")
except FileNotFoundError:
pass
# Create new socket with a random temporary name
tmp_path = f"{path}.{secrets.token_urlsafe()}"
sock = socket.socket(socket.AF_UNIX)
try:
# Critical section begins (filename races)
sock.bind(tmp_path)
try:
os.chmod(tmp_path, mode)
# Start listening before rename to avoid connection failures
sock.listen(backlog)
os.rename(tmp_path, path)
except: # noqa: E722
try:
os.unlink(tmp_path)
finally:
raise
except: # noqa: E722
try:
sock.close()
finally:
raise
return sock
def remove_unix_socket(path: Optional[str]) -> None:
"""Remove dead unix socket during server exit."""
if not path:
return
try:
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
# Is it actually dead (doesn't belong to a new server instance)?
with socket.socket(socket.AF_UNIX) as testsock:
try:
testsock.connect(path)
except ConnectionRefusedError:
os.unlink(path)
except FileNotFoundError:
pass

View File

@ -116,7 +116,9 @@ def test_log_connection_lost(app, debug, monkeypatch):
stream = StringIO() stream = StringIO()
error = logging.getLogger("sanic.error") error = logging.getLogger("sanic.error")
error.addHandler(logging.StreamHandler(stream)) error.addHandler(logging.StreamHandler(stream))
monkeypatch.setattr(sanic.server, "error_logger", error) monkeypatch.setattr(
sanic.server.protocols.http_protocol, "error_logger", error
)
@app.route("/conn_lost") @app.route("/conn_lost")
async def conn_lost(request): async def conn_lost(request):