From f7040ccec8ad6e90953c3364ca08c76348224fce Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 18 Dec 2022 14:09:17 +0200 Subject: [PATCH] Implement restart ordering (#2632) --- sanic/app.py | 1 + sanic/cli/base.py | 2 +- sanic/cli/inspector.py | 42 +++++++++++--- sanic/compat.py | 25 ++++++++ sanic/constants.py | 22 ++------ sanic/mixins/signals.py | 4 +- sanic/mixins/startup.py | 2 +- sanic/server/runners.py | 2 + sanic/worker/constants.py | 18 ++++++ sanic/worker/inspector.py | 4 +- sanic/worker/manager.py | 35 ++++++++++-- sanic/worker/multiplexer.py | 18 +++++- sanic/worker/process.py | 94 ++++++++++++++++++++++++------- tests/test_cli.py | 3 +- tests/worker/test_inspector.py | 6 ++ tests/worker/test_manager.py | 58 ++++++++++++++++--- tests/worker/test_multiplexer.py | 20 +++++-- tests/worker/test_reloader.py | 88 +++++++++++++++++++++++++++++ tests/worker/test_worker_serve.py | 2 +- 19 files changed, 375 insertions(+), 71 deletions(-) create mode 100644 sanic/worker/constants.py diff --git a/sanic/app.py b/sanic/app.py index c951009b..8d098663 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1534,6 +1534,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): self.state.is_started = True + def ack(self): if hasattr(self, "multiplexer"): self.multiplexer.ack() diff --git a/sanic/cli/base.py b/sanic/cli/base.py index 270c4118..c9335475 100644 --- a/sanic/cli/base.py +++ b/sanic/cli/base.py @@ -20,7 +20,7 @@ class SanicHelpFormatter(RawTextHelpFormatter): if not usage: usage = SUPPRESS # Add one linebreak, but not two - self.add_text("\x1b[1A'") + self.add_text("\x1b[1A") super().add_usage(usage, actions, groups, prefix) diff --git a/sanic/cli/inspector.py b/sanic/cli/inspector.py index 5a1719b4..8bb8e90b 100644 --- a/sanic/cli/inspector.py +++ b/sanic/cli/inspector.py @@ -5,10 +5,26 @@ from sanic.cli.base import SanicHelpFormatter, SanicSubParsersAction def _add_shared(parser: ArgumentParser) -> None: - parser.add_argument("--host", "-H", default="localhost") - parser.add_argument("--port", "-p", default=6457, type=int) - parser.add_argument("--secure", "-s", action="store_true") - parser.add_argument("--api-key", "-k") + parser.add_argument( + "--host", + "-H", + default="localhost", + help="Inspector host address [default 127.0.0.1]", + ) + parser.add_argument( + "--port", + "-p", + default=6457, + type=int, + help="Inspector port [default 6457]", + ) + parser.add_argument( + "--secure", + "-s", + action="store_true", + help="Whether to access the Inspector via TLS encryption", + ) + parser.add_argument("--api-key", "-k", help="Inspector authentication key") parser.add_argument( "--raw", action="store_true", @@ -32,17 +48,25 @@ def make_inspector_parser(parser: ArgumentParser) -> None: dest="action", description=( "Run one of the below subcommands. If you have created a custom " - "Inspector instance, then you can run custom commands.\nSee ___ " + "Inspector instance, then you can run custom commands. See ___ " "for more details." ), title="Required\n========\n Subcommands", parser_class=InspectorSubParser, ) - subparsers.add_parser( + reloader = subparsers.add_parser( "reload", help="Trigger a reload of the server workers", formatter_class=SanicHelpFormatter, ) + reloader.add_argument( + "--zero-downtime", + action="store_true", + help=( + "Whether to wait for the new process to be online before " + "terminating the old" + ), + ) subparsers.add_parser( "shutdown", help="Shutdown the application and all processes", @@ -53,7 +77,11 @@ def make_inspector_parser(parser: ArgumentParser) -> None: help="Scale the number of workers", formatter_class=SanicHelpFormatter, ) - scale.add_argument("replicas", type=int) + scale.add_argument( + "replicas", + type=int, + help="Number of workers requested", + ) custom = subparsers.add_parser( "", diff --git a/sanic/compat.py b/sanic/compat.py index 26b0bcde..35769876 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -4,6 +4,7 @@ import signal import sys from contextlib import contextmanager +from enum import Enum from typing import Awaitable, Union from multidict import CIMultiDict # type: ignore @@ -30,6 +31,30 @@ try: except ImportError: pass +# Python 3.11 changed the way Enum formatting works for mixed-in types. +if sys.version_info < (3, 11, 0): + + class StrEnum(str, Enum): + pass + +else: + from enum import StrEnum # type: ignore # noqa + + +class UpperStrEnum(StrEnum): + def _generate_next_value_(name, start, count, last_values): + return name.upper() + + def __eq__(self, value: object) -> bool: + value = str(value).upper() + return super().__eq__(value) + + def __hash__(self) -> int: + return hash(self.value) + + def __str__(self) -> str: + return self.value + @contextmanager def use_context(method: StartMethod): diff --git a/sanic/constants.py b/sanic/constants.py index 988d8bae..ab23c86d 100644 --- a/sanic/constants.py +++ b/sanic/constants.py @@ -1,19 +1,9 @@ -from enum import Enum, auto +from enum import auto + +from sanic.compat import UpperStrEnum -class HTTPMethod(str, Enum): - def _generate_next_value_(name, start, count, last_values): - return name.upper() - - def __eq__(self, value: object) -> bool: - value = str(value).upper() - return super().__eq__(value) - - def __hash__(self) -> int: - return hash(self.value) - - def __str__(self) -> str: - return self.value +class HTTPMethod(UpperStrEnum): GET = auto() POST = auto() @@ -24,9 +14,7 @@ class HTTPMethod(str, Enum): DELETE = auto() -class LocalCertCreator(str, Enum): - def _generate_next_value_(name, start, count, last_values): - return name.upper() +class LocalCertCreator(UpperStrEnum): AUTO = auto() TRUSTME = auto() diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index 601c4b18..d3d67ac6 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -20,7 +20,7 @@ class SignalMixin(metaclass=SanicMeta): event: Union[str, Enum], *, apply: bool = True, - condition: Dict[str, Any] = None, + condition: Optional[Dict[str, Any]] = None, exclusive: bool = True, ) -> Callable[[SignalHandler], SignalHandler]: """ @@ -64,7 +64,7 @@ class SignalMixin(metaclass=SanicMeta): self, handler: Optional[Callable[..., Any]], event: str, - condition: Dict[str, Any] = None, + condition: Optional[Dict[str, Any]] = None, exclusive: bool = True, ): if not handler: diff --git a/sanic/mixins/startup.py b/sanic/mixins/startup.py index 79517257..a6126433 100644 --- a/sanic/mixins/startup.py +++ b/sanic/mixins/startup.py @@ -851,7 +851,7 @@ class StartupMixin(metaclass=SanicMeta): primary.config.INSPECTOR_TLS_KEY, primary.config.INSPECTOR_TLS_CERT, ) - manager.manage("Inspector", inspector, {}, transient=True) + manager.manage("Inspector", inspector, {}, transient=False) primary._inspector = inspector primary._manager = manager diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 92be4b31..73df3baf 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -229,6 +229,7 @@ def _serve_http_1( loop.run_until_complete(app._startup()) loop.run_until_complete(app._server_event("init", "before")) + app.ack() try: http_server = loop.run_until_complete(server_coroutine) @@ -306,6 +307,7 @@ def _serve_http_3( server = AsyncioServer(app, loop, coro, []) loop.run_until_complete(server.startup()) loop.run_until_complete(server.before_start()) + app.ack() loop.run_until_complete(server) _setup_system_signals(app, run_multiple, register_sys_signals, loop) loop.run_until_complete(server.after_start()) diff --git a/sanic/worker/constants.py b/sanic/worker/constants.py new file mode 100644 index 00000000..15997c29 --- /dev/null +++ b/sanic/worker/constants.py @@ -0,0 +1,18 @@ +from enum import IntEnum, auto + +from sanic.compat import UpperStrEnum + + +class RestartOrder(UpperStrEnum): + SHUTDOWN_FIRST = auto() + STARTUP_FIRST = auto() + + +class ProcessState(IntEnum): + IDLE = auto() + RESTARTING = auto() + STARTING = auto() + STARTED = auto() + ACKED = auto() + JOINED = auto() + TERMINATED = auto() diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index d60c5ae3..487ef3c5 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -101,8 +101,10 @@ class Inspector: obj[key] = value.isoformat() return obj - def reload(self) -> None: + def reload(self, zero_downtime: bool = False) -> None: message = "__ALL_PROCESSES__:" + if zero_downtime: + message += ":STARTUP_FIRST" self._publisher.send(message) def scale(self, replicas) -> str: diff --git a/sanic/worker/manager.py b/sanic/worker/manager.py index 2f09b818..e4deaf1e 100644 --- a/sanic/worker/manager.py +++ b/sanic/worker/manager.py @@ -10,6 +10,7 @@ from typing import Dict, List, Optional from sanic.compat import OS_IS_WINDOWS from sanic.exceptions import ServerKilled from sanic.log import error_logger, logger +from sanic.worker.constants import RestartOrder from sanic.worker.process import ProcessState, Worker, WorkerProcess @@ -20,7 +21,8 @@ else: class WorkerManager: - THRESHOLD = 300 # == 30 seconds + THRESHOLD = WorkerProcess.THRESHOLD + MAIN_IDENT = "Sanic-Main" def __init__( self, @@ -37,7 +39,7 @@ class WorkerManager: self.durable: Dict[str, Worker] = {} self.monitor_publisher, self.monitor_subscriber = monitor_pubsub self.worker_state = worker_state - self.worker_state["Sanic-Main"] = {"pid": self.pid} + self.worker_state[self.MAIN_IDENT] = {"pid": self.pid} self.terminated = False self._serve = serve self._server_settings = server_settings @@ -119,10 +121,15 @@ class WorkerManager: process.terminate() self.terminated = True - def restart(self, process_names: Optional[List[str]] = None, **kwargs): + def restart( + self, + process_names: Optional[List[str]] = None, + restart_order=RestartOrder.SHUTDOWN_FIRST, + **kwargs, + ): for process in self.transient_processes: if not process_names or process.name in process_names: - process.restart(**kwargs) + process.restart(restart_order=restart_order, **kwargs) def scale(self, num_worker: int): if num_worker <= 0: @@ -160,7 +167,12 @@ class WorkerManager: elif message == "__TERMINATE__": self.shutdown() break - split_message = message.split(":", 1) + logger.debug( + "Incoming monitor message: %s", + message, + extra={"verbosity": 1}, + ) + split_message = message.split(":", 2) if message.startswith("__SCALE__"): self.scale(int(split_message[-1])) continue @@ -173,10 +185,17 @@ class WorkerManager: ] if "__ALL_PROCESSES__" in process_names: process_names = None + order = ( + RestartOrder.STARTUP_FIRST + if "STARTUP_FIRST" in split_message + else RestartOrder.SHUTDOWN_FIRST + ) self.restart( process_names=process_names, reloaded_files=reloaded_files, + restart_order=order, ) + self._sync_states() except InterruptedError: if not OS_IS_WINDOWS: raise @@ -263,3 +282,9 @@ class WorkerManager: if worker_state.get("server") ] return all(acked) and len(acked) == self.num_server + + def _sync_states(self): + for process in self.processes: + state = self.worker_state[process.name].get("state") + if state and process.state.name != state: + process.set_state(ProcessState[state], True) diff --git a/sanic/worker/multiplexer.py b/sanic/worker/multiplexer.py index 0fb4085f..7bd956e4 100644 --- a/sanic/worker/multiplexer.py +++ b/sanic/worker/multiplexer.py @@ -2,6 +2,7 @@ from multiprocessing.connection import Connection from os import environ, getpid from typing import Any, Dict +from sanic.log import Colors, logger from sanic.worker.process import ProcessState from sanic.worker.state import WorkerState @@ -16,12 +17,23 @@ class WorkerMultiplexer: self._state = WorkerState(worker_state, self.name) def ack(self): + logger.debug( + f"{Colors.BLUE}Process ack: {Colors.BOLD}{Colors.SANIC}" + f"%s {Colors.BLUE}[%s]{Colors.END}", + self.name, + self.pid, + ) self._state._state[self.name] = { **self._state._state[self.name], "state": ProcessState.ACKED.name, } - def restart(self, name: str = "", all_workers: bool = False): + def restart( + self, + name: str = "", + all_workers: bool = False, + zero_downtime: bool = False, + ): if name and all_workers: raise ValueError( "Ambiguous restart with both a named process and" @@ -29,6 +41,10 @@ class WorkerMultiplexer: ) if not name: name = "__ALL_PROCESSES__:" if all_workers else self.name + if not name.endswith(":"): + name += ":" + if zero_downtime: + name += ":STARTUP_FIRST" self._monitor_publisher.send(name) reload = restart # no cov diff --git a/sanic/worker/process.py b/sanic/worker/process.py index 7301ca2c..58d28c1f 100644 --- a/sanic/worker/process.py +++ b/sanic/worker/process.py @@ -1,12 +1,14 @@ import os from datetime import datetime, timezone -from enum import IntEnum, auto from multiprocessing.context import BaseContext from signal import SIGINT +from threading import Thread +from time import sleep from typing import Any, Dict, Set from sanic.log import Colors, logger +from sanic.worker.constants import ProcessState, RestartOrder def get_now(): @@ -14,15 +16,8 @@ def get_now(): return now -class ProcessState(IntEnum): - IDLE = auto() - STARTED = auto() - ACKED = auto() - JOINED = auto() - TERMINATED = auto() - - class WorkerProcess: + THRESHOLD = 300 # == 30 seconds SERVER_LABEL = "Server" def __init__(self, factory, name, target, kwargs, worker_state): @@ -54,8 +49,9 @@ class WorkerProcess: f"{Colors.SANIC}%s{Colors.END}", self.name, ) + self.set_state(ProcessState.STARTING) + self._current_process.start() self.set_state(ProcessState.STARTED) - self._process.start() if not self.worker_state[self.name].get("starts"): self.worker_state[self.name] = { **self.worker_state[self.name], @@ -67,7 +63,7 @@ class WorkerProcess: def join(self): self.set_state(ProcessState.JOINED) - self._process.join() + self._current_process.join() def terminate(self): if self.state is not ProcessState.TERMINATED: @@ -80,21 +76,23 @@ class WorkerProcess: ) self.set_state(ProcessState.TERMINATED, force=True) try: - # self._process.terminate() os.kill(self.pid, SIGINT) del self.worker_state[self.name] except (KeyError, AttributeError, ProcessLookupError): ... - def restart(self, **kwargs): + def restart(self, restart_order=RestartOrder.SHUTDOWN_FIRST, **kwargs): logger.debug( f"{Colors.BLUE}Restarting a process: {Colors.BOLD}{Colors.SANIC}" f"%s {Colors.BLUE}[%s]{Colors.END}", self.name, self.pid, ) - self._process.terminate() - self.set_state(ProcessState.IDLE, force=True) + self.set_state(ProcessState.RESTARTING, force=True) + if restart_order is RestartOrder.SHUTDOWN_FIRST: + self._terminate_now() + else: + self._old_process = self._current_process self.kwargs.update( {"config": {k.upper(): v for k, v in kwargs.items()}} ) @@ -104,6 +102,9 @@ class WorkerProcess: except AttributeError: raise RuntimeError("Restart failed") + if restart_order is RestartOrder.STARTUP_FIRST: + self._terminate_soon() + self.worker_state[self.name] = { **self.worker_state[self.name], "pid": self.pid, @@ -113,14 +114,14 @@ class WorkerProcess: def is_alive(self): try: - return self._process.is_alive() + return self._current_process.is_alive() except AssertionError: return False def spawn(self): - if self.state is not ProcessState.IDLE: + if self.state not in (ProcessState.IDLE, ProcessState.RESTARTING): raise Exception("Cannot spawn a worker process until it is idle.") - self._process = self.factory( + self._current_process = self.factory( name=self.name, target=self.target, kwargs=self.kwargs, @@ -129,7 +130,56 @@ class WorkerProcess: @property def pid(self): - return self._process.pid + return self._current_process.pid + + def _terminate_now(self): + logger.debug( + f"{Colors.BLUE}Begin restart termination: " + f"{Colors.BOLD}{Colors.SANIC}" + f"%s {Colors.BLUE}[%s]{Colors.END}", + self.name, + self._current_process.pid, + ) + self._current_process.terminate() + + def _terminate_soon(self): + logger.debug( + f"{Colors.BLUE}Begin restart termination: " + f"{Colors.BOLD}{Colors.SANIC}" + f"%s {Colors.BLUE}[%s]{Colors.END}", + self.name, + self._current_process.pid, + ) + termination_thread = Thread(target=self._wait_to_terminate) + termination_thread.start() + + def _wait_to_terminate(self): + logger.debug( + f"{Colors.BLUE}Waiting for process to be acked: " + f"{Colors.BOLD}{Colors.SANIC}" + f"%s {Colors.BLUE}[%s]{Colors.END}", + self.name, + self._old_process.pid, + ) + misses = 0 + while self.state is not ProcessState.ACKED: + sleep(0.1) + misses += 1 + if misses > self.THRESHOLD: + raise TimeoutError( + f"Worker {self.name} failed to come ack within " + f"{self.THRESHOLD / 10} seconds" + ) + else: + logger.debug( + f"{Colors.BLUE}Process acked. Terminating: " + f"{Colors.BOLD}{Colors.SANIC}" + f"%s {Colors.BLUE}[%s]{Colors.END}", + self.name, + self._old_process.pid, + ) + self._old_process.terminate() + delattr(self, "_old_process") class Worker: @@ -153,7 +203,11 @@ class Worker: def create_process(self) -> WorkerProcess: process = WorkerProcess( - factory=self.context.Process, + # Need to ignore this typing error - The problem is the + # BaseContext itself has no Process. But, all of its + # implementations do. We can safely ignore as it is a typing + # issue in the standard lib. + factory=self.context.Process, # type: ignore name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}", target=self.serve, kwargs={**self.server_settings}, diff --git a/tests/test_cli.py b/tests/test_cli.py index c8e0c79b..eff61a18 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -321,7 +321,8 @@ def test_inspector_inspect(urlopen, caplog, capsys): @pytest.mark.parametrize( "command,params", ( - (["reload"], {}), + (["reload"], {"zero_downtime": False}), + (["reload", "--zero-downtime"], {"zero_downtime": True}), (["shutdown"], {}), (["scale", "9"], {"replicas": 9}), (["foo", "--bar=something"], {"bar": "something"}), diff --git a/tests/worker/test_inspector.py b/tests/worker/test_inspector.py index bc9dd2d5..b1773203 100644 --- a/tests/worker/test_inspector.py +++ b/tests/worker/test_inspector.py @@ -88,6 +88,12 @@ def test_run_inspector_reload(publisher, http_client): publisher.send.assert_called_once_with("__ALL_PROCESSES__:") +def test_run_inspector_reload_zero_downtime(publisher, http_client): + _, response = http_client.post("/reload", json={"zero_downtime": True}) + assert response.status == 200 + publisher.send.assert_called_once_with("__ALL_PROCESSES__::STARTUP_FIRST") + + def test_run_inspector_shutdown(publisher, http_client): _, response = http_client.post("/shutdown") assert response.status == 200 diff --git a/tests/worker/test_manager.py b/tests/worker/test_manager.py index 85f673dd..6940c83f 100644 --- a/tests/worker/test_manager.py +++ b/tests/worker/test_manager.py @@ -6,6 +6,7 @@ import pytest from sanic.compat import OS_IS_WINDOWS from sanic.exceptions import ServerKilled +from sanic.worker.constants import RestartOrder from sanic.worker.manager import WorkerManager @@ -102,11 +103,17 @@ def test_restart_all(): ) -def test_monitor_all(): +@pytest.mark.parametrize("zero_downtime", (False, True)) +def test_monitor_all(zero_downtime): p1 = Mock() p2 = Mock() sub = Mock() - sub.recv.side_effect = ["__ALL_PROCESSES__:", ""] + incoming = ( + "__ALL_PROCESSES__::STARTUP_FIRST" + if zero_downtime + else "__ALL_PROCESSES__:" + ) + sub.recv.side_effect = [incoming, ""] context = Mock() context.Process.side_effect = [p1, p2] manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) @@ -114,16 +121,29 @@ def test_monitor_all(): manager.wait_for_ack = Mock() # type: ignore manager.monitor() + restart_order = ( + RestartOrder.STARTUP_FIRST + if zero_downtime + else RestartOrder.SHUTDOWN_FIRST + ) manager.restart.assert_called_once_with( - process_names=None, reloaded_files="" + process_names=None, + reloaded_files="", + restart_order=restart_order, ) -def test_monitor_all_with_files(): +@pytest.mark.parametrize("zero_downtime", (False, True)) +def test_monitor_all_with_files(zero_downtime): p1 = Mock() p2 = Mock() sub = Mock() - sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""] + incoming = ( + "__ALL_PROCESSES__:foo,bar:STARTUP_FIRST" + if zero_downtime + else "__ALL_PROCESSES__:foo,bar" + ) + sub.recv.side_effect = [incoming, ""] context = Mock() context.Process.side_effect = [p1, p2] manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) @@ -131,17 +151,30 @@ def test_monitor_all_with_files(): manager.wait_for_ack = Mock() # type: ignore manager.monitor() + restart_order = ( + RestartOrder.STARTUP_FIRST + if zero_downtime + else RestartOrder.SHUTDOWN_FIRST + ) manager.restart.assert_called_once_with( - process_names=None, reloaded_files="foo,bar" + process_names=None, + reloaded_files="foo,bar", + restart_order=restart_order, ) -def test_monitor_one_process(): +@pytest.mark.parametrize("zero_downtime", (False, True)) +def test_monitor_one_process(zero_downtime): p1 = Mock() p1.name = "Testing" p2 = Mock() sub = Mock() - sub.recv.side_effect = [f"{p1.name}:foo,bar", ""] + incoming = ( + f"{p1.name}:foo,bar:STARTUP_FIRST" + if zero_downtime + else f"{p1.name}:foo,bar" + ) + sub.recv.side_effect = [incoming, ""] context = Mock() context.Process.side_effect = [p1, p2] manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {}) @@ -149,8 +182,15 @@ def test_monitor_one_process(): manager.wait_for_ack = Mock() # type: ignore manager.monitor() + restart_order = ( + RestartOrder.STARTUP_FIRST + if zero_downtime + else RestartOrder.SHUTDOWN_FIRST + ) manager.restart.assert_called_once_with( - process_names=[p1.name], reloaded_files="foo,bar" + process_names=[p1.name], + reloaded_files="foo,bar", + restart_order=restart_order, ) diff --git a/tests/worker/test_multiplexer.py b/tests/worker/test_multiplexer.py index 075a677c..88072cb7 100644 --- a/tests/worker/test_multiplexer.py +++ b/tests/worker/test_multiplexer.py @@ -98,17 +98,17 @@ def test_ack(worker_state: Dict[str, Any], m: WorkerMultiplexer): def test_restart_self(monitor_publisher: Mock, m: WorkerMultiplexer): m.restart() - monitor_publisher.send.assert_called_once_with("Test") + monitor_publisher.send.assert_called_once_with("Test:") def test_restart_foo(monitor_publisher: Mock, m: WorkerMultiplexer): m.restart("foo") - monitor_publisher.send.assert_called_once_with("foo") + monitor_publisher.send.assert_called_once_with("foo:") def test_reload_alias(monitor_publisher: Mock, m: WorkerMultiplexer): m.reload() - monitor_publisher.send.assert_called_once_with("Test") + monitor_publisher.send.assert_called_once_with("Test:") def test_terminate(monitor_publisher: Mock, m: WorkerMultiplexer): @@ -135,10 +135,20 @@ def test_properties( @pytest.mark.parametrize( "params,expected", ( - ({}, "Test"), - ({"name": "foo"}, "foo"), + ({}, "Test:"), + ({"name": "foo"}, "foo:"), ({"all_workers": True}, "__ALL_PROCESSES__:"), + ({"zero_downtime": True}, "Test::STARTUP_FIRST"), ({"name": "foo", "all_workers": True}, ValueError), + ({"name": "foo", "zero_downtime": True}, "foo::STARTUP_FIRST"), + ( + {"all_workers": True, "zero_downtime": True}, + "__ALL_PROCESSES__::STARTUP_FIRST", + ), + ( + {"name": "foo", "all_workers": True, "zero_downtime": True}, + ValueError, + ), ), ) def test_restart_params( diff --git a/tests/worker/test_reloader.py b/tests/worker/test_reloader.py index 38daea74..6530c770 100644 --- a/tests/worker/test_reloader.py +++ b/tests/worker/test_reloader.py @@ -1,13 +1,19 @@ +import re import signal +import threading from asyncio import Event +from logging import DEBUG from pathlib import Path +from time import sleep from unittest.mock import Mock import pytest from sanic.app import Sanic +from sanic.worker.constants import ProcessState, RestartOrder from sanic.worker.loader import AppLoader +from sanic.worker.process import WorkerProcess from sanic.worker.reloader import Reloader @@ -67,6 +73,88 @@ def test_iter_files(): assert len_total_files == len_python_files + len_static_files +@pytest.mark.parametrize( + "order,expected", + ( + ( + RestartOrder.SHUTDOWN_FIRST, + [ + "Restarting a process", + "Begin restart termination", + "Starting a process", + ], + ), + ( + RestartOrder.STARTUP_FIRST, + [ + "Restarting a process", + "Starting a process", + "Begin restart termination", + "Waiting for process to be acked", + "Process acked. Terminating", + ], + ), + ), +) +def test_default_reload_shutdown_order(monkeypatch, caplog, order, expected): + current_process = Mock() + worker_process = WorkerProcess( + lambda **_: current_process, + "Test", + lambda **_: ..., + {}, + {}, + ) + + def start(self): + worker_process.set_state(ProcessState.ACKED) + self._target() + + orig = threading.Thread.start + monkeypatch.setattr(threading.Thread, "start", start) + + with caplog.at_level(DEBUG): + worker_process.restart(restart_order=order) + + ansi = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + + def clean(msg: str): + msg, _ = ansi.sub("", msg).split(":", 1) + return msg + + debug = [clean(record[2]) for record in caplog.record_tuples] + assert debug == expected + current_process.start.assert_called_once() + current_process.terminate.assert_called_once() + monkeypatch.setattr(threading.Thread, "start", orig) + + +def test_reload_delayed(monkeypatch): + WorkerProcess.THRESHOLD = 1 + + current_process = Mock() + worker_process = WorkerProcess( + lambda **_: current_process, + "Test", + lambda **_: ..., + {}, + {}, + ) + + def start(self): + sleep(0.2) + self._target() + + orig = threading.Thread.start + monkeypatch.setattr(threading.Thread, "start", start) + + message = "Worker Test failed to come ack within 0.1 seconds" + with pytest.raises(TimeoutError, match=message): + worker_process.restart(restart_order=RestartOrder.STARTUP_FIRST) + + monkeypatch.setattr(threading.Thread, "start", orig) + + def test_reloader_triggers_start_stop_listeners( app: Sanic, app_loader: AppLoader ): diff --git a/tests/worker/test_worker_serve.py b/tests/worker/test_worker_serve.py index a33e3cac..54ff9d65 100644 --- a/tests/worker/test_worker_serve.py +++ b/tests/worker/test_worker_serve.py @@ -118,7 +118,7 @@ def test_serve_with_inspector( if config: Inspector.assert_called_once() WorkerManager.manage.assert_called_once_with( - "Inspector", inspector, {}, transient=True + "Inspector", inspector, {}, transient=False ) else: Inspector.assert_not_called()