Socket binding implemented properly for IPv6 and UNIX sockets. (#1641)

* Socket binding implemented properly for IPv6 and UNIX sockets.

- app.run("::1") for IPv6
- app.run("unix:/tmp/server.sock") for UNIX sockets
- app.run("localhost") retains old functionality (randomly either IPv4 or IPv6)

Do note that IPv6 and UNIX sockets are not fully supported by other Sanic facilities.
In particular, request.server_name and request.server_port are currently unreliable.

* Fix Windows compatibility by not referring to socket.AF_UNIX unless needed.

* Compatibility fix.

* Fix test of existing unix socket.

* Cleaner unix socket removal.

* Remove unix socket on exit also with workers=1.

* More pedantic UNIX socket implementation.

* Refactor app to take unix= argument instead of unix:-prefixed host. Goin' fast @ unix-socket fixed.

* Linter

* Proxy properties cleanup. Slight changes of semantics. SERVER_NAME now overrides everything.

* Have server fill in connection info instead of request asking the socket.

- Would be a good idea to remove request.transport entirely but I didn't dare to touch it yet.

* Linter 💣🌟💀

* Fix typing issues. request.server_name returns empty string if host header is missing.

* Fix tests

* Tests were failing, fix connection info.

* Linter nazi says you need that empty line.

* Rename a to addr, leave client empty for unix sockets.

* Add --unix support when sanic is run as module.

* Remove remove_route, deprecated in 19.6.

* Improved unix socket binding.

* More robust creating and unlinking of sockets. Show proper and not temporary name in conn_info.

* Add comprehensive tests for unix socket mode.

* Hide some imports inside functions to avoid Windows failure.

* Mention unix socket mode in deployment docs.

* Fix merge commit.

* Make test_unix_connection_multiple_workers pickleable for spawn mode multiprocessing.

Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
L. Kärkkäinen 2020-06-29 08:55:32 +03:00 committed by GitHub
parent 4aba74d050
commit a62c84a954
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 504 additions and 101 deletions

View File

@ -16,6 +16,7 @@ keyword arguments:
- `host` *(default `"127.0.0.1"`)*: Address to host the server on. - `host` *(default `"127.0.0.1"`)*: Address to host the server on.
- `port` *(default `8000`)*: Port to host the server on. - `port` *(default `8000`)*: Port to host the server on.
- `unix` *(default `None`)*: Unix socket name to host the server on (instead of TCP).
- `debug` *(default `False`)*: Enables debug output (slows server). - `debug` *(default `False`)*: Enables debug output (slows server).
- `ssl` *(default `None`)*: `SSLContext` for SSL encryption of worker(s). - `ssl` *(default `None`)*: `SSLContext` for SSL encryption of worker(s).
- `sock` *(default `None`)*: Socket for the server to accept connections from. - `sock` *(default `None`)*: Socket for the server to accept connections from.

View File

@ -0,0 +1,18 @@
from asyncio import sleep
from sanic import Sanic, response
app = Sanic(__name__, strict_slashes=True)
@app.get("/")
async def handler(request):
return response.redirect("/sleep/3")
@app.get("/sleep/<t:number>")
async def handler2(request, t=0.3):
await sleep(t)
return response.text(f"Slept {t:.1f} seconds.\n")
if __name__ == '__main__':
app.run(host="0.0.0.0", port=8000)

View File

@ -13,6 +13,7 @@ def main():
parser = ArgumentParser(prog="sanic") parser = ArgumentParser(prog="sanic")
parser.add_argument("--host", dest="host", type=str, default="127.0.0.1") parser.add_argument("--host", dest="host", type=str, default="127.0.0.1")
parser.add_argument("--port", dest="port", type=int, default=8000) parser.add_argument("--port", dest="port", type=int, default=8000)
parser.add_argument("--unix", dest="unix", type=str, default="")
parser.add_argument( parser.add_argument(
"--cert", dest="cert", type=str, help="location of certificate for SSL" "--cert", dest="cert", type=str, help="location of certificate for SSL"
) )
@ -53,6 +54,7 @@ def main():
app.run( app.run(
host=args.host, host=args.host,
port=args.port, port=args.port,
unix=args.unix,
workers=args.workers, workers=args.workers,
debug=args.debug, debug=args.debug,
ssl=ssl, ssl=ssl,

View File

@ -1033,6 +1033,7 @@ class Sanic:
stop_event: Any = None, stop_event: Any = None,
register_sys_signals: bool = True, register_sys_signals: bool = True,
access_log: Optional[bool] = None, access_log: Optional[bool] = None,
unix: Optional[str] = None,
loop: None = None, loop: None = None,
) -> None: ) -> None:
"""Run the HTTP Server and listen until keyboard interrupt or term """Run the HTTP Server and listen until keyboard interrupt or term
@ -1066,6 +1067,8 @@ class Sanic:
:type register_sys_signals: bool :type register_sys_signals: bool
:param access_log: Enables writing access logs (slows server) :param access_log: Enables writing access logs (slows server)
:type access_log: bool :type access_log: bool
:param unix: Unix socket to listen on instead of TCP port
:type unix: str
:return: Nothing :return: Nothing
""" """
if loop is not None: if loop is not None:
@ -1104,6 +1107,7 @@ class Sanic:
debug=debug, debug=debug,
ssl=ssl, ssl=ssl,
sock=sock, sock=sock,
unix=unix,
workers=workers, workers=workers,
protocol=protocol, protocol=protocol,
backlog=backlog, backlog=backlog,
@ -1151,6 +1155,7 @@ class Sanic:
backlog: int = 100, backlog: int = 100,
stop_event: Any = None, stop_event: Any = None,
access_log: Optional[bool] = None, access_log: Optional[bool] = None,
unix: Optional[str] = None,
return_asyncio_server=False, return_asyncio_server=False,
asyncio_server_kwargs=None, asyncio_server_kwargs=None,
) -> Optional[AsyncioServer]: ) -> Optional[AsyncioServer]:
@ -1220,6 +1225,7 @@ class Sanic:
debug=debug, debug=debug,
ssl=ssl, ssl=ssl,
sock=sock, sock=sock,
unix=unix,
loop=get_event_loop(), loop=get_event_loop(),
protocol=protocol, protocol=protocol,
backlog=backlog, backlog=backlog,
@ -1285,6 +1291,7 @@ class Sanic:
debug=False, debug=False,
ssl=None, ssl=None,
sock=None, sock=None,
unix=None,
workers=1, workers=1,
loop=None, loop=None,
protocol=HttpProtocol, protocol=HttpProtocol,
@ -1326,6 +1333,7 @@ class Sanic:
"host": host, "host": host,
"port": port, "port": port,
"sock": sock, "sock": sock,
"unix": unix,
"ssl": ssl, "ssl": ssl,
"app": self, "app": self,
"signal": Signal(), "signal": Signal(),
@ -1372,7 +1380,10 @@ class Sanic:
proto = "http" proto = "http"
if ssl is not None: if ssl is not None:
proto = "https" proto = "https"
logger.info(f"Goin' Fast @ {proto}://{host}:{port}") if unix:
logger.info(f"Goin' Fast @ {unix} {proto}://...")
else:
logger.info(f"Goin' Fast @ {proto}://{host}:{port}")
return server_settings return server_settings

View File

@ -22,7 +22,7 @@ from sanic.exceptions import InvalidUsage, ServerError
from sanic.log import logger from sanic.log import logger
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.server import StreamBuffer from sanic.server import ConnInfo, StreamBuffer
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
@ -255,6 +255,7 @@ class ASGIApp:
instance.transport, instance.transport,
sanic_app, sanic_app,
) )
instance.request.conn_info = ConnInfo(instance.transport)
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
is_stream_handler = sanic_app.router.is_stream_handler( is_stream_handler = sanic_app.router.is_stream_handler(

View File

@ -87,6 +87,7 @@ class Request:
"_socket", "_socket",
"app", "app",
"body", "body",
"conn_info",
"ctx", "ctx",
"endpoint", "endpoint",
"headers", "headers",
@ -117,6 +118,7 @@ class Request:
# Init but do not inhale # Init but do not inhale
self.body_init() self.body_init()
self.conn_info = None
self.ctx = SimpleNamespace() self.ctx = SimpleNamespace()
self.parsed_forwarded = None self.parsed_forwarded = None
self.parsed_json = None self.parsed_json = None
@ -349,56 +351,55 @@ class Request:
self._cookies = {} self._cookies = {}
return self._cookies return self._cookies
@property
def content_type(self):
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE)
@property
def match_info(self):
"""return matched info after resolving route"""
return self.app.router.get(self)[2]
# Transport properties (obtained from local interface only)
@property @property
def ip(self): def ip(self):
""" """
:return: peer ip of the socket :return: peer ip of the socket
""" """
if not hasattr(self, "_socket"): return self.conn_info.client if self.conn_info else ""
self._get_address()
return self._ip
@property @property
def port(self): def port(self):
""" """
:return: peer port of the socket :return: peer port of the socket
""" """
if not hasattr(self, "_socket"): return self.conn_info.client_port if self.conn_info else 0
self._get_address()
return self._port
@property @property
def socket(self): def socket(self):
if not hasattr(self, "_socket"): return self.conn_info.peername if self.conn_info else (None, None)
self._get_address()
return self._socket
def _get_address(self):
self._socket = self.transport.get_extra_info("peername") or (
None,
None,
)
self._ip = self._socket[0]
self._port = self._socket[1]
@property @property
def server_name(self): def path(self) -> str:
""" """Path of the local HTTP request."""
Attempt to get the server's external hostname in this order: return self._parsed_url.path.decode("utf-8")
`config.SERVER_NAME`, proxied or direct Host headers
:func:`Request.host`
:return: the server name without port number # Proxy properties (using SERVER_NAME/forwarded/request/transport info)
:rtype: str
"""
server_name = self.app.config.get("SERVER_NAME")
if server_name:
host = server_name.split("//", 1)[-1].split("/", 1)[0]
return parse_host(host)[0]
return parse_host(self.host)[0]
@property @property
def forwarded(self): def forwarded(self):
"""
Active proxy information obtained from request headers, as specified in
Sanic configuration.
Field names by, for, proto, host, port and path are normalized.
- for and by IPv6 addresses are bracketed
- port (int) is only set by port headers, not from host.
- path is url-unencoded
Additional values may be available from new style Forwarded headers.
"""
if self.parsed_forwarded is None: if self.parsed_forwarded is None:
self.parsed_forwarded = ( self.parsed_forwarded = (
parse_forwarded(self.headers, self.app.config) parse_forwarded(self.headers, self.app.config)
@ -408,50 +409,30 @@ class Request:
return self.parsed_forwarded return self.parsed_forwarded
@property @property
def server_port(self): def remote_addr(self) -> str:
""" """
Attempt to get the server's external port number in this order: Client IP address, if available.
`config.SERVER_NAME`, proxied or direct Host headers 1. proxied remote address `self.forwarded['for']`
:func:`Request.host`, 2. local remote address `self.ip`
actual port used by the transport layer socket. :return: IPv4, bracketed IPv6, UNIX socket name or arbitrary string
:return: server port
:rtype: int
"""
if self.forwarded:
return self.forwarded.get("port") or (
80 if self.scheme in ("http", "ws") else 443
)
return (
parse_host(self.host)[1]
or self.transport.get_extra_info("sockname")[1]
)
@property
def remote_addr(self):
"""Attempt to return the original client ip based on `forwarded`,
`x-forwarded-for` or `x-real-ip`. If HTTP headers are unavailable or
untrusted, returns an empty string.
:return: original client ip.
""" """
if not hasattr(self, "_remote_addr"): if not hasattr(self, "_remote_addr"):
self._remote_addr = self.forwarded.get("for", "") self._remote_addr = self.forwarded.get("for", "") # or self.ip
return self._remote_addr return self._remote_addr
@property @property
def scheme(self): def scheme(self) -> str:
""" """
Attempt to get the request scheme. Determine request scheme.
Seeking the value in this order: 1. `config.SERVER_NAME` if in full URL format
`forwarded` header, `x-forwarded-proto` header, 2. proxied proto/scheme
`x-scheme` header, the sanic app itself. 3. local connection protocol
:return: http|https|ws|wss or arbitrary value given by the headers. :return: http|https|ws|wss or arbitrary value given by the headers.
:rtype: str
""" """
forwarded_proto = self.forwarded.get("proto") if "//" in self.app.config.get("SERVER_NAME", ""):
if forwarded_proto: return self.app.config.SERVER_NAME.split("//")[0]
return forwarded_proto if "proto" in self.forwarded:
return self.forwarded["proto"]
if ( if (
self.app.websocket_enabled self.app.websocket_enabled
@ -467,25 +448,41 @@ class Request:
return scheme return scheme
@property @property
def host(self): def host(self) -> str:
""" """
:return: proxied or direct Host header. Hostname and port number may be The currently effective server 'host' (hostname or hostname:port).
separated by sanic.headers.parse_host(request.host). 1. `config.SERVER_NAME` overrides any client headers
2. proxied host of original request
3. request host header
hostname and port may be separated by
`sanic.headers.parse_host(request.host)`.
:return: the first matching host found, or empty string
""" """
return self.forwarded.get("host", self.headers.get("Host", "")) server_name = self.app.config.get("SERVER_NAME")
if server_name:
return server_name.split("//", 1)[-1].split("/", 1)[0]
return self.forwarded.get("host") or self.headers.get("host", "")
@property @property
def content_type(self): def server_name(self) -> str:
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) """The hostname the client connected to, by `request.host`."""
return parse_host(self.host)[0] or ""
@property @property
def match_info(self): def server_port(self) -> int:
"""return matched info after resolving route""" """
return self.app.router.get(self)[2] The port the client connected to, by forwarded `port` or
`request.host`.
Default port is returned as 80 and 443 based on `request.scheme`.
"""
port = self.forwarded.get("port") or parse_host(self.host)[1]
return port or (80 if self.scheme in ("http", "ws") else 443)
@property @property
def path(self): def server_path(self) -> str:
return self._parsed_url.path.decode("utf-8") """Full path of current URL. Uses proxied or local path."""
return self.forwarded.get("path") or self.path
@property @property
def query_string(self): def query_string(self):

View File

@ -1,15 +1,18 @@
import asyncio import asyncio
import multiprocessing import multiprocessing
import os import os
import secrets
import socket
import stat
import sys import sys
import traceback import traceback
from collections import deque from collections import deque
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from ipaddress import ip_address
from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from socket import SO_REUSEADDR, SOL_SOCKET, socket
from time import time from time import time
from httptools import HttpRequestParser # type: ignore from httptools import HttpRequestParser # type: ignore
@ -44,6 +47,41 @@ class Signal:
stopped = False stopped = False
class ConnInfo:
"""Local and remote addresses and SSL status info."""
__slots__ = (
"sockname",
"peername",
"server",
"server_port",
"client",
"client_port",
"ssl",
)
def __init__(self, transport, unix=None):
self.ssl = bool(transport.get_extra_info("sslcontext"))
self.server = self.client = ""
self.server_port = self.client_port = 0
self.peername = None
self.sockname = addr = transport.get_extra_info("sockname")
if isinstance(addr, str): # UNIX socket
self.server = unix or addr
return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.server_port = addr[1]
# self.server gets non-standard port appended
if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername")
if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.client_port = addr[1]
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
""" """
This class provides a basic HTTP implementation of the sanic framework. This class provides a basic HTTP implementation of the sanic framework.
@ -57,6 +95,7 @@ class HttpProtocol(asyncio.Protocol):
"transport", "transport",
"connections", "connections",
"signal", "signal",
"conn_info",
# request params # request params
"parser", "parser",
"request", "request",
@ -88,6 +127,7 @@ class HttpProtocol(asyncio.Protocol):
"_keep_alive", "_keep_alive",
"_header_fragment", "_header_fragment",
"state", "state",
"_unix",
"_body_chunks", "_body_chunks",
) )
@ -99,6 +139,7 @@ class HttpProtocol(asyncio.Protocol):
signal=Signal(), signal=Signal(),
connections=None, connections=None,
state=None, state=None,
unix=None,
**kwargs, **kwargs,
): ):
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -106,6 +147,7 @@ class HttpProtocol(asyncio.Protocol):
deprecated_loop = self.loop if sys.version_info < (3, 7) else None deprecated_loop = self.loop if sys.version_info < (3, 7) else None
self.app = app self.app = app
self.transport = None self.transport = None
self.conn_info = None
self.request = None self.request = None
self.parser = None self.parser = None
self.url = None self.url = None
@ -139,6 +181,7 @@ class HttpProtocol(asyncio.Protocol):
self.state = state if state else {} self.state = state if state else {}
if "requests_count" not in self.state: if "requests_count" not in self.state:
self.state["requests_count"] = 0 self.state["requests_count"] = 0
self._unix = unix
self._not_paused.set() self._not_paused.set()
self._body_chunks = deque() self._body_chunks = deque()
@ -167,6 +210,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_timeout, self.request_timeout_callback self.request_timeout, self.request_timeout_callback
) )
self.transport = transport self.transport = transport
self.conn_info = ConnInfo(transport, unix=self._unix)
self._last_request_time = time() self._last_request_time = time()
def connection_lost(self, exc): def connection_lost(self, exc):
@ -304,6 +348,7 @@ class HttpProtocol(asyncio.Protocol):
transport=self.transport, transport=self.transport,
app=self.app, app=self.app,
) )
self.request.conn_info = self.conn_info
# Remove any existing KeepAlive handler here, # Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request. # It will be recreated if required on the new request.
if self._keep_alive_timeout_handler: if self._keep_alive_timeout_handler:
@ -750,6 +795,7 @@ def serve(
after_stop=None, after_stop=None,
ssl=None, ssl=None,
sock=None, sock=None,
unix=None,
reuse_port=False, reuse_port=False,
loop=None, loop=None,
protocol=HttpProtocol, protocol=HttpProtocol,
@ -778,6 +824,7 @@ def serve(
`app` instance and `loop` `app` instance and `loop`
:param ssl: SSLContext :param ssl: SSLContext
:param sock: Socket for the server to accept connections from :param sock: Socket for the server to accept connections from
:param unix: Unix socket to listen on instead of TCP port
:param reuse_port: `True` for multiple workers :param reuse_port: `True` for multiple workers
:param loop: asyncio compatible event loop :param loop: asyncio compatible event loop
:param run_async: bool: Do not create a new event loop for the server, :param run_async: bool: Do not create a new event loop for the server,
@ -804,14 +851,18 @@ def serve(
signal=signal, signal=signal,
app=app, app=app,
state=state, state=state,
unix=unix,
) )
asyncio_server_kwargs = ( asyncio_server_kwargs = (
asyncio_server_kwargs if asyncio_server_kwargs else {} asyncio_server_kwargs if asyncio_server_kwargs else {}
) )
# UNIX sockets are always bound by us (to preserve semantics between modes)
if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_coroutine = loop.create_server( server_coroutine = loop.create_server(
server, server,
host, None if sock else host,
port, None if sock else port,
ssl=ssl, ssl=ssl,
reuse_port=reuse_port, reuse_port=reuse_port,
sock=sock, sock=sock,
@ -894,6 +945,85 @@ def serve(
trigger_events(after_stop, loop) trigger_events(after_stop, loop)
loop.close() loop.close()
remove_unix_socket(unix)
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
"""Create TCP server socket.
:param host: IPv4, IPv6 or hostname may be specified
:param port: TCP port number
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
try: # IP address: family must be specified for IPv6 at least
ip = ip_address(host)
host = str(ip)
sock = socket.socket(
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
)
except ValueError: # Hostname, may become AF_INET or AF_INET6
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(backlog)
return sock
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
"""Create unix socket.
:param path: filesystem path
:param backlog: Maximum number of connections to queue
:return: socket.socket object
"""
"""Open or atomically replace existing socket with zero downtime."""
# Sanitise and pre-verify socket path
path = os.path.abspath(path)
folder = os.path.dirname(path)
if not os.path.isdir(folder):
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
try:
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
raise FileExistsError(f"Existing file is not a socket: {path}")
except FileNotFoundError:
pass
# Create new socket with a random temporary name
tmp_path = f"{path}.{secrets.token_urlsafe()}"
sock = socket.socket(socket.AF_UNIX)
try:
# Critical section begins (filename races)
sock.bind(tmp_path)
try:
os.chmod(tmp_path, mode)
# Start listening before rename to avoid connection failures
sock.listen(backlog)
os.rename(tmp_path, path)
except: # noqa: E722
try:
os.unlink(tmp_path)
finally:
raise
except: # noqa: E722
try:
sock.close()
finally:
raise
return sock
def remove_unix_socket(path: str) -> None:
"""Remove dead unix socket during server exit."""
if not path:
return
try:
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
# Is it actually dead (doesn't belong to a new server instance)?
with socket.socket(socket.AF_UNIX) as testsock:
try:
testsock.connect(path)
except ConnectionRefusedError:
os.unlink(path)
except FileNotFoundError:
pass
def serve_multiple(server_settings, workers): def serve_multiple(server_settings, workers):
@ -908,11 +1038,17 @@ def serve_multiple(server_settings, workers):
server_settings["reuse_port"] = True server_settings["reuse_port"] = True
server_settings["run_multiple"] = True server_settings["run_multiple"] = True
# Handling when custom socket is not provided. # Create a listening socket or use the one in settings
if server_settings.get("sock") is None: sock = server_settings.get("sock")
sock = socket() unix = server_settings["unix"]
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) backlog = server_settings["backlog"]
sock.bind((server_settings["host"], server_settings["port"])) if unix:
sock = bind_unix_socket(unix, backlog=backlog)
server_settings["unix"] = unix
if sock is None:
sock = bind_socket(
server_settings["host"], server_settings["port"], backlog=backlog
)
sock.set_inheritable(True) sock.set_inheritable(True)
server_settings["sock"] = sock server_settings["sock"] = sock
server_settings["host"] = None server_settings["host"] = None
@ -941,4 +1077,6 @@ def serve_multiple(server_settings, workers):
# the above processes will block this until they're stopped # the above processes will block this until they're stopped
for process in processes: for process in processes:
process.terminate() process.terminate()
server_settings.get("sock").close()
sock.close()
remove_unix_socket(unix)

View File

@ -454,11 +454,13 @@ def test_standard_forwarded(app):
"X-Real-IP": "127.0.0.2", "X-Real-IP": "127.0.0.2",
"X-Forwarded-For": "127.0.1.1", "X-Forwarded-For": "127.0.1.1",
"X-Scheme": "ws", "X-Scheme": "ws",
"Host": "local.site",
} }
request, response = app.test_client.get("/", headers=headers) request, response = app.test_client.get("/", headers=headers)
assert response.json == {"for": "127.0.0.2", "proto": "ws"} assert response.json == {"for": "127.0.0.2", "proto": "ws"}
assert request.remote_addr == "127.0.0.2" assert request.remote_addr == "127.0.0.2"
assert request.scheme == "ws" assert request.scheme == "ws"
assert request.server_name == "local.site"
assert request.server_port == 80 assert request.server_port == 80
app.config.FORWARDED_SECRET = "mySecret" app.config.FORWARDED_SECRET = "mySecret"
@ -1807,13 +1809,17 @@ def test_request_port(app):
port = request.port port = request.port
assert isinstance(port, int) assert isinstance(port, int)
delattr(request, "_socket")
delattr(request, "_port") @pytest.mark.asyncio
async def test_request_port_asgi(app):
@app.get("/")
def handler(request):
return text("OK")
request, response = await app.asgi_client.get("/")
port = request.port port = request.port
assert isinstance(port, int) assert isinstance(port, int)
assert hasattr(request, "_socket")
assert hasattr(request, "_port")
def test_request_socket(app): def test_request_socket(app):
@ -1832,12 +1838,6 @@ def test_request_socket(app):
assert ip == request.ip assert ip == request.ip
assert port == request.port assert port == request.port
delattr(request, "_socket")
socket = request.socket
assert isinstance(socket, tuple)
assert hasattr(request, "_socket")
def test_request_server_name(app): def test_request_server_name(app):
@app.get("/") @app.get("/")
@ -1866,7 +1866,7 @@ def test_request_server_name_in_host_header(app):
request, response = app.test_client.get( request, response = app.test_client.get(
"/", headers={"Host": "mal_formed"} "/", headers={"Host": "mal_formed"}
) )
assert request.server_name == None # For now (later maybe 127.0.0.1) assert request.server_name == ""
def test_request_server_name_forwarded(app): def test_request_server_name_forwarded(app):
@ -1893,7 +1893,7 @@ def test_request_server_port(app):
test_client = SanicTestClient(app) test_client = SanicTestClient(app)
request, response = test_client.get("/", headers={"Host": "my-server"}) request, response = test_client.get("/", headers={"Host": "my-server"})
assert request.server_port == test_client.port assert request.server_port == 80
def test_request_server_port_in_host_header(app): def test_request_server_port_in_host_header(app):
@ -1952,12 +1952,12 @@ def test_server_name_and_url_for(app):
def handler(request): def handler(request):
return text("ok") return text("ok")
app.config.SERVER_NAME = "my-server" app.config.SERVER_NAME = "my-server" # This means default port
assert app.url_for("handler", _external=True) == "http://my-server/foo" assert app.url_for("handler", _external=True) == "http://my-server/foo"
request, response = app.test_client.get("/foo") request, response = app.test_client.get("/foo")
assert ( assert (
request.url_for("handler") request.url_for("handler")
== f"http://my-server:{request.server_port}/foo" == f"http://my-server/foo"
) )
app.config.SERVER_NAME = "https://my-server/path" app.config.SERVER_NAME = "https://my-server/path"

235
tests/test_unix_socket.py Normal file
View File

@ -0,0 +1,235 @@
import asyncio
import logging
import os
import subprocess
import sys
import httpx
import pytest
from sanic import Sanic
from sanic.response import text
pytestmark = pytest.mark.skipif(os.name != "posix", reason="UNIX only")
SOCKPATH = "/tmp/sanictest.sock"
SOCKPATH2 = "/tmp/sanictest2.sock"
@pytest.fixture(autouse=True)
def socket_cleanup():
try:
os.unlink(SOCKPATH)
except FileNotFoundError:
pass
try:
os.unlink(SOCKPATH2)
except FileNotFoundError:
pass
# Run test function
yield
try:
os.unlink(SOCKPATH2)
except FileNotFoundError:
pass
try:
os.unlink(SOCKPATH)
except FileNotFoundError:
pass
def test_unix_socket_creation(caplog):
from socket import AF_UNIX, socket
with socket(AF_UNIX) as sock:
sock.bind(SOCKPATH)
assert os.path.exists(SOCKPATH)
ino = os.stat(SOCKPATH).st_ino
app = Sanic(name=__name__)
@app.listener("after_server_start")
def running(app, loop):
assert os.path.exists(SOCKPATH)
assert ino != os.stat(SOCKPATH).st_ino
app.stop()
with caplog.at_level(logging.INFO):
app.run(unix=SOCKPATH)
assert (
"sanic.root",
logging.INFO,
f"Goin' Fast @ {SOCKPATH} http://...",
) in caplog.record_tuples
assert not os.path.exists(SOCKPATH)
def test_invalid_paths():
app = Sanic(name=__name__)
with pytest.raises(FileExistsError):
app.run(unix=".")
with pytest.raises(FileNotFoundError):
app.run(unix="no-such-directory/sanictest.sock")
def test_dont_replace_file():
with open(SOCKPATH, "w") as f:
f.write("File, not socket")
app = Sanic(name=__name__)
@app.listener("after_server_start")
def stop(app, loop):
app.stop()
with pytest.raises(FileExistsError):
app.run(unix=SOCKPATH)
def test_dont_follow_symlink():
from socket import AF_UNIX, socket
with socket(AF_UNIX) as sock:
sock.bind(SOCKPATH2)
os.symlink(SOCKPATH2, SOCKPATH)
app = Sanic(name=__name__)
@app.listener("after_server_start")
def stop(app, loop):
app.stop()
with pytest.raises(FileExistsError):
app.run(unix=SOCKPATH)
def test_socket_deleted_while_running():
app = Sanic(name=__name__)
@app.listener("after_server_start")
async def hack(app, loop):
os.unlink(SOCKPATH)
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
def test_socket_replaced_with_file():
app = Sanic(name=__name__)
@app.listener("after_server_start")
async def hack(app, loop):
os.unlink(SOCKPATH)
with open(SOCKPATH, "w") as f:
f.write("Not a socket")
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
def test_unix_connection():
app = Sanic(name=__name__)
@app.get("/")
def handler(request):
return text(f"{request.conn_info.server}")
@app.listener("after_server_start")
async def client(app, loop):
try:
async with httpx.AsyncClient(uds=SOCKPATH) as client:
r = await client.get("http://myhost.invalid/")
assert r.status_code == 200
assert r.text == os.path.abspath(SOCKPATH)
finally:
app.stop()
app.run(host="myhost.invalid", unix=SOCKPATH)
app_multi = Sanic(name=__name__)
def handler(request):
return text(f"{request.conn_info.server}")
async def client(app, loop):
try:
async with httpx.AsyncClient(uds=SOCKPATH) as client:
r = await client.get("http://myhost.invalid/")
assert r.status_code == 200
assert r.text == os.path.abspath(SOCKPATH)
finally:
app.stop()
def test_unix_connection_multiple_workers():
app_multi.get("/")(handler)
app_multi.listener("after_server_start")(client)
app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2)
async def test_zero_downtime():
"""Graceful server termination and socket replacement on restarts"""
from signal import SIGINT
from time import monotonic as current_time
async def client():
for _ in range(40):
async with httpx.AsyncClient(uds=SOCKPATH) as client:
r = await client.get("http://localhost/sleep/0.1")
assert r.status_code == 200
assert r.text == f"Slept 0.1 seconds.\n"
def spawn():
command = [
sys.executable,
"-m",
"sanic",
"--unix",
SOCKPATH,
"examples.delayed_response.app",
]
DN = subprocess.DEVNULL
return subprocess.Popen(
command, stdin=DN, stdout=DN, stderr=subprocess.PIPE
)
try:
processes = [spawn()]
while not os.path.exists(SOCKPATH):
if processes[0].poll() is not None:
raise Exception("Worker did not start properly")
await asyncio.sleep(0.0001)
ino = os.stat(SOCKPATH).st_ino
task = asyncio.get_event_loop().create_task(client())
start_time = current_time()
while current_time() < start_time + 4:
# Start a new one and wait until the socket is replaced
processes.append(spawn())
while ino == os.stat(SOCKPATH).st_ino:
await asyncio.sleep(0.001)
ino = os.stat(SOCKPATH).st_ino
# Graceful termination of the previous one
processes[-2].send_signal(SIGINT)
# Wait until client has completed all requests
await task
processes[-1].send_signal(SIGINT)
for worker in processes:
try:
worker.wait(1.0)
except subprocess.TimeoutExpired:
raise Exception(
f"Worker would not terminate:\n{worker.stderr}"
)
finally:
for worker in processes:
worker.kill()
# Test for clean run and termination
assert len(processes) > 5
assert [worker.poll() for worker in processes] == len(processes) * [0]
assert not os.path.exists(SOCKPATH)