From a62c84a9546d47cac0b17d3d7ce522944355c19e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= <98187+Tronic@users.noreply.github.com> Date: Mon, 29 Jun 2020 08:55:32 +0300 Subject: [PATCH] Socket binding implemented properly for IPv6 and UNIX sockets. (#1641) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Socket binding implemented properly for IPv6 and UNIX sockets. - app.run("::1") for IPv6 - app.run("unix:/tmp/server.sock") for UNIX sockets - app.run("localhost") retains old functionality (randomly either IPv4 or IPv6) Do note that IPv6 and UNIX sockets are not fully supported by other Sanic facilities. In particular, request.server_name and request.server_port are currently unreliable. * Fix Windows compatibility by not referring to socket.AF_UNIX unless needed. * Compatibility fix. * Fix test of existing unix socket. * Cleaner unix socket removal. * Remove unix socket on exit also with workers=1. * More pedantic UNIX socket implementation. * Refactor app to take unix= argument instead of unix:-prefixed host. Goin' fast @ unix-socket fixed. * Linter * Proxy properties cleanup. Slight changes of semantics. SERVER_NAME now overrides everything. * Have server fill in connection info instead of request asking the socket. - Would be a good idea to remove request.transport entirely but I didn't dare to touch it yet. * Linter 💣🌟✊💀 * Fix typing issues. request.server_name returns empty string if host header is missing. * Fix tests * Tests were failing, fix connection info. * Linter nazi says you need that empty line. * Rename a to addr, leave client empty for unix sockets. * Add --unix support when sanic is run as module. * Remove remove_route, deprecated in 19.6. * Improved unix socket binding. * More robust creating and unlinking of sockets. Show proper and not temporary name in conn_info. * Add comprehensive tests for unix socket mode. * Hide some imports inside functions to avoid Windows failure. * Mention unix socket mode in deployment docs. * Fix merge commit. * Make test_unix_connection_multiple_workers pickleable for spawn mode multiprocessing. Co-authored-by: L. Kärkkäinen Co-authored-by: Adam Hopkins --- docs/sanic/deploying.rst | 1 + examples/delayed_response.py | 18 +++ sanic/__main__.py | 2 + sanic/app.py | 13 +- sanic/asgi.py | 3 +- sanic/request.py | 149 +++++++++++----------- sanic/server.py | 156 +++++++++++++++++++++-- tests/test_requests.py | 28 ++--- tests/test_unix_socket.py | 235 +++++++++++++++++++++++++++++++++++ 9 files changed, 504 insertions(+), 101 deletions(-) create mode 100644 examples/delayed_response.py create mode 100644 tests/test_unix_socket.py diff --git a/docs/sanic/deploying.rst b/docs/sanic/deploying.rst index 886abdf1..bc6e9987 100644 --- a/docs/sanic/deploying.rst +++ b/docs/sanic/deploying.rst @@ -16,6 +16,7 @@ keyword arguments: - `host` *(default `"127.0.0.1"`)*: Address to host the server on. - `port` *(default `8000`)*: Port to host the server on. +- `unix` *(default `None`)*: Unix socket name to host the server on (instead of TCP). - `debug` *(default `False`)*: Enables debug output (slows server). - `ssl` *(default `None`)*: `SSLContext` for SSL encryption of worker(s). - `sock` *(default `None`)*: Socket for the server to accept connections from. diff --git a/examples/delayed_response.py b/examples/delayed_response.py new file mode 100644 index 00000000..4105edba --- /dev/null +++ b/examples/delayed_response.py @@ -0,0 +1,18 @@ +from asyncio import sleep + +from sanic import Sanic, response + +app = Sanic(__name__, strict_slashes=True) + +@app.get("/") +async def handler(request): + return response.redirect("/sleep/3") + +@app.get("/sleep/") +async def handler2(request, t=0.3): + await sleep(t) + return response.text(f"Slept {t:.1f} seconds.\n") + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=8000) diff --git a/sanic/__main__.py b/sanic/__main__.py index 4f78fe5b..c9fa2e52 100644 --- a/sanic/__main__.py +++ b/sanic/__main__.py @@ -13,6 +13,7 @@ def main(): parser = ArgumentParser(prog="sanic") parser.add_argument("--host", dest="host", type=str, default="127.0.0.1") parser.add_argument("--port", dest="port", type=int, default=8000) + parser.add_argument("--unix", dest="unix", type=str, default="") parser.add_argument( "--cert", dest="cert", type=str, help="location of certificate for SSL" ) @@ -53,6 +54,7 @@ def main(): app.run( host=args.host, port=args.port, + unix=args.unix, workers=args.workers, debug=args.debug, ssl=ssl, diff --git a/sanic/app.py b/sanic/app.py index db41f021..5c805874 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1033,6 +1033,7 @@ class Sanic: stop_event: Any = None, register_sys_signals: bool = True, access_log: Optional[bool] = None, + unix: Optional[str] = None, loop: None = None, ) -> None: """Run the HTTP Server and listen until keyboard interrupt or term @@ -1066,6 +1067,8 @@ class Sanic: :type register_sys_signals: bool :param access_log: Enables writing access logs (slows server) :type access_log: bool + :param unix: Unix socket to listen on instead of TCP port + :type unix: str :return: Nothing """ if loop is not None: @@ -1104,6 +1107,7 @@ class Sanic: debug=debug, ssl=ssl, sock=sock, + unix=unix, workers=workers, protocol=protocol, backlog=backlog, @@ -1151,6 +1155,7 @@ class Sanic: backlog: int = 100, stop_event: Any = None, access_log: Optional[bool] = None, + unix: Optional[str] = None, return_asyncio_server=False, asyncio_server_kwargs=None, ) -> Optional[AsyncioServer]: @@ -1220,6 +1225,7 @@ class Sanic: debug=debug, ssl=ssl, sock=sock, + unix=unix, loop=get_event_loop(), protocol=protocol, backlog=backlog, @@ -1285,6 +1291,7 @@ class Sanic: debug=False, ssl=None, sock=None, + unix=None, workers=1, loop=None, protocol=HttpProtocol, @@ -1326,6 +1333,7 @@ class Sanic: "host": host, "port": port, "sock": sock, + "unix": unix, "ssl": ssl, "app": self, "signal": Signal(), @@ -1372,7 +1380,10 @@ class Sanic: proto = "http" if ssl is not None: proto = "https" - logger.info(f"Goin' Fast @ {proto}://{host}:{port}") + if unix: + logger.info(f"Goin' Fast @ {unix} {proto}://...") + else: + logger.info(f"Goin' Fast @ {proto}://{host}:{port}") return server_settings diff --git a/sanic/asgi.py b/sanic/asgi.py index f08cc454..2ae6f369 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -22,7 +22,7 @@ from sanic.exceptions import InvalidUsage, ServerError from sanic.log import logger from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse -from sanic.server import StreamBuffer +from sanic.server import ConnInfo, StreamBuffer from sanic.websocket import WebSocketConnection @@ -255,6 +255,7 @@ class ASGIApp: instance.transport, sanic_app, ) + instance.request.conn_info = ConnInfo(instance.transport) if sanic_app.is_request_stream: is_stream_handler = sanic_app.router.is_stream_handler( diff --git a/sanic/request.py b/sanic/request.py index 6e1a3061..0330e121 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -87,6 +87,7 @@ class Request: "_socket", "app", "body", + "conn_info", "ctx", "endpoint", "headers", @@ -117,6 +118,7 @@ class Request: # Init but do not inhale self.body_init() + self.conn_info = None self.ctx = SimpleNamespace() self.parsed_forwarded = None self.parsed_json = None @@ -349,56 +351,55 @@ class Request: self._cookies = {} return self._cookies + @property + def content_type(self): + return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) + + @property + def match_info(self): + """return matched info after resolving route""" + return self.app.router.get(self)[2] + + # Transport properties (obtained from local interface only) + @property def ip(self): """ :return: peer ip of the socket """ - if not hasattr(self, "_socket"): - self._get_address() - return self._ip + return self.conn_info.client if self.conn_info else "" @property def port(self): """ :return: peer port of the socket """ - if not hasattr(self, "_socket"): - self._get_address() - return self._port + return self.conn_info.client_port if self.conn_info else 0 @property def socket(self): - if not hasattr(self, "_socket"): - self._get_address() - return self._socket - - def _get_address(self): - self._socket = self.transport.get_extra_info("peername") or ( - None, - None, - ) - self._ip = self._socket[0] - self._port = self._socket[1] + return self.conn_info.peername if self.conn_info else (None, None) @property - def server_name(self): - """ - Attempt to get the server's external hostname in this order: - `config.SERVER_NAME`, proxied or direct Host headers - :func:`Request.host` + def path(self) -> str: + """Path of the local HTTP request.""" + return self._parsed_url.path.decode("utf-8") - :return: the server name without port number - :rtype: str - """ - server_name = self.app.config.get("SERVER_NAME") - if server_name: - host = server_name.split("//", 1)[-1].split("/", 1)[0] - return parse_host(host)[0] - return parse_host(self.host)[0] + # Proxy properties (using SERVER_NAME/forwarded/request/transport info) @property def forwarded(self): + """ + Active proxy information obtained from request headers, as specified in + Sanic configuration. + + Field names by, for, proto, host, port and path are normalized. + - for and by IPv6 addresses are bracketed + - port (int) is only set by port headers, not from host. + - path is url-unencoded + + Additional values may be available from new style Forwarded headers. + """ if self.parsed_forwarded is None: self.parsed_forwarded = ( parse_forwarded(self.headers, self.app.config) @@ -408,50 +409,30 @@ class Request: return self.parsed_forwarded @property - def server_port(self): + def remote_addr(self) -> str: """ - Attempt to get the server's external port number in this order: - `config.SERVER_NAME`, proxied or direct Host headers - :func:`Request.host`, - actual port used by the transport layer socket. - :return: server port - :rtype: int - """ - if self.forwarded: - return self.forwarded.get("port") or ( - 80 if self.scheme in ("http", "ws") else 443 - ) - return ( - parse_host(self.host)[1] - or self.transport.get_extra_info("sockname")[1] - ) - - @property - def remote_addr(self): - """Attempt to return the original client ip based on `forwarded`, - `x-forwarded-for` or `x-real-ip`. If HTTP headers are unavailable or - untrusted, returns an empty string. - - :return: original client ip. + Client IP address, if available. + 1. proxied remote address `self.forwarded['for']` + 2. local remote address `self.ip` + :return: IPv4, bracketed IPv6, UNIX socket name or arbitrary string """ if not hasattr(self, "_remote_addr"): - self._remote_addr = self.forwarded.get("for", "") + self._remote_addr = self.forwarded.get("for", "") # or self.ip return self._remote_addr @property - def scheme(self): + def scheme(self) -> str: """ - Attempt to get the request scheme. - Seeking the value in this order: - `forwarded` header, `x-forwarded-proto` header, - `x-scheme` header, the sanic app itself. - + Determine request scheme. + 1. `config.SERVER_NAME` if in full URL format + 2. proxied proto/scheme + 3. local connection protocol :return: http|https|ws|wss or arbitrary value given by the headers. - :rtype: str """ - forwarded_proto = self.forwarded.get("proto") - if forwarded_proto: - return forwarded_proto + if "//" in self.app.config.get("SERVER_NAME", ""): + return self.app.config.SERVER_NAME.split("//")[0] + if "proto" in self.forwarded: + return self.forwarded["proto"] if ( self.app.websocket_enabled @@ -467,25 +448,41 @@ class Request: return scheme @property - def host(self): + def host(self) -> str: """ - :return: proxied or direct Host header. Hostname and port number may be - separated by sanic.headers.parse_host(request.host). + The currently effective server 'host' (hostname or hostname:port). + 1. `config.SERVER_NAME` overrides any client headers + 2. proxied host of original request + 3. request host header + hostname and port may be separated by + `sanic.headers.parse_host(request.host)`. + :return: the first matching host found, or empty string """ - return self.forwarded.get("host", self.headers.get("Host", "")) + server_name = self.app.config.get("SERVER_NAME") + if server_name: + return server_name.split("//", 1)[-1].split("/", 1)[0] + return self.forwarded.get("host") or self.headers.get("host", "") @property - def content_type(self): - return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) + def server_name(self) -> str: + """The hostname the client connected to, by `request.host`.""" + return parse_host(self.host)[0] or "" @property - def match_info(self): - """return matched info after resolving route""" - return self.app.router.get(self)[2] + def server_port(self) -> int: + """ + The port the client connected to, by forwarded `port` or + `request.host`. + + Default port is returned as 80 and 443 based on `request.scheme`. + """ + port = self.forwarded.get("port") or parse_host(self.host)[1] + return port or (80 if self.scheme in ("http", "ws") else 443) @property - def path(self): - return self._parsed_url.path.decode("utf-8") + def server_path(self) -> str: + """Full path of current URL. Uses proxied or local path.""" + return self.forwarded.get("path") or self.path @property def query_string(self): diff --git a/sanic/server.py b/sanic/server.py index d408eb06..6f64ddb8 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -1,15 +1,18 @@ import asyncio import multiprocessing import os +import secrets +import socket +import stat import sys import traceback from collections import deque from functools import partial from inspect import isawaitable +from ipaddress import ip_address from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func -from socket import SO_REUSEADDR, SOL_SOCKET, socket from time import time from httptools import HttpRequestParser # type: ignore @@ -44,6 +47,41 @@ class Signal: stopped = False +class ConnInfo: + """Local and remote addresses and SSL status info.""" + + __slots__ = ( + "sockname", + "peername", + "server", + "server_port", + "client", + "client_port", + "ssl", + ) + + def __init__(self, transport, unix=None): + self.ssl = bool(transport.get_extra_info("sslcontext")) + self.server = self.client = "" + self.server_port = self.client_port = 0 + self.peername = None + self.sockname = addr = transport.get_extra_info("sockname") + if isinstance(addr, str): # UNIX socket + self.server = unix or addr + return + # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) + if isinstance(addr, tuple): + self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.server_port = addr[1] + # self.server gets non-standard port appended + if addr[1] != (443 if self.ssl else 80): + self.server = f"{self.server}:{addr[1]}" + self.peername = addr = transport.get_extra_info("peername") + if isinstance(addr, tuple): + self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.client_port = addr[1] + + class HttpProtocol(asyncio.Protocol): """ This class provides a basic HTTP implementation of the sanic framework. @@ -57,6 +95,7 @@ class HttpProtocol(asyncio.Protocol): "transport", "connections", "signal", + "conn_info", # request params "parser", "request", @@ -88,6 +127,7 @@ class HttpProtocol(asyncio.Protocol): "_keep_alive", "_header_fragment", "state", + "_unix", "_body_chunks", ) @@ -99,6 +139,7 @@ class HttpProtocol(asyncio.Protocol): signal=Signal(), connections=None, state=None, + unix=None, **kwargs, ): asyncio.set_event_loop(loop) @@ -106,6 +147,7 @@ class HttpProtocol(asyncio.Protocol): deprecated_loop = self.loop if sys.version_info < (3, 7) else None self.app = app self.transport = None + self.conn_info = None self.request = None self.parser = None self.url = None @@ -139,6 +181,7 @@ class HttpProtocol(asyncio.Protocol): self.state = state if state else {} if "requests_count" not in self.state: self.state["requests_count"] = 0 + self._unix = unix self._not_paused.set() self._body_chunks = deque() @@ -167,6 +210,7 @@ class HttpProtocol(asyncio.Protocol): self.request_timeout, self.request_timeout_callback ) self.transport = transport + self.conn_info = ConnInfo(transport, unix=self._unix) self._last_request_time = time() def connection_lost(self, exc): @@ -304,6 +348,7 @@ class HttpProtocol(asyncio.Protocol): transport=self.transport, app=self.app, ) + self.request.conn_info = self.conn_info # Remove any existing KeepAlive handler here, # It will be recreated if required on the new request. if self._keep_alive_timeout_handler: @@ -750,6 +795,7 @@ def serve( after_stop=None, ssl=None, sock=None, + unix=None, reuse_port=False, loop=None, protocol=HttpProtocol, @@ -778,6 +824,7 @@ def serve( `app` instance and `loop` :param ssl: SSLContext :param sock: Socket for the server to accept connections from + :param unix: Unix socket to listen on instead of TCP port :param reuse_port: `True` for multiple workers :param loop: asyncio compatible event loop :param run_async: bool: Do not create a new event loop for the server, @@ -804,14 +851,18 @@ def serve( signal=signal, app=app, state=state, + unix=unix, ) asyncio_server_kwargs = ( asyncio_server_kwargs if asyncio_server_kwargs else {} ) + # UNIX sockets are always bound by us (to preserve semantics between modes) + if unix: + sock = bind_unix_socket(unix, backlog=backlog) server_coroutine = loop.create_server( server, - host, - port, + None if sock else host, + None if sock else port, ssl=ssl, reuse_port=reuse_port, sock=sock, @@ -894,6 +945,85 @@ def serve( trigger_events(after_stop, loop) loop.close() + remove_unix_socket(unix) + + +def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: + """Create TCP server socket. + :param host: IPv4, IPv6 or hostname may be specified + :param port: TCP port number + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + try: # IP address: family must be specified for IPv6 at least + ip = ip_address(host) + host = str(ip) + sock = socket.socket( + socket.AF_INET6 if ip.version == 6 else socket.AF_INET + ) + except ValueError: # Hostname, may become AF_INET or AF_INET6 + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(backlog) + return sock + + +def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: + """Create unix socket. + :param path: filesystem path + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + """Open or atomically replace existing socket with zero downtime.""" + # Sanitise and pre-verify socket path + path = os.path.abspath(path) + folder = os.path.dirname(path) + if not os.path.isdir(folder): + raise FileNotFoundError(f"Socket folder does not exist: {folder}") + try: + if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + raise FileExistsError(f"Existing file is not a socket: {path}") + except FileNotFoundError: + pass + # Create new socket with a random temporary name + tmp_path = f"{path}.{secrets.token_urlsafe()}" + sock = socket.socket(socket.AF_UNIX) + try: + # Critical section begins (filename races) + sock.bind(tmp_path) + try: + os.chmod(tmp_path, mode) + # Start listening before rename to avoid connection failures + sock.listen(backlog) + os.rename(tmp_path, path) + except: # noqa: E722 + try: + os.unlink(tmp_path) + finally: + raise + except: # noqa: E722 + try: + sock.close() + finally: + raise + return sock + + +def remove_unix_socket(path: str) -> None: + """Remove dead unix socket during server exit.""" + if not path: + return + try: + if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + # Is it actually dead (doesn't belong to a new server instance)? + with socket.socket(socket.AF_UNIX) as testsock: + try: + testsock.connect(path) + except ConnectionRefusedError: + os.unlink(path) + except FileNotFoundError: + pass def serve_multiple(server_settings, workers): @@ -908,11 +1038,17 @@ def serve_multiple(server_settings, workers): server_settings["reuse_port"] = True server_settings["run_multiple"] = True - # Handling when custom socket is not provided. - if server_settings.get("sock") is None: - sock = socket() - sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - sock.bind((server_settings["host"], server_settings["port"])) + # Create a listening socket or use the one in settings + sock = server_settings.get("sock") + unix = server_settings["unix"] + backlog = server_settings["backlog"] + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_settings["unix"] = unix + if sock is None: + sock = bind_socket( + server_settings["host"], server_settings["port"], backlog=backlog + ) sock.set_inheritable(True) server_settings["sock"] = sock server_settings["host"] = None @@ -941,4 +1077,6 @@ def serve_multiple(server_settings, workers): # the above processes will block this until they're stopped for process in processes: process.terminate() - server_settings.get("sock").close() + + sock.close() + remove_unix_socket(unix) diff --git a/tests/test_requests.py b/tests/test_requests.py index 12c06254..31883e37 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -454,11 +454,13 @@ def test_standard_forwarded(app): "X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1", "X-Scheme": "ws", + "Host": "local.site", } request, response = app.test_client.get("/", headers=headers) assert response.json == {"for": "127.0.0.2", "proto": "ws"} assert request.remote_addr == "127.0.0.2" assert request.scheme == "ws" + assert request.server_name == "local.site" assert request.server_port == 80 app.config.FORWARDED_SECRET = "mySecret" @@ -1807,13 +1809,17 @@ def test_request_port(app): port = request.port assert isinstance(port, int) - delattr(request, "_socket") - delattr(request, "_port") + +@pytest.mark.asyncio +async def test_request_port_asgi(app): + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/") port = request.port assert isinstance(port, int) - assert hasattr(request, "_socket") - assert hasattr(request, "_port") def test_request_socket(app): @@ -1832,12 +1838,6 @@ def test_request_socket(app): assert ip == request.ip assert port == request.port - delattr(request, "_socket") - - socket = request.socket - assert isinstance(socket, tuple) - assert hasattr(request, "_socket") - def test_request_server_name(app): @app.get("/") @@ -1866,7 +1866,7 @@ def test_request_server_name_in_host_header(app): request, response = app.test_client.get( "/", headers={"Host": "mal_formed"} ) - assert request.server_name == None # For now (later maybe 127.0.0.1) + assert request.server_name == "" def test_request_server_name_forwarded(app): @@ -1893,7 +1893,7 @@ def test_request_server_port(app): test_client = SanicTestClient(app) request, response = test_client.get("/", headers={"Host": "my-server"}) - assert request.server_port == test_client.port + assert request.server_port == 80 def test_request_server_port_in_host_header(app): @@ -1952,12 +1952,12 @@ def test_server_name_and_url_for(app): def handler(request): return text("ok") - app.config.SERVER_NAME = "my-server" + app.config.SERVER_NAME = "my-server" # This means default port assert app.url_for("handler", _external=True) == "http://my-server/foo" request, response = app.test_client.get("/foo") assert ( request.url_for("handler") - == f"http://my-server:{request.server_port}/foo" + == f"http://my-server/foo" ) app.config.SERVER_NAME = "https://my-server/path" diff --git a/tests/test_unix_socket.py b/tests/test_unix_socket.py new file mode 100644 index 00000000..bbffc890 --- /dev/null +++ b/tests/test_unix_socket.py @@ -0,0 +1,235 @@ +import asyncio +import logging +import os +import subprocess +import sys + +import httpx +import pytest + +from sanic import Sanic +from sanic.response import text + + +pytestmark = pytest.mark.skipif(os.name != "posix", reason="UNIX only") +SOCKPATH = "/tmp/sanictest.sock" +SOCKPATH2 = "/tmp/sanictest2.sock" + + +@pytest.fixture(autouse=True) +def socket_cleanup(): + try: + os.unlink(SOCKPATH) + except FileNotFoundError: + pass + try: + os.unlink(SOCKPATH2) + except FileNotFoundError: + pass + # Run test function + yield + try: + os.unlink(SOCKPATH2) + except FileNotFoundError: + pass + try: + os.unlink(SOCKPATH) + except FileNotFoundError: + pass + + +def test_unix_socket_creation(caplog): + from socket import AF_UNIX, socket + + with socket(AF_UNIX) as sock: + sock.bind(SOCKPATH) + assert os.path.exists(SOCKPATH) + ino = os.stat(SOCKPATH).st_ino + + app = Sanic(name=__name__) + + @app.listener("after_server_start") + def running(app, loop): + assert os.path.exists(SOCKPATH) + assert ino != os.stat(SOCKPATH).st_ino + app.stop() + + with caplog.at_level(logging.INFO): + app.run(unix=SOCKPATH) + + assert ( + "sanic.root", + logging.INFO, + f"Goin' Fast @ {SOCKPATH} http://...", + ) in caplog.record_tuples + assert not os.path.exists(SOCKPATH) + + +def test_invalid_paths(): + app = Sanic(name=__name__) + + with pytest.raises(FileExistsError): + app.run(unix=".") + + with pytest.raises(FileNotFoundError): + app.run(unix="no-such-directory/sanictest.sock") + + +def test_dont_replace_file(): + with open(SOCKPATH, "w") as f: + f.write("File, not socket") + + app = Sanic(name=__name__) + + @app.listener("after_server_start") + def stop(app, loop): + app.stop() + + with pytest.raises(FileExistsError): + app.run(unix=SOCKPATH) + + +def test_dont_follow_symlink(): + from socket import AF_UNIX, socket + + with socket(AF_UNIX) as sock: + sock.bind(SOCKPATH2) + os.symlink(SOCKPATH2, SOCKPATH) + + app = Sanic(name=__name__) + + @app.listener("after_server_start") + def stop(app, loop): + app.stop() + + with pytest.raises(FileExistsError): + app.run(unix=SOCKPATH) + + +def test_socket_deleted_while_running(): + app = Sanic(name=__name__) + + @app.listener("after_server_start") + async def hack(app, loop): + os.unlink(SOCKPATH) + app.stop() + + app.run(host="myhost.invalid", unix=SOCKPATH) + + +def test_socket_replaced_with_file(): + app = Sanic(name=__name__) + + @app.listener("after_server_start") + async def hack(app, loop): + os.unlink(SOCKPATH) + with open(SOCKPATH, "w") as f: + f.write("Not a socket") + app.stop() + + app.run(host="myhost.invalid", unix=SOCKPATH) + + +def test_unix_connection(): + app = Sanic(name=__name__) + + @app.get("/") + def handler(request): + return text(f"{request.conn_info.server}") + + @app.listener("after_server_start") + async def client(app, loop): + try: + async with httpx.AsyncClient(uds=SOCKPATH) as client: + r = await client.get("http://myhost.invalid/") + assert r.status_code == 200 + assert r.text == os.path.abspath(SOCKPATH) + finally: + app.stop() + + app.run(host="myhost.invalid", unix=SOCKPATH) + + +app_multi = Sanic(name=__name__) + + +def handler(request): + return text(f"{request.conn_info.server}") + + +async def client(app, loop): + try: + async with httpx.AsyncClient(uds=SOCKPATH) as client: + r = await client.get("http://myhost.invalid/") + assert r.status_code == 200 + assert r.text == os.path.abspath(SOCKPATH) + finally: + app.stop() + + +def test_unix_connection_multiple_workers(): + app_multi.get("/")(handler) + app_multi.listener("after_server_start")(client) + app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2) + + +async def test_zero_downtime(): + """Graceful server termination and socket replacement on restarts""" + from signal import SIGINT + from time import monotonic as current_time + + async def client(): + for _ in range(40): + async with httpx.AsyncClient(uds=SOCKPATH) as client: + r = await client.get("http://localhost/sleep/0.1") + assert r.status_code == 200 + assert r.text == f"Slept 0.1 seconds.\n" + + def spawn(): + command = [ + sys.executable, + "-m", + "sanic", + "--unix", + SOCKPATH, + "examples.delayed_response.app", + ] + DN = subprocess.DEVNULL + return subprocess.Popen( + command, stdin=DN, stdout=DN, stderr=subprocess.PIPE + ) + + try: + processes = [spawn()] + while not os.path.exists(SOCKPATH): + if processes[0].poll() is not None: + raise Exception("Worker did not start properly") + await asyncio.sleep(0.0001) + ino = os.stat(SOCKPATH).st_ino + task = asyncio.get_event_loop().create_task(client()) + start_time = current_time() + while current_time() < start_time + 4: + # Start a new one and wait until the socket is replaced + processes.append(spawn()) + while ino == os.stat(SOCKPATH).st_ino: + await asyncio.sleep(0.001) + ino = os.stat(SOCKPATH).st_ino + # Graceful termination of the previous one + processes[-2].send_signal(SIGINT) + # Wait until client has completed all requests + await task + processes[-1].send_signal(SIGINT) + for worker in processes: + try: + worker.wait(1.0) + except subprocess.TimeoutExpired: + raise Exception( + f"Worker would not terminate:\n{worker.stderr}" + ) + finally: + for worker in processes: + worker.kill() + # Test for clean run and termination + assert len(processes) > 5 + assert [worker.poll() for worker in processes] == len(processes) * [0] + assert not os.path.exists(SOCKPATH)