Implement restart ordering (#2632)

This commit is contained in:
Adam Hopkins 2022-12-18 14:09:17 +02:00 committed by GitHub
parent 518152d97e
commit f7040ccec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 375 additions and 71 deletions

View File

@ -1534,6 +1534,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
self.state.is_started = True self.state.is_started = True
def ack(self):
if hasattr(self, "multiplexer"): if hasattr(self, "multiplexer"):
self.multiplexer.ack() self.multiplexer.ack()

View File

@ -20,7 +20,7 @@ class SanicHelpFormatter(RawTextHelpFormatter):
if not usage: if not usage:
usage = SUPPRESS usage = SUPPRESS
# Add one linebreak, but not two # Add one linebreak, but not two
self.add_text("\x1b[1A'") self.add_text("\x1b[1A")
super().add_usage(usage, actions, groups, prefix) super().add_usage(usage, actions, groups, prefix)

View File

@ -5,10 +5,26 @@ from sanic.cli.base import SanicHelpFormatter, SanicSubParsersAction
def _add_shared(parser: ArgumentParser) -> None: def _add_shared(parser: ArgumentParser) -> None:
parser.add_argument("--host", "-H", default="localhost") parser.add_argument(
parser.add_argument("--port", "-p", default=6457, type=int) "--host",
parser.add_argument("--secure", "-s", action="store_true") "-H",
parser.add_argument("--api-key", "-k") default="localhost",
help="Inspector host address [default 127.0.0.1]",
)
parser.add_argument(
"--port",
"-p",
default=6457,
type=int,
help="Inspector port [default 6457]",
)
parser.add_argument(
"--secure",
"-s",
action="store_true",
help="Whether to access the Inspector via TLS encryption",
)
parser.add_argument("--api-key", "-k", help="Inspector authentication key")
parser.add_argument( parser.add_argument(
"--raw", "--raw",
action="store_true", action="store_true",
@ -32,17 +48,25 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
dest="action", dest="action",
description=( description=(
"Run one of the below subcommands. If you have created a custom " "Run one of the below subcommands. If you have created a custom "
"Inspector instance, then you can run custom commands.\nSee ___ " "Inspector instance, then you can run custom commands. See ___ "
"for more details." "for more details."
), ),
title="Required\n========\n Subcommands", title="Required\n========\n Subcommands",
parser_class=InspectorSubParser, parser_class=InspectorSubParser,
) )
subparsers.add_parser( reloader = subparsers.add_parser(
"reload", "reload",
help="Trigger a reload of the server workers", help="Trigger a reload of the server workers",
formatter_class=SanicHelpFormatter, formatter_class=SanicHelpFormatter,
) )
reloader.add_argument(
"--zero-downtime",
action="store_true",
help=(
"Whether to wait for the new process to be online before "
"terminating the old"
),
)
subparsers.add_parser( subparsers.add_parser(
"shutdown", "shutdown",
help="Shutdown the application and all processes", help="Shutdown the application and all processes",
@ -53,7 +77,11 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
help="Scale the number of workers", help="Scale the number of workers",
formatter_class=SanicHelpFormatter, formatter_class=SanicHelpFormatter,
) )
scale.add_argument("replicas", type=int) scale.add_argument(
"replicas",
type=int,
help="Number of workers requested",
)
custom = subparsers.add_parser( custom = subparsers.add_parser(
"<custom>", "<custom>",

View File

@ -4,6 +4,7 @@ import signal
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum
from typing import Awaitable, Union from typing import Awaitable, Union
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
@ -30,6 +31,30 @@ try:
except ImportError: except ImportError:
pass pass
# Python 3.11 changed the way Enum formatting works for mixed-in types.
if sys.version_info < (3, 11, 0):
class StrEnum(str, Enum):
pass
else:
from enum import StrEnum # type: ignore # noqa
class UpperStrEnum(StrEnum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
def __eq__(self, value: object) -> bool:
value = str(value).upper()
return super().__eq__(value)
def __hash__(self) -> int:
return hash(self.value)
def __str__(self) -> str:
return self.value
@contextmanager @contextmanager
def use_context(method: StartMethod): def use_context(method: StartMethod):

View File

@ -1,19 +1,9 @@
from enum import Enum, auto from enum import auto
from sanic.compat import UpperStrEnum
class HTTPMethod(str, Enum): class HTTPMethod(UpperStrEnum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
def __eq__(self, value: object) -> bool:
value = str(value).upper()
return super().__eq__(value)
def __hash__(self) -> int:
return hash(self.value)
def __str__(self) -> str:
return self.value
GET = auto() GET = auto()
POST = auto() POST = auto()
@ -24,9 +14,7 @@ class HTTPMethod(str, Enum):
DELETE = auto() DELETE = auto()
class LocalCertCreator(str, Enum): class LocalCertCreator(UpperStrEnum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
AUTO = auto() AUTO = auto()
TRUSTME = auto() TRUSTME = auto()

View File

@ -20,7 +20,7 @@ class SignalMixin(metaclass=SanicMeta):
event: Union[str, Enum], event: Union[str, Enum],
*, *,
apply: bool = True, apply: bool = True,
condition: Dict[str, Any] = None, condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True, exclusive: bool = True,
) -> Callable[[SignalHandler], SignalHandler]: ) -> Callable[[SignalHandler], SignalHandler]:
""" """
@ -64,7 +64,7 @@ class SignalMixin(metaclass=SanicMeta):
self, self,
handler: Optional[Callable[..., Any]], handler: Optional[Callable[..., Any]],
event: str, event: str,
condition: Dict[str, Any] = None, condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True, exclusive: bool = True,
): ):
if not handler: if not handler:

View File

@ -851,7 +851,7 @@ class StartupMixin(metaclass=SanicMeta):
primary.config.INSPECTOR_TLS_KEY, primary.config.INSPECTOR_TLS_KEY,
primary.config.INSPECTOR_TLS_CERT, primary.config.INSPECTOR_TLS_CERT,
) )
manager.manage("Inspector", inspector, {}, transient=True) manager.manage("Inspector", inspector, {}, transient=False)
primary._inspector = inspector primary._inspector = inspector
primary._manager = manager primary._manager = manager

View File

@ -229,6 +229,7 @@ def _serve_http_1(
loop.run_until_complete(app._startup()) loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before")) loop.run_until_complete(app._server_event("init", "before"))
app.ack()
try: try:
http_server = loop.run_until_complete(server_coroutine) http_server = loop.run_until_complete(server_coroutine)
@ -306,6 +307,7 @@ def _serve_http_3(
server = AsyncioServer(app, loop, coro, []) server = AsyncioServer(app, loop, coro, [])
loop.run_until_complete(server.startup()) loop.run_until_complete(server.startup())
loop.run_until_complete(server.before_start()) loop.run_until_complete(server.before_start())
app.ack()
loop.run_until_complete(server) loop.run_until_complete(server)
_setup_system_signals(app, run_multiple, register_sys_signals, loop) _setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(server.after_start()) loop.run_until_complete(server.after_start())

18
sanic/worker/constants.py Normal file
View File

@ -0,0 +1,18 @@
from enum import IntEnum, auto
from sanic.compat import UpperStrEnum
class RestartOrder(UpperStrEnum):
SHUTDOWN_FIRST = auto()
STARTUP_FIRST = auto()
class ProcessState(IntEnum):
IDLE = auto()
RESTARTING = auto()
STARTING = auto()
STARTED = auto()
ACKED = auto()
JOINED = auto()
TERMINATED = auto()

View File

@ -101,8 +101,10 @@ class Inspector:
obj[key] = value.isoformat() obj[key] = value.isoformat()
return obj return obj
def reload(self) -> None: def reload(self, zero_downtime: bool = False) -> None:
message = "__ALL_PROCESSES__:" message = "__ALL_PROCESSES__:"
if zero_downtime:
message += ":STARTUP_FIRST"
self._publisher.send(message) self._publisher.send(message)
def scale(self, replicas) -> str: def scale(self, replicas) -> str:

View File

@ -10,6 +10,7 @@ from typing import Dict, List, Optional
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.worker.constants import RestartOrder
from sanic.worker.process import ProcessState, Worker, WorkerProcess from sanic.worker.process import ProcessState, Worker, WorkerProcess
@ -20,7 +21,8 @@ else:
class WorkerManager: class WorkerManager:
THRESHOLD = 300 # == 30 seconds THRESHOLD = WorkerProcess.THRESHOLD
MAIN_IDENT = "Sanic-Main"
def __init__( def __init__(
self, self,
@ -37,7 +39,7 @@ class WorkerManager:
self.durable: Dict[str, Worker] = {} self.durable: Dict[str, Worker] = {}
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid} self.worker_state[self.MAIN_IDENT] = {"pid": self.pid}
self.terminated = False self.terminated = False
self._serve = serve self._serve = serve
self._server_settings = server_settings self._server_settings = server_settings
@ -119,10 +121,15 @@ class WorkerManager:
process.terminate() process.terminate()
self.terminated = True self.terminated = True
def restart(self, process_names: Optional[List[str]] = None, **kwargs): def restart(
self,
process_names: Optional[List[str]] = None,
restart_order=RestartOrder.SHUTDOWN_FIRST,
**kwargs,
):
for process in self.transient_processes: for process in self.transient_processes:
if not process_names or process.name in process_names: if not process_names or process.name in process_names:
process.restart(**kwargs) process.restart(restart_order=restart_order, **kwargs)
def scale(self, num_worker: int): def scale(self, num_worker: int):
if num_worker <= 0: if num_worker <= 0:
@ -160,7 +167,12 @@ class WorkerManager:
elif message == "__TERMINATE__": elif message == "__TERMINATE__":
self.shutdown() self.shutdown()
break break
split_message = message.split(":", 1) logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"): if message.startswith("__SCALE__"):
self.scale(int(split_message[-1])) self.scale(int(split_message[-1]))
continue continue
@ -173,10 +185,17 @@ class WorkerManager:
] ]
if "__ALL_PROCESSES__" in process_names: if "__ALL_PROCESSES__" in process_names:
process_names = None process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart( self.restart(
process_names=process_names, process_names=process_names,
reloaded_files=reloaded_files, reloaded_files=reloaded_files,
restart_order=order,
) )
self._sync_states()
except InterruptedError: except InterruptedError:
if not OS_IS_WINDOWS: if not OS_IS_WINDOWS:
raise raise
@ -263,3 +282,9 @@ class WorkerManager:
if worker_state.get("server") if worker_state.get("server")
] ]
return all(acked) and len(acked) == self.num_server return all(acked) and len(acked) == self.num_server
def _sync_states(self):
for process in self.processes:
state = self.worker_state[process.name].get("state")
if state and process.state.name != state:
process.set_state(ProcessState[state], True)

View File

@ -2,6 +2,7 @@ from multiprocessing.connection import Connection
from os import environ, getpid from os import environ, getpid
from typing import Any, Dict from typing import Any, Dict
from sanic.log import Colors, logger
from sanic.worker.process import ProcessState from sanic.worker.process import ProcessState
from sanic.worker.state import WorkerState from sanic.worker.state import WorkerState
@ -16,12 +17,23 @@ class WorkerMultiplexer:
self._state = WorkerState(worker_state, self.name) self._state = WorkerState(worker_state, self.name)
def ack(self): def ack(self):
logger.debug(
f"{Colors.BLUE}Process ack: {Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self.pid,
)
self._state._state[self.name] = { self._state._state[self.name] = {
**self._state._state[self.name], **self._state._state[self.name],
"state": ProcessState.ACKED.name, "state": ProcessState.ACKED.name,
} }
def restart(self, name: str = "", all_workers: bool = False): def restart(
self,
name: str = "",
all_workers: bool = False,
zero_downtime: bool = False,
):
if name and all_workers: if name and all_workers:
raise ValueError( raise ValueError(
"Ambiguous restart with both a named process and" "Ambiguous restart with both a named process and"
@ -29,6 +41,10 @@ class WorkerMultiplexer:
) )
if not name: if not name:
name = "__ALL_PROCESSES__:" if all_workers else self.name name = "__ALL_PROCESSES__:" if all_workers else self.name
if not name.endswith(":"):
name += ":"
if zero_downtime:
name += ":STARTUP_FIRST"
self._monitor_publisher.send(name) self._monitor_publisher.send(name)
reload = restart # no cov reload = restart # no cov

View File

@ -1,12 +1,14 @@
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import IntEnum, auto
from multiprocessing.context import BaseContext from multiprocessing.context import BaseContext
from signal import SIGINT from signal import SIGINT
from threading import Thread
from time import sleep
from typing import Any, Dict, Set from typing import Any, Dict, Set
from sanic.log import Colors, logger from sanic.log import Colors, logger
from sanic.worker.constants import ProcessState, RestartOrder
def get_now(): def get_now():
@ -14,15 +16,8 @@ def get_now():
return now return now
class ProcessState(IntEnum):
IDLE = auto()
STARTED = auto()
ACKED = auto()
JOINED = auto()
TERMINATED = auto()
class WorkerProcess: class WorkerProcess:
THRESHOLD = 300 # == 30 seconds
SERVER_LABEL = "Server" SERVER_LABEL = "Server"
def __init__(self, factory, name, target, kwargs, worker_state): def __init__(self, factory, name, target, kwargs, worker_state):
@ -54,8 +49,9 @@ class WorkerProcess:
f"{Colors.SANIC}%s{Colors.END}", f"{Colors.SANIC}%s{Colors.END}",
self.name, self.name,
) )
self.set_state(ProcessState.STARTING)
self._current_process.start()
self.set_state(ProcessState.STARTED) self.set_state(ProcessState.STARTED)
self._process.start()
if not self.worker_state[self.name].get("starts"): if not self.worker_state[self.name].get("starts"):
self.worker_state[self.name] = { self.worker_state[self.name] = {
**self.worker_state[self.name], **self.worker_state[self.name],
@ -67,7 +63,7 @@ class WorkerProcess:
def join(self): def join(self):
self.set_state(ProcessState.JOINED) self.set_state(ProcessState.JOINED)
self._process.join() self._current_process.join()
def terminate(self): def terminate(self):
if self.state is not ProcessState.TERMINATED: if self.state is not ProcessState.TERMINATED:
@ -80,21 +76,23 @@ class WorkerProcess:
) )
self.set_state(ProcessState.TERMINATED, force=True) self.set_state(ProcessState.TERMINATED, force=True)
try: try:
# self._process.terminate()
os.kill(self.pid, SIGINT) os.kill(self.pid, SIGINT)
del self.worker_state[self.name] del self.worker_state[self.name]
except (KeyError, AttributeError, ProcessLookupError): except (KeyError, AttributeError, ProcessLookupError):
... ...
def restart(self, **kwargs): def restart(self, restart_order=RestartOrder.SHUTDOWN_FIRST, **kwargs):
logger.debug( logger.debug(
f"{Colors.BLUE}Restarting a process: {Colors.BOLD}{Colors.SANIC}" f"{Colors.BLUE}Restarting a process: {Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}", f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name, self.name,
self.pid, self.pid,
) )
self._process.terminate() self.set_state(ProcessState.RESTARTING, force=True)
self.set_state(ProcessState.IDLE, force=True) if restart_order is RestartOrder.SHUTDOWN_FIRST:
self._terminate_now()
else:
self._old_process = self._current_process
self.kwargs.update( self.kwargs.update(
{"config": {k.upper(): v for k, v in kwargs.items()}} {"config": {k.upper(): v for k, v in kwargs.items()}}
) )
@ -104,6 +102,9 @@ class WorkerProcess:
except AttributeError: except AttributeError:
raise RuntimeError("Restart failed") raise RuntimeError("Restart failed")
if restart_order is RestartOrder.STARTUP_FIRST:
self._terminate_soon()
self.worker_state[self.name] = { self.worker_state[self.name] = {
**self.worker_state[self.name], **self.worker_state[self.name],
"pid": self.pid, "pid": self.pid,
@ -113,14 +114,14 @@ class WorkerProcess:
def is_alive(self): def is_alive(self):
try: try:
return self._process.is_alive() return self._current_process.is_alive()
except AssertionError: except AssertionError:
return False return False
def spawn(self): def spawn(self):
if self.state is not ProcessState.IDLE: if self.state not in (ProcessState.IDLE, ProcessState.RESTARTING):
raise Exception("Cannot spawn a worker process until it is idle.") raise Exception("Cannot spawn a worker process until it is idle.")
self._process = self.factory( self._current_process = self.factory(
name=self.name, name=self.name,
target=self.target, target=self.target,
kwargs=self.kwargs, kwargs=self.kwargs,
@ -129,7 +130,56 @@ class WorkerProcess:
@property @property
def pid(self): def pid(self):
return self._process.pid return self._current_process.pid
def _terminate_now(self):
logger.debug(
f"{Colors.BLUE}Begin restart termination: "
f"{Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self._current_process.pid,
)
self._current_process.terminate()
def _terminate_soon(self):
logger.debug(
f"{Colors.BLUE}Begin restart termination: "
f"{Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self._current_process.pid,
)
termination_thread = Thread(target=self._wait_to_terminate)
termination_thread.start()
def _wait_to_terminate(self):
logger.debug(
f"{Colors.BLUE}Waiting for process to be acked: "
f"{Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self._old_process.pid,
)
misses = 0
while self.state is not ProcessState.ACKED:
sleep(0.1)
misses += 1
if misses > self.THRESHOLD:
raise TimeoutError(
f"Worker {self.name} failed to come ack within "
f"{self.THRESHOLD / 10} seconds"
)
else:
logger.debug(
f"{Colors.BLUE}Process acked. Terminating: "
f"{Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self._old_process.pid,
)
self._old_process.terminate()
delattr(self, "_old_process")
class Worker: class Worker:
@ -153,7 +203,11 @@ class Worker:
def create_process(self) -> WorkerProcess: def create_process(self) -> WorkerProcess:
process = WorkerProcess( process = WorkerProcess(
factory=self.context.Process, # Need to ignore this typing error - The problem is the
# BaseContext itself has no Process. But, all of its
# implementations do. We can safely ignore as it is a typing
# issue in the standard lib.
factory=self.context.Process, # type: ignore
name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}", name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}",
target=self.serve, target=self.serve,
kwargs={**self.server_settings}, kwargs={**self.server_settings},

View File

@ -321,7 +321,8 @@ def test_inspector_inspect(urlopen, caplog, capsys):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"command,params", "command,params",
( (
(["reload"], {}), (["reload"], {"zero_downtime": False}),
(["reload", "--zero-downtime"], {"zero_downtime": True}),
(["shutdown"], {}), (["shutdown"], {}),
(["scale", "9"], {"replicas": 9}), (["scale", "9"], {"replicas": 9}),
(["foo", "--bar=something"], {"bar": "something"}), (["foo", "--bar=something"], {"bar": "something"}),

View File

@ -88,6 +88,12 @@ def test_run_inspector_reload(publisher, http_client):
publisher.send.assert_called_once_with("__ALL_PROCESSES__:") publisher.send.assert_called_once_with("__ALL_PROCESSES__:")
def test_run_inspector_reload_zero_downtime(publisher, http_client):
_, response = http_client.post("/reload", json={"zero_downtime": True})
assert response.status == 200
publisher.send.assert_called_once_with("__ALL_PROCESSES__::STARTUP_FIRST")
def test_run_inspector_shutdown(publisher, http_client): def test_run_inspector_shutdown(publisher, http_client):
_, response = http_client.post("/shutdown") _, response = http_client.post("/shutdown")
assert response.status == 200 assert response.status == 200

View File

@ -6,6 +6,7 @@ import pytest
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
from sanic.worker.constants import RestartOrder
from sanic.worker.manager import WorkerManager from sanic.worker.manager import WorkerManager
@ -102,11 +103,17 @@ def test_restart_all():
) )
def test_monitor_all(): @pytest.mark.parametrize("zero_downtime", (False, True))
def test_monitor_all(zero_downtime):
p1 = Mock() p1 = Mock()
p2 = Mock() p2 = Mock()
sub = Mock() sub = Mock()
sub.recv.side_effect = ["__ALL_PROCESSES__:", ""] incoming = (
"__ALL_PROCESSES__::STARTUP_FIRST"
if zero_downtime
else "__ALL_PROCESSES__:"
)
sub.recv.side_effect = [incoming, ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
@ -114,16 +121,29 @@ def test_monitor_all():
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
restart_order = (
RestartOrder.STARTUP_FIRST
if zero_downtime
else RestartOrder.SHUTDOWN_FIRST
)
manager.restart.assert_called_once_with( manager.restart.assert_called_once_with(
process_names=None, reloaded_files="" process_names=None,
reloaded_files="",
restart_order=restart_order,
) )
def test_monitor_all_with_files(): @pytest.mark.parametrize("zero_downtime", (False, True))
def test_monitor_all_with_files(zero_downtime):
p1 = Mock() p1 = Mock()
p2 = Mock() p2 = Mock()
sub = Mock() sub = Mock()
sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""] incoming = (
"__ALL_PROCESSES__:foo,bar:STARTUP_FIRST"
if zero_downtime
else "__ALL_PROCESSES__:foo,bar"
)
sub.recv.side_effect = [incoming, ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
@ -131,17 +151,30 @@ def test_monitor_all_with_files():
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
restart_order = (
RestartOrder.STARTUP_FIRST
if zero_downtime
else RestartOrder.SHUTDOWN_FIRST
)
manager.restart.assert_called_once_with( manager.restart.assert_called_once_with(
process_names=None, reloaded_files="foo,bar" process_names=None,
reloaded_files="foo,bar",
restart_order=restart_order,
) )
def test_monitor_one_process(): @pytest.mark.parametrize("zero_downtime", (False, True))
def test_monitor_one_process(zero_downtime):
p1 = Mock() p1 = Mock()
p1.name = "Testing" p1.name = "Testing"
p2 = Mock() p2 = Mock()
sub = Mock() sub = Mock()
sub.recv.side_effect = [f"{p1.name}:foo,bar", ""] incoming = (
f"{p1.name}:foo,bar:STARTUP_FIRST"
if zero_downtime
else f"{p1.name}:foo,bar"
)
sub.recv.side_effect = [incoming, ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
@ -149,8 +182,15 @@ def test_monitor_one_process():
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
restart_order = (
RestartOrder.STARTUP_FIRST
if zero_downtime
else RestartOrder.SHUTDOWN_FIRST
)
manager.restart.assert_called_once_with( manager.restart.assert_called_once_with(
process_names=[p1.name], reloaded_files="foo,bar" process_names=[p1.name],
reloaded_files="foo,bar",
restart_order=restart_order,
) )

View File

@ -98,17 +98,17 @@ def test_ack(worker_state: Dict[str, Any], m: WorkerMultiplexer):
def test_restart_self(monitor_publisher: Mock, m: WorkerMultiplexer): def test_restart_self(monitor_publisher: Mock, m: WorkerMultiplexer):
m.restart() m.restart()
monitor_publisher.send.assert_called_once_with("Test") monitor_publisher.send.assert_called_once_with("Test:")
def test_restart_foo(monitor_publisher: Mock, m: WorkerMultiplexer): def test_restart_foo(monitor_publisher: Mock, m: WorkerMultiplexer):
m.restart("foo") m.restart("foo")
monitor_publisher.send.assert_called_once_with("foo") monitor_publisher.send.assert_called_once_with("foo:")
def test_reload_alias(monitor_publisher: Mock, m: WorkerMultiplexer): def test_reload_alias(monitor_publisher: Mock, m: WorkerMultiplexer):
m.reload() m.reload()
monitor_publisher.send.assert_called_once_with("Test") monitor_publisher.send.assert_called_once_with("Test:")
def test_terminate(monitor_publisher: Mock, m: WorkerMultiplexer): def test_terminate(monitor_publisher: Mock, m: WorkerMultiplexer):
@ -135,10 +135,20 @@ def test_properties(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"params,expected", "params,expected",
( (
({}, "Test"), ({}, "Test:"),
({"name": "foo"}, "foo"), ({"name": "foo"}, "foo:"),
({"all_workers": True}, "__ALL_PROCESSES__:"), ({"all_workers": True}, "__ALL_PROCESSES__:"),
({"zero_downtime": True}, "Test::STARTUP_FIRST"),
({"name": "foo", "all_workers": True}, ValueError), ({"name": "foo", "all_workers": True}, ValueError),
({"name": "foo", "zero_downtime": True}, "foo::STARTUP_FIRST"),
(
{"all_workers": True, "zero_downtime": True},
"__ALL_PROCESSES__::STARTUP_FIRST",
),
(
{"name": "foo", "all_workers": True, "zero_downtime": True},
ValueError,
),
), ),
) )
def test_restart_params( def test_restart_params(

View File

@ -1,13 +1,19 @@
import re
import signal import signal
import threading
from asyncio import Event from asyncio import Event
from logging import DEBUG
from pathlib import Path from pathlib import Path
from time import sleep
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from sanic.app import Sanic from sanic.app import Sanic
from sanic.worker.constants import ProcessState, RestartOrder
from sanic.worker.loader import AppLoader from sanic.worker.loader import AppLoader
from sanic.worker.process import WorkerProcess
from sanic.worker.reloader import Reloader from sanic.worker.reloader import Reloader
@ -67,6 +73,88 @@ def test_iter_files():
assert len_total_files == len_python_files + len_static_files assert len_total_files == len_python_files + len_static_files
@pytest.mark.parametrize(
"order,expected",
(
(
RestartOrder.SHUTDOWN_FIRST,
[
"Restarting a process",
"Begin restart termination",
"Starting a process",
],
),
(
RestartOrder.STARTUP_FIRST,
[
"Restarting a process",
"Starting a process",
"Begin restart termination",
"Waiting for process to be acked",
"Process acked. Terminating",
],
),
),
)
def test_default_reload_shutdown_order(monkeypatch, caplog, order, expected):
current_process = Mock()
worker_process = WorkerProcess(
lambda **_: current_process,
"Test",
lambda **_: ...,
{},
{},
)
def start(self):
worker_process.set_state(ProcessState.ACKED)
self._target()
orig = threading.Thread.start
monkeypatch.setattr(threading.Thread, "start", start)
with caplog.at_level(DEBUG):
worker_process.restart(restart_order=order)
ansi = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
def clean(msg: str):
msg, _ = ansi.sub("", msg).split(":", 1)
return msg
debug = [clean(record[2]) for record in caplog.record_tuples]
assert debug == expected
current_process.start.assert_called_once()
current_process.terminate.assert_called_once()
monkeypatch.setattr(threading.Thread, "start", orig)
def test_reload_delayed(monkeypatch):
WorkerProcess.THRESHOLD = 1
current_process = Mock()
worker_process = WorkerProcess(
lambda **_: current_process,
"Test",
lambda **_: ...,
{},
{},
)
def start(self):
sleep(0.2)
self._target()
orig = threading.Thread.start
monkeypatch.setattr(threading.Thread, "start", start)
message = "Worker Test failed to come ack within 0.1 seconds"
with pytest.raises(TimeoutError, match=message):
worker_process.restart(restart_order=RestartOrder.STARTUP_FIRST)
monkeypatch.setattr(threading.Thread, "start", orig)
def test_reloader_triggers_start_stop_listeners( def test_reloader_triggers_start_stop_listeners(
app: Sanic, app_loader: AppLoader app: Sanic, app_loader: AppLoader
): ):

View File

@ -118,7 +118,7 @@ def test_serve_with_inspector(
if config: if config:
Inspector.assert_called_once() Inspector.assert_called_once()
WorkerManager.manage.assert_called_once_with( WorkerManager.manage.assert_called_once_with(
"Inspector", inspector, {}, transient=True "Inspector", inspector, {}, transient=False
) )
else: else:
Inspector.assert_not_called() Inspector.assert_not_called()