Socket binding implemented properly for IPv6 and UNIX sockets. (#1641)
* Socket binding implemented properly for IPv6 and UNIX sockets. - app.run("::1") for IPv6 - app.run("unix:/tmp/server.sock") for UNIX sockets - app.run("localhost") retains old functionality (randomly either IPv4 or IPv6) Do note that IPv6 and UNIX sockets are not fully supported by other Sanic facilities. In particular, request.server_name and request.server_port are currently unreliable. * Fix Windows compatibility by not referring to socket.AF_UNIX unless needed. * Compatibility fix. * Fix test of existing unix socket. * Cleaner unix socket removal. * Remove unix socket on exit also with workers=1. * More pedantic UNIX socket implementation. * Refactor app to take unix= argument instead of unix:-prefixed host. Goin' fast @ unix-socket fixed. * Linter * Proxy properties cleanup. Slight changes of semantics. SERVER_NAME now overrides everything. * Have server fill in connection info instead of request asking the socket. - Would be a good idea to remove request.transport entirely but I didn't dare to touch it yet. * Linter 💣🌟✊💀 * Fix typing issues. request.server_name returns empty string if host header is missing. * Fix tests * Tests were failing, fix connection info. * Linter nazi says you need that empty line. * Rename a to addr, leave client empty for unix sockets. * Add --unix support when sanic is run as module. * Remove remove_route, deprecated in 19.6. * Improved unix socket binding. * More robust creating and unlinking of sockets. Show proper and not temporary name in conn_info. * Add comprehensive tests for unix socket mode. * Hide some imports inside functions to avoid Windows failure. * Mention unix socket mode in deployment docs. * Fix merge commit. * Make test_unix_connection_multiple_workers pickleable for spawn mode multiprocessing. Co-authored-by: L. Kärkkäinen <tronic@users.noreply.github.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
4aba74d050
commit
a62c84a954
|
@ -16,6 +16,7 @@ keyword arguments:
|
||||||
|
|
||||||
- `host` *(default `"127.0.0.1"`)*: Address to host the server on.
|
- `host` *(default `"127.0.0.1"`)*: Address to host the server on.
|
||||||
- `port` *(default `8000`)*: Port to host the server on.
|
- `port` *(default `8000`)*: Port to host the server on.
|
||||||
|
- `unix` *(default `None`)*: Unix socket name to host the server on (instead of TCP).
|
||||||
- `debug` *(default `False`)*: Enables debug output (slows server).
|
- `debug` *(default `False`)*: Enables debug output (slows server).
|
||||||
- `ssl` *(default `None`)*: `SSLContext` for SSL encryption of worker(s).
|
- `ssl` *(default `None`)*: `SSLContext` for SSL encryption of worker(s).
|
||||||
- `sock` *(default `None`)*: Socket for the server to accept connections from.
|
- `sock` *(default `None`)*: Socket for the server to accept connections from.
|
||||||
|
|
18
examples/delayed_response.py
Normal file
18
examples/delayed_response.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
from asyncio import sleep
|
||||||
|
|
||||||
|
from sanic import Sanic, response
|
||||||
|
|
||||||
|
app = Sanic(__name__, strict_slashes=True)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def handler(request):
|
||||||
|
return response.redirect("/sleep/3")
|
||||||
|
|
||||||
|
@app.get("/sleep/<t:number>")
|
||||||
|
async def handler2(request, t=0.3):
|
||||||
|
await sleep(t)
|
||||||
|
return response.text(f"Slept {t:.1f} seconds.\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(host="0.0.0.0", port=8000)
|
|
@ -13,6 +13,7 @@ def main():
|
||||||
parser = ArgumentParser(prog="sanic")
|
parser = ArgumentParser(prog="sanic")
|
||||||
parser.add_argument("--host", dest="host", type=str, default="127.0.0.1")
|
parser.add_argument("--host", dest="host", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--port", dest="port", type=int, default=8000)
|
parser.add_argument("--port", dest="port", type=int, default=8000)
|
||||||
|
parser.add_argument("--unix", dest="unix", type=str, default="")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cert", dest="cert", type=str, help="location of certificate for SSL"
|
"--cert", dest="cert", type=str, help="location of certificate for SSL"
|
||||||
)
|
)
|
||||||
|
@ -53,6 +54,7 @@ def main():
|
||||||
app.run(
|
app.run(
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
|
unix=args.unix,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
ssl=ssl,
|
ssl=ssl,
|
||||||
|
|
13
sanic/app.py
13
sanic/app.py
|
@ -1033,6 +1033,7 @@ class Sanic:
|
||||||
stop_event: Any = None,
|
stop_event: Any = None,
|
||||||
register_sys_signals: bool = True,
|
register_sys_signals: bool = True,
|
||||||
access_log: Optional[bool] = None,
|
access_log: Optional[bool] = None,
|
||||||
|
unix: Optional[str] = None,
|
||||||
loop: None = None,
|
loop: None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the HTTP Server and listen until keyboard interrupt or term
|
"""Run the HTTP Server and listen until keyboard interrupt or term
|
||||||
|
@ -1066,6 +1067,8 @@ class Sanic:
|
||||||
:type register_sys_signals: bool
|
:type register_sys_signals: bool
|
||||||
:param access_log: Enables writing access logs (slows server)
|
:param access_log: Enables writing access logs (slows server)
|
||||||
:type access_log: bool
|
:type access_log: bool
|
||||||
|
:param unix: Unix socket to listen on instead of TCP port
|
||||||
|
:type unix: str
|
||||||
:return: Nothing
|
:return: Nothing
|
||||||
"""
|
"""
|
||||||
if loop is not None:
|
if loop is not None:
|
||||||
|
@ -1104,6 +1107,7 @@ class Sanic:
|
||||||
debug=debug,
|
debug=debug,
|
||||||
ssl=ssl,
|
ssl=ssl,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
unix=unix,
|
||||||
workers=workers,
|
workers=workers,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
backlog=backlog,
|
backlog=backlog,
|
||||||
|
@ -1151,6 +1155,7 @@ class Sanic:
|
||||||
backlog: int = 100,
|
backlog: int = 100,
|
||||||
stop_event: Any = None,
|
stop_event: Any = None,
|
||||||
access_log: Optional[bool] = None,
|
access_log: Optional[bool] = None,
|
||||||
|
unix: Optional[str] = None,
|
||||||
return_asyncio_server=False,
|
return_asyncio_server=False,
|
||||||
asyncio_server_kwargs=None,
|
asyncio_server_kwargs=None,
|
||||||
) -> Optional[AsyncioServer]:
|
) -> Optional[AsyncioServer]:
|
||||||
|
@ -1220,6 +1225,7 @@ class Sanic:
|
||||||
debug=debug,
|
debug=debug,
|
||||||
ssl=ssl,
|
ssl=ssl,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
unix=unix,
|
||||||
loop=get_event_loop(),
|
loop=get_event_loop(),
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
backlog=backlog,
|
backlog=backlog,
|
||||||
|
@ -1285,6 +1291,7 @@ class Sanic:
|
||||||
debug=False,
|
debug=False,
|
||||||
ssl=None,
|
ssl=None,
|
||||||
sock=None,
|
sock=None,
|
||||||
|
unix=None,
|
||||||
workers=1,
|
workers=1,
|
||||||
loop=None,
|
loop=None,
|
||||||
protocol=HttpProtocol,
|
protocol=HttpProtocol,
|
||||||
|
@ -1326,6 +1333,7 @@ class Sanic:
|
||||||
"host": host,
|
"host": host,
|
||||||
"port": port,
|
"port": port,
|
||||||
"sock": sock,
|
"sock": sock,
|
||||||
|
"unix": unix,
|
||||||
"ssl": ssl,
|
"ssl": ssl,
|
||||||
"app": self,
|
"app": self,
|
||||||
"signal": Signal(),
|
"signal": Signal(),
|
||||||
|
@ -1372,7 +1380,10 @@ class Sanic:
|
||||||
proto = "http"
|
proto = "http"
|
||||||
if ssl is not None:
|
if ssl is not None:
|
||||||
proto = "https"
|
proto = "https"
|
||||||
logger.info(f"Goin' Fast @ {proto}://{host}:{port}")
|
if unix:
|
||||||
|
logger.info(f"Goin' Fast @ {unix} {proto}://...")
|
||||||
|
else:
|
||||||
|
logger.info(f"Goin' Fast @ {proto}://{host}:{port}")
|
||||||
|
|
||||||
return server_settings
|
return server_settings
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ from sanic.exceptions import InvalidUsage, ServerError
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
from sanic.response import HTTPResponse, StreamingHTTPResponse
|
||||||
from sanic.server import StreamBuffer
|
from sanic.server import ConnInfo, StreamBuffer
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.websocket import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
|
@ -255,6 +255,7 @@ class ASGIApp:
|
||||||
instance.transport,
|
instance.transport,
|
||||||
sanic_app,
|
sanic_app,
|
||||||
)
|
)
|
||||||
|
instance.request.conn_info = ConnInfo(instance.transport)
|
||||||
|
|
||||||
if sanic_app.is_request_stream:
|
if sanic_app.is_request_stream:
|
||||||
is_stream_handler = sanic_app.router.is_stream_handler(
|
is_stream_handler = sanic_app.router.is_stream_handler(
|
||||||
|
|
149
sanic/request.py
149
sanic/request.py
|
@ -87,6 +87,7 @@ class Request:
|
||||||
"_socket",
|
"_socket",
|
||||||
"app",
|
"app",
|
||||||
"body",
|
"body",
|
||||||
|
"conn_info",
|
||||||
"ctx",
|
"ctx",
|
||||||
"endpoint",
|
"endpoint",
|
||||||
"headers",
|
"headers",
|
||||||
|
@ -117,6 +118,7 @@ class Request:
|
||||||
|
|
||||||
# Init but do not inhale
|
# Init but do not inhale
|
||||||
self.body_init()
|
self.body_init()
|
||||||
|
self.conn_info = None
|
||||||
self.ctx = SimpleNamespace()
|
self.ctx = SimpleNamespace()
|
||||||
self.parsed_forwarded = None
|
self.parsed_forwarded = None
|
||||||
self.parsed_json = None
|
self.parsed_json = None
|
||||||
|
@ -349,56 +351,55 @@ class Request:
|
||||||
self._cookies = {}
|
self._cookies = {}
|
||||||
return self._cookies
|
return self._cookies
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_type(self):
|
||||||
|
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def match_info(self):
|
||||||
|
"""return matched info after resolving route"""
|
||||||
|
return self.app.router.get(self)[2]
|
||||||
|
|
||||||
|
# Transport properties (obtained from local interface only)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ip(self):
|
def ip(self):
|
||||||
"""
|
"""
|
||||||
:return: peer ip of the socket
|
:return: peer ip of the socket
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "_socket"):
|
return self.conn_info.client if self.conn_info else ""
|
||||||
self._get_address()
|
|
||||||
return self._ip
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def port(self):
|
def port(self):
|
||||||
"""
|
"""
|
||||||
:return: peer port of the socket
|
:return: peer port of the socket
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "_socket"):
|
return self.conn_info.client_port if self.conn_info else 0
|
||||||
self._get_address()
|
|
||||||
return self._port
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def socket(self):
|
def socket(self):
|
||||||
if not hasattr(self, "_socket"):
|
return self.conn_info.peername if self.conn_info else (None, None)
|
||||||
self._get_address()
|
|
||||||
return self._socket
|
|
||||||
|
|
||||||
def _get_address(self):
|
|
||||||
self._socket = self.transport.get_extra_info("peername") or (
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
self._ip = self._socket[0]
|
|
||||||
self._port = self._socket[1]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server_name(self):
|
def path(self) -> str:
|
||||||
"""
|
"""Path of the local HTTP request."""
|
||||||
Attempt to get the server's external hostname in this order:
|
return self._parsed_url.path.decode("utf-8")
|
||||||
`config.SERVER_NAME`, proxied or direct Host headers
|
|
||||||
:func:`Request.host`
|
|
||||||
|
|
||||||
:return: the server name without port number
|
# Proxy properties (using SERVER_NAME/forwarded/request/transport info)
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
server_name = self.app.config.get("SERVER_NAME")
|
|
||||||
if server_name:
|
|
||||||
host = server_name.split("//", 1)[-1].split("/", 1)[0]
|
|
||||||
return parse_host(host)[0]
|
|
||||||
return parse_host(self.host)[0]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def forwarded(self):
|
def forwarded(self):
|
||||||
|
"""
|
||||||
|
Active proxy information obtained from request headers, as specified in
|
||||||
|
Sanic configuration.
|
||||||
|
|
||||||
|
Field names by, for, proto, host, port and path are normalized.
|
||||||
|
- for and by IPv6 addresses are bracketed
|
||||||
|
- port (int) is only set by port headers, not from host.
|
||||||
|
- path is url-unencoded
|
||||||
|
|
||||||
|
Additional values may be available from new style Forwarded headers.
|
||||||
|
"""
|
||||||
if self.parsed_forwarded is None:
|
if self.parsed_forwarded is None:
|
||||||
self.parsed_forwarded = (
|
self.parsed_forwarded = (
|
||||||
parse_forwarded(self.headers, self.app.config)
|
parse_forwarded(self.headers, self.app.config)
|
||||||
|
@ -408,50 +409,30 @@ class Request:
|
||||||
return self.parsed_forwarded
|
return self.parsed_forwarded
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server_port(self):
|
def remote_addr(self) -> str:
|
||||||
"""
|
"""
|
||||||
Attempt to get the server's external port number in this order:
|
Client IP address, if available.
|
||||||
`config.SERVER_NAME`, proxied or direct Host headers
|
1. proxied remote address `self.forwarded['for']`
|
||||||
:func:`Request.host`,
|
2. local remote address `self.ip`
|
||||||
actual port used by the transport layer socket.
|
:return: IPv4, bracketed IPv6, UNIX socket name or arbitrary string
|
||||||
:return: server port
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
if self.forwarded:
|
|
||||||
return self.forwarded.get("port") or (
|
|
||||||
80 if self.scheme in ("http", "ws") else 443
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
parse_host(self.host)[1]
|
|
||||||
or self.transport.get_extra_info("sockname")[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def remote_addr(self):
|
|
||||||
"""Attempt to return the original client ip based on `forwarded`,
|
|
||||||
`x-forwarded-for` or `x-real-ip`. If HTTP headers are unavailable or
|
|
||||||
untrusted, returns an empty string.
|
|
||||||
|
|
||||||
:return: original client ip.
|
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "_remote_addr"):
|
if not hasattr(self, "_remote_addr"):
|
||||||
self._remote_addr = self.forwarded.get("for", "")
|
self._remote_addr = self.forwarded.get("for", "") # or self.ip
|
||||||
return self._remote_addr
|
return self._remote_addr
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheme(self):
|
def scheme(self) -> str:
|
||||||
"""
|
"""
|
||||||
Attempt to get the request scheme.
|
Determine request scheme.
|
||||||
Seeking the value in this order:
|
1. `config.SERVER_NAME` if in full URL format
|
||||||
`forwarded` header, `x-forwarded-proto` header,
|
2. proxied proto/scheme
|
||||||
`x-scheme` header, the sanic app itself.
|
3. local connection protocol
|
||||||
|
|
||||||
:return: http|https|ws|wss or arbitrary value given by the headers.
|
:return: http|https|ws|wss or arbitrary value given by the headers.
|
||||||
:rtype: str
|
|
||||||
"""
|
"""
|
||||||
forwarded_proto = self.forwarded.get("proto")
|
if "//" in self.app.config.get("SERVER_NAME", ""):
|
||||||
if forwarded_proto:
|
return self.app.config.SERVER_NAME.split("//")[0]
|
||||||
return forwarded_proto
|
if "proto" in self.forwarded:
|
||||||
|
return self.forwarded["proto"]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.app.websocket_enabled
|
self.app.websocket_enabled
|
||||||
|
@ -467,25 +448,41 @@ class Request:
|
||||||
return scheme
|
return scheme
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def host(self):
|
def host(self) -> str:
|
||||||
"""
|
"""
|
||||||
:return: proxied or direct Host header. Hostname and port number may be
|
The currently effective server 'host' (hostname or hostname:port).
|
||||||
separated by sanic.headers.parse_host(request.host).
|
1. `config.SERVER_NAME` overrides any client headers
|
||||||
|
2. proxied host of original request
|
||||||
|
3. request host header
|
||||||
|
hostname and port may be separated by
|
||||||
|
`sanic.headers.parse_host(request.host)`.
|
||||||
|
:return: the first matching host found, or empty string
|
||||||
"""
|
"""
|
||||||
return self.forwarded.get("host", self.headers.get("Host", ""))
|
server_name = self.app.config.get("SERVER_NAME")
|
||||||
|
if server_name:
|
||||||
|
return server_name.split("//", 1)[-1].split("/", 1)[0]
|
||||||
|
return self.forwarded.get("host") or self.headers.get("host", "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content_type(self):
|
def server_name(self) -> str:
|
||||||
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE)
|
"""The hostname the client connected to, by `request.host`."""
|
||||||
|
return parse_host(self.host)[0] or ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def match_info(self):
|
def server_port(self) -> int:
|
||||||
"""return matched info after resolving route"""
|
"""
|
||||||
return self.app.router.get(self)[2]
|
The port the client connected to, by forwarded `port` or
|
||||||
|
`request.host`.
|
||||||
|
|
||||||
|
Default port is returned as 80 and 443 based on `request.scheme`.
|
||||||
|
"""
|
||||||
|
port = self.forwarded.get("port") or parse_host(self.host)[1]
|
||||||
|
return port or (80 if self.scheme in ("http", "ws") else 443)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def server_path(self) -> str:
|
||||||
return self._parsed_url.path.decode("utf-8")
|
"""Full path of current URL. Uses proxied or local path."""
|
||||||
|
return self.forwarded.get("path") or self.path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query_string(self):
|
def query_string(self):
|
||||||
|
|
156
sanic/server.py
156
sanic/server.py
|
@ -1,15 +1,18 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
|
import socket
|
||||||
|
import stat
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
|
from ipaddress import ip_address
|
||||||
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
|
||||||
from signal import signal as signal_func
|
from signal import signal as signal_func
|
||||||
from socket import SO_REUSEADDR, SOL_SOCKET, socket
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from httptools import HttpRequestParser # type: ignore
|
from httptools import HttpRequestParser # type: ignore
|
||||||
|
@ -44,6 +47,41 @@ class Signal:
|
||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
|
||||||
|
class ConnInfo:
|
||||||
|
"""Local and remote addresses and SSL status info."""
|
||||||
|
|
||||||
|
__slots__ = (
|
||||||
|
"sockname",
|
||||||
|
"peername",
|
||||||
|
"server",
|
||||||
|
"server_port",
|
||||||
|
"client",
|
||||||
|
"client_port",
|
||||||
|
"ssl",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, transport, unix=None):
|
||||||
|
self.ssl = bool(transport.get_extra_info("sslcontext"))
|
||||||
|
self.server = self.client = ""
|
||||||
|
self.server_port = self.client_port = 0
|
||||||
|
self.peername = None
|
||||||
|
self.sockname = addr = transport.get_extra_info("sockname")
|
||||||
|
if isinstance(addr, str): # UNIX socket
|
||||||
|
self.server = unix or addr
|
||||||
|
return
|
||||||
|
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
|
||||||
|
if isinstance(addr, tuple):
|
||||||
|
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||||
|
self.server_port = addr[1]
|
||||||
|
# self.server gets non-standard port appended
|
||||||
|
if addr[1] != (443 if self.ssl else 80):
|
||||||
|
self.server = f"{self.server}:{addr[1]}"
|
||||||
|
self.peername = addr = transport.get_extra_info("peername")
|
||||||
|
if isinstance(addr, tuple):
|
||||||
|
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
|
||||||
|
self.client_port = addr[1]
|
||||||
|
|
||||||
|
|
||||||
class HttpProtocol(asyncio.Protocol):
|
class HttpProtocol(asyncio.Protocol):
|
||||||
"""
|
"""
|
||||||
This class provides a basic HTTP implementation of the sanic framework.
|
This class provides a basic HTTP implementation of the sanic framework.
|
||||||
|
@ -57,6 +95,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
"transport",
|
"transport",
|
||||||
"connections",
|
"connections",
|
||||||
"signal",
|
"signal",
|
||||||
|
"conn_info",
|
||||||
# request params
|
# request params
|
||||||
"parser",
|
"parser",
|
||||||
"request",
|
"request",
|
||||||
|
@ -88,6 +127,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
"_keep_alive",
|
"_keep_alive",
|
||||||
"_header_fragment",
|
"_header_fragment",
|
||||||
"state",
|
"state",
|
||||||
|
"_unix",
|
||||||
"_body_chunks",
|
"_body_chunks",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,6 +139,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
signal=Signal(),
|
signal=Signal(),
|
||||||
connections=None,
|
connections=None,
|
||||||
state=None,
|
state=None,
|
||||||
|
unix=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
@ -106,6 +147,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
|
deprecated_loop = self.loop if sys.version_info < (3, 7) else None
|
||||||
self.app = app
|
self.app = app
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
self.conn_info = None
|
||||||
self.request = None
|
self.request = None
|
||||||
self.parser = None
|
self.parser = None
|
||||||
self.url = None
|
self.url = None
|
||||||
|
@ -139,6 +181,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self.state = state if state else {}
|
self.state = state if state else {}
|
||||||
if "requests_count" not in self.state:
|
if "requests_count" not in self.state:
|
||||||
self.state["requests_count"] = 0
|
self.state["requests_count"] = 0
|
||||||
|
self._unix = unix
|
||||||
self._not_paused.set()
|
self._not_paused.set()
|
||||||
self._body_chunks = deque()
|
self._body_chunks = deque()
|
||||||
|
|
||||||
|
@ -167,6 +210,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self.request_timeout, self.request_timeout_callback
|
self.request_timeout, self.request_timeout_callback
|
||||||
)
|
)
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
self.conn_info = ConnInfo(transport, unix=self._unix)
|
||||||
self._last_request_time = time()
|
self._last_request_time = time()
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
|
@ -304,6 +348,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
transport=self.transport,
|
transport=self.transport,
|
||||||
app=self.app,
|
app=self.app,
|
||||||
)
|
)
|
||||||
|
self.request.conn_info = self.conn_info
|
||||||
# Remove any existing KeepAlive handler here,
|
# Remove any existing KeepAlive handler here,
|
||||||
# It will be recreated if required on the new request.
|
# It will be recreated if required on the new request.
|
||||||
if self._keep_alive_timeout_handler:
|
if self._keep_alive_timeout_handler:
|
||||||
|
@ -750,6 +795,7 @@ def serve(
|
||||||
after_stop=None,
|
after_stop=None,
|
||||||
ssl=None,
|
ssl=None,
|
||||||
sock=None,
|
sock=None,
|
||||||
|
unix=None,
|
||||||
reuse_port=False,
|
reuse_port=False,
|
||||||
loop=None,
|
loop=None,
|
||||||
protocol=HttpProtocol,
|
protocol=HttpProtocol,
|
||||||
|
@ -778,6 +824,7 @@ def serve(
|
||||||
`app` instance and `loop`
|
`app` instance and `loop`
|
||||||
:param ssl: SSLContext
|
:param ssl: SSLContext
|
||||||
:param sock: Socket for the server to accept connections from
|
:param sock: Socket for the server to accept connections from
|
||||||
|
:param unix: Unix socket to listen on instead of TCP port
|
||||||
:param reuse_port: `True` for multiple workers
|
:param reuse_port: `True` for multiple workers
|
||||||
:param loop: asyncio compatible event loop
|
:param loop: asyncio compatible event loop
|
||||||
:param run_async: bool: Do not create a new event loop for the server,
|
:param run_async: bool: Do not create a new event loop for the server,
|
||||||
|
@ -804,14 +851,18 @@ def serve(
|
||||||
signal=signal,
|
signal=signal,
|
||||||
app=app,
|
app=app,
|
||||||
state=state,
|
state=state,
|
||||||
|
unix=unix,
|
||||||
)
|
)
|
||||||
asyncio_server_kwargs = (
|
asyncio_server_kwargs = (
|
||||||
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
asyncio_server_kwargs if asyncio_server_kwargs else {}
|
||||||
)
|
)
|
||||||
|
# UNIX sockets are always bound by us (to preserve semantics between modes)
|
||||||
|
if unix:
|
||||||
|
sock = bind_unix_socket(unix, backlog=backlog)
|
||||||
server_coroutine = loop.create_server(
|
server_coroutine = loop.create_server(
|
||||||
server,
|
server,
|
||||||
host,
|
None if sock else host,
|
||||||
port,
|
None if sock else port,
|
||||||
ssl=ssl,
|
ssl=ssl,
|
||||||
reuse_port=reuse_port,
|
reuse_port=reuse_port,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
@ -894,6 +945,85 @@ def serve(
|
||||||
trigger_events(after_stop, loop)
|
trigger_events(after_stop, loop)
|
||||||
|
|
||||||
loop.close()
|
loop.close()
|
||||||
|
remove_unix_socket(unix)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket:
|
||||||
|
"""Create TCP server socket.
|
||||||
|
:param host: IPv4, IPv6 or hostname may be specified
|
||||||
|
:param port: TCP port number
|
||||||
|
:param backlog: Maximum number of connections to queue
|
||||||
|
:return: socket.socket object
|
||||||
|
"""
|
||||||
|
try: # IP address: family must be specified for IPv6 at least
|
||||||
|
ip = ip_address(host)
|
||||||
|
host = str(ip)
|
||||||
|
sock = socket.socket(
|
||||||
|
socket.AF_INET6 if ip.version == 6 else socket.AF_INET
|
||||||
|
)
|
||||||
|
except ValueError: # Hostname, may become AF_INET or AF_INET6
|
||||||
|
sock = socket.socket()
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.bind((host, port))
|
||||||
|
sock.listen(backlog)
|
||||||
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket:
|
||||||
|
"""Create unix socket.
|
||||||
|
:param path: filesystem path
|
||||||
|
:param backlog: Maximum number of connections to queue
|
||||||
|
:return: socket.socket object
|
||||||
|
"""
|
||||||
|
"""Open or atomically replace existing socket with zero downtime."""
|
||||||
|
# Sanitise and pre-verify socket path
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
folder = os.path.dirname(path)
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
raise FileNotFoundError(f"Socket folder does not exist: {folder}")
|
||||||
|
try:
|
||||||
|
if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||||
|
raise FileExistsError(f"Existing file is not a socket: {path}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
# Create new socket with a random temporary name
|
||||||
|
tmp_path = f"{path}.{secrets.token_urlsafe()}"
|
||||||
|
sock = socket.socket(socket.AF_UNIX)
|
||||||
|
try:
|
||||||
|
# Critical section begins (filename races)
|
||||||
|
sock.bind(tmp_path)
|
||||||
|
try:
|
||||||
|
os.chmod(tmp_path, mode)
|
||||||
|
# Start listening before rename to avoid connection failures
|
||||||
|
sock.listen(backlog)
|
||||||
|
os.rename(tmp_path, path)
|
||||||
|
except: # noqa: E722
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
finally:
|
||||||
|
raise
|
||||||
|
except: # noqa: E722
|
||||||
|
try:
|
||||||
|
sock.close()
|
||||||
|
finally:
|
||||||
|
raise
|
||||||
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
def remove_unix_socket(path: str) -> None:
|
||||||
|
"""Remove dead unix socket during server exit."""
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode):
|
||||||
|
# Is it actually dead (doesn't belong to a new server instance)?
|
||||||
|
with socket.socket(socket.AF_UNIX) as testsock:
|
||||||
|
try:
|
||||||
|
testsock.connect(path)
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
os.unlink(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def serve_multiple(server_settings, workers):
|
def serve_multiple(server_settings, workers):
|
||||||
|
@ -908,11 +1038,17 @@ def serve_multiple(server_settings, workers):
|
||||||
server_settings["reuse_port"] = True
|
server_settings["reuse_port"] = True
|
||||||
server_settings["run_multiple"] = True
|
server_settings["run_multiple"] = True
|
||||||
|
|
||||||
# Handling when custom socket is not provided.
|
# Create a listening socket or use the one in settings
|
||||||
if server_settings.get("sock") is None:
|
sock = server_settings.get("sock")
|
||||||
sock = socket()
|
unix = server_settings["unix"]
|
||||||
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
|
backlog = server_settings["backlog"]
|
||||||
sock.bind((server_settings["host"], server_settings["port"]))
|
if unix:
|
||||||
|
sock = bind_unix_socket(unix, backlog=backlog)
|
||||||
|
server_settings["unix"] = unix
|
||||||
|
if sock is None:
|
||||||
|
sock = bind_socket(
|
||||||
|
server_settings["host"], server_settings["port"], backlog=backlog
|
||||||
|
)
|
||||||
sock.set_inheritable(True)
|
sock.set_inheritable(True)
|
||||||
server_settings["sock"] = sock
|
server_settings["sock"] = sock
|
||||||
server_settings["host"] = None
|
server_settings["host"] = None
|
||||||
|
@ -941,4 +1077,6 @@ def serve_multiple(server_settings, workers):
|
||||||
# the above processes will block this until they're stopped
|
# the above processes will block this until they're stopped
|
||||||
for process in processes:
|
for process in processes:
|
||||||
process.terminate()
|
process.terminate()
|
||||||
server_settings.get("sock").close()
|
|
||||||
|
sock.close()
|
||||||
|
remove_unix_socket(unix)
|
||||||
|
|
|
@ -454,11 +454,13 @@ def test_standard_forwarded(app):
|
||||||
"X-Real-IP": "127.0.0.2",
|
"X-Real-IP": "127.0.0.2",
|
||||||
"X-Forwarded-For": "127.0.1.1",
|
"X-Forwarded-For": "127.0.1.1",
|
||||||
"X-Scheme": "ws",
|
"X-Scheme": "ws",
|
||||||
|
"Host": "local.site",
|
||||||
}
|
}
|
||||||
request, response = app.test_client.get("/", headers=headers)
|
request, response = app.test_client.get("/", headers=headers)
|
||||||
assert response.json == {"for": "127.0.0.2", "proto": "ws"}
|
assert response.json == {"for": "127.0.0.2", "proto": "ws"}
|
||||||
assert request.remote_addr == "127.0.0.2"
|
assert request.remote_addr == "127.0.0.2"
|
||||||
assert request.scheme == "ws"
|
assert request.scheme == "ws"
|
||||||
|
assert request.server_name == "local.site"
|
||||||
assert request.server_port == 80
|
assert request.server_port == 80
|
||||||
|
|
||||||
app.config.FORWARDED_SECRET = "mySecret"
|
app.config.FORWARDED_SECRET = "mySecret"
|
||||||
|
@ -1807,13 +1809,17 @@ def test_request_port(app):
|
||||||
port = request.port
|
port = request.port
|
||||||
assert isinstance(port, int)
|
assert isinstance(port, int)
|
||||||
|
|
||||||
delattr(request, "_socket")
|
|
||||||
delattr(request, "_port")
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_port_asgi(app):
|
||||||
|
@app.get("/")
|
||||||
|
def handler(request):
|
||||||
|
return text("OK")
|
||||||
|
|
||||||
|
request, response = await app.asgi_client.get("/")
|
||||||
|
|
||||||
port = request.port
|
port = request.port
|
||||||
assert isinstance(port, int)
|
assert isinstance(port, int)
|
||||||
assert hasattr(request, "_socket")
|
|
||||||
assert hasattr(request, "_port")
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_socket(app):
|
def test_request_socket(app):
|
||||||
|
@ -1832,12 +1838,6 @@ def test_request_socket(app):
|
||||||
assert ip == request.ip
|
assert ip == request.ip
|
||||||
assert port == request.port
|
assert port == request.port
|
||||||
|
|
||||||
delattr(request, "_socket")
|
|
||||||
|
|
||||||
socket = request.socket
|
|
||||||
assert isinstance(socket, tuple)
|
|
||||||
assert hasattr(request, "_socket")
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_name(app):
|
def test_request_server_name(app):
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
@ -1866,7 +1866,7 @@ def test_request_server_name_in_host_header(app):
|
||||||
request, response = app.test_client.get(
|
request, response = app.test_client.get(
|
||||||
"/", headers={"Host": "mal_formed"}
|
"/", headers={"Host": "mal_formed"}
|
||||||
)
|
)
|
||||||
assert request.server_name == None # For now (later maybe 127.0.0.1)
|
assert request.server_name == ""
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_name_forwarded(app):
|
def test_request_server_name_forwarded(app):
|
||||||
|
@ -1893,7 +1893,7 @@ def test_request_server_port(app):
|
||||||
|
|
||||||
test_client = SanicTestClient(app)
|
test_client = SanicTestClient(app)
|
||||||
request, response = test_client.get("/", headers={"Host": "my-server"})
|
request, response = test_client.get("/", headers={"Host": "my-server"})
|
||||||
assert request.server_port == test_client.port
|
assert request.server_port == 80
|
||||||
|
|
||||||
|
|
||||||
def test_request_server_port_in_host_header(app):
|
def test_request_server_port_in_host_header(app):
|
||||||
|
@ -1952,12 +1952,12 @@ def test_server_name_and_url_for(app):
|
||||||
def handler(request):
|
def handler(request):
|
||||||
return text("ok")
|
return text("ok")
|
||||||
|
|
||||||
app.config.SERVER_NAME = "my-server"
|
app.config.SERVER_NAME = "my-server" # This means default port
|
||||||
assert app.url_for("handler", _external=True) == "http://my-server/foo"
|
assert app.url_for("handler", _external=True) == "http://my-server/foo"
|
||||||
request, response = app.test_client.get("/foo")
|
request, response = app.test_client.get("/foo")
|
||||||
assert (
|
assert (
|
||||||
request.url_for("handler")
|
request.url_for("handler")
|
||||||
== f"http://my-server:{request.server_port}/foo"
|
== f"http://my-server/foo"
|
||||||
)
|
)
|
||||||
|
|
||||||
app.config.SERVER_NAME = "https://my-server/path"
|
app.config.SERVER_NAME = "https://my-server/path"
|
||||||
|
|
235
tests/test_unix_socket.py
Normal file
235
tests/test_unix_socket.py
Normal file
|
@ -0,0 +1,235 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from sanic import Sanic
|
||||||
|
from sanic.response import text
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(os.name != "posix", reason="UNIX only")
|
||||||
|
SOCKPATH = "/tmp/sanictest.sock"
|
||||||
|
SOCKPATH2 = "/tmp/sanictest2.sock"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def socket_cleanup():
|
||||||
|
try:
|
||||||
|
os.unlink(SOCKPATH)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
os.unlink(SOCKPATH2)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
# Run test function
|
||||||
|
yield
|
||||||
|
try:
|
||||||
|
os.unlink(SOCKPATH2)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
os.unlink(SOCKPATH)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_unix_socket_creation(caplog):
|
||||||
|
from socket import AF_UNIX, socket
|
||||||
|
|
||||||
|
with socket(AF_UNIX) as sock:
|
||||||
|
sock.bind(SOCKPATH)
|
||||||
|
assert os.path.exists(SOCKPATH)
|
||||||
|
ino = os.stat(SOCKPATH).st_ino
|
||||||
|
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
def running(app, loop):
|
||||||
|
assert os.path.exists(SOCKPATH)
|
||||||
|
assert ino != os.stat(SOCKPATH).st_ino
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
app.run(unix=SOCKPATH)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"sanic.root",
|
||||||
|
logging.INFO,
|
||||||
|
f"Goin' Fast @ {SOCKPATH} http://...",
|
||||||
|
) in caplog.record_tuples
|
||||||
|
assert not os.path.exists(SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_paths():
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
with pytest.raises(FileExistsError):
|
||||||
|
app.run(unix=".")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
app.run(unix="no-such-directory/sanictest.sock")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dont_replace_file():
|
||||||
|
with open(SOCKPATH, "w") as f:
|
||||||
|
f.write("File, not socket")
|
||||||
|
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
def stop(app, loop):
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
with pytest.raises(FileExistsError):
|
||||||
|
app.run(unix=SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dont_follow_symlink():
|
||||||
|
from socket import AF_UNIX, socket
|
||||||
|
|
||||||
|
with socket(AF_UNIX) as sock:
|
||||||
|
sock.bind(SOCKPATH2)
|
||||||
|
os.symlink(SOCKPATH2, SOCKPATH)
|
||||||
|
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
def stop(app, loop):
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
with pytest.raises(FileExistsError):
|
||||||
|
app.run(unix=SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_socket_deleted_while_running():
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
async def hack(app, loop):
|
||||||
|
os.unlink(SOCKPATH)
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
app.run(host="myhost.invalid", unix=SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_socket_replaced_with_file():
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
async def hack(app, loop):
|
||||||
|
os.unlink(SOCKPATH)
|
||||||
|
with open(SOCKPATH, "w") as f:
|
||||||
|
f.write("Not a socket")
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
app.run(host="myhost.invalid", unix=SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unix_connection():
|
||||||
|
app = Sanic(name=__name__)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def handler(request):
|
||||||
|
return text(f"{request.conn_info.server}")
|
||||||
|
|
||||||
|
@app.listener("after_server_start")
|
||||||
|
async def client(app, loop):
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(uds=SOCKPATH) as client:
|
||||||
|
r = await client.get("http://myhost.invalid/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.text == os.path.abspath(SOCKPATH)
|
||||||
|
finally:
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
app.run(host="myhost.invalid", unix=SOCKPATH)
|
||||||
|
|
||||||
|
|
||||||
|
app_multi = Sanic(name=__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handler(request):
|
||||||
|
return text(f"{request.conn_info.server}")
|
||||||
|
|
||||||
|
|
||||||
|
async def client(app, loop):
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(uds=SOCKPATH) as client:
|
||||||
|
r = await client.get("http://myhost.invalid/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.text == os.path.abspath(SOCKPATH)
|
||||||
|
finally:
|
||||||
|
app.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_unix_connection_multiple_workers():
|
||||||
|
app_multi.get("/")(handler)
|
||||||
|
app_multi.listener("after_server_start")(client)
|
||||||
|
app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zero_downtime():
|
||||||
|
"""Graceful server termination and socket replacement on restarts"""
|
||||||
|
from signal import SIGINT
|
||||||
|
from time import monotonic as current_time
|
||||||
|
|
||||||
|
async def client():
|
||||||
|
for _ in range(40):
|
||||||
|
async with httpx.AsyncClient(uds=SOCKPATH) as client:
|
||||||
|
r = await client.get("http://localhost/sleep/0.1")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.text == f"Slept 0.1 seconds.\n"
|
||||||
|
|
||||||
|
def spawn():
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"sanic",
|
||||||
|
"--unix",
|
||||||
|
SOCKPATH,
|
||||||
|
"examples.delayed_response.app",
|
||||||
|
]
|
||||||
|
DN = subprocess.DEVNULL
|
||||||
|
return subprocess.Popen(
|
||||||
|
command, stdin=DN, stdout=DN, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processes = [spawn()]
|
||||||
|
while not os.path.exists(SOCKPATH):
|
||||||
|
if processes[0].poll() is not None:
|
||||||
|
raise Exception("Worker did not start properly")
|
||||||
|
await asyncio.sleep(0.0001)
|
||||||
|
ino = os.stat(SOCKPATH).st_ino
|
||||||
|
task = asyncio.get_event_loop().create_task(client())
|
||||||
|
start_time = current_time()
|
||||||
|
while current_time() < start_time + 4:
|
||||||
|
# Start a new one and wait until the socket is replaced
|
||||||
|
processes.append(spawn())
|
||||||
|
while ino == os.stat(SOCKPATH).st_ino:
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
ino = os.stat(SOCKPATH).st_ino
|
||||||
|
# Graceful termination of the previous one
|
||||||
|
processes[-2].send_signal(SIGINT)
|
||||||
|
# Wait until client has completed all requests
|
||||||
|
await task
|
||||||
|
processes[-1].send_signal(SIGINT)
|
||||||
|
for worker in processes:
|
||||||
|
try:
|
||||||
|
worker.wait(1.0)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise Exception(
|
||||||
|
f"Worker would not terminate:\n{worker.stderr}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for worker in processes:
|
||||||
|
worker.kill()
|
||||||
|
# Test for clean run and termination
|
||||||
|
assert len(processes) > 5
|
||||||
|
assert [worker.poll() for worker in processes] == len(processes) * [0]
|
||||||
|
assert not os.path.exists(SOCKPATH)
|
Loading…
Reference in New Issue
Block a user