Implement restart ordering (#2632)
This commit is contained in:
parent
518152d97e
commit
f7040ccec8
|
@ -1534,6 +1534,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta):
|
|||
|
||||
self.state.is_started = True
|
||||
|
||||
def ack(self):
|
||||
if hasattr(self, "multiplexer"):
|
||||
self.multiplexer.ack()
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class SanicHelpFormatter(RawTextHelpFormatter):
|
|||
if not usage:
|
||||
usage = SUPPRESS
|
||||
# Add one linebreak, but not two
|
||||
self.add_text("\x1b[1A'")
|
||||
self.add_text("\x1b[1A")
|
||||
super().add_usage(usage, actions, groups, prefix)
|
||||
|
||||
|
||||
|
|
|
@ -5,10 +5,26 @@ from sanic.cli.base import SanicHelpFormatter, SanicSubParsersAction
|
|||
|
||||
|
||||
def _add_shared(parser: ArgumentParser) -> None:
|
||||
parser.add_argument("--host", "-H", default="localhost")
|
||||
parser.add_argument("--port", "-p", default=6457, type=int)
|
||||
parser.add_argument("--secure", "-s", action="store_true")
|
||||
parser.add_argument("--api-key", "-k")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
"-H",
|
||||
default="localhost",
|
||||
help="Inspector host address [default 127.0.0.1]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
"-p",
|
||||
default=6457,
|
||||
type=int,
|
||||
help="Inspector port [default 6457]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secure",
|
||||
"-s",
|
||||
action="store_true",
|
||||
help="Whether to access the Inspector via TLS encryption",
|
||||
)
|
||||
parser.add_argument("--api-key", "-k", help="Inspector authentication key")
|
||||
parser.add_argument(
|
||||
"--raw",
|
||||
action="store_true",
|
||||
|
@ -32,17 +48,25 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
|
|||
dest="action",
|
||||
description=(
|
||||
"Run one of the below subcommands. If you have created a custom "
|
||||
"Inspector instance, then you can run custom commands.\nSee ___ "
|
||||
"Inspector instance, then you can run custom commands. See ___ "
|
||||
"for more details."
|
||||
),
|
||||
title="Required\n========\n Subcommands",
|
||||
parser_class=InspectorSubParser,
|
||||
)
|
||||
subparsers.add_parser(
|
||||
reloader = subparsers.add_parser(
|
||||
"reload",
|
||||
help="Trigger a reload of the server workers",
|
||||
formatter_class=SanicHelpFormatter,
|
||||
)
|
||||
reloader.add_argument(
|
||||
"--zero-downtime",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to wait for the new process to be online before "
|
||||
"terminating the old"
|
||||
),
|
||||
)
|
||||
subparsers.add_parser(
|
||||
"shutdown",
|
||||
help="Shutdown the application and all processes",
|
||||
|
@ -53,7 +77,11 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
|
|||
help="Scale the number of workers",
|
||||
formatter_class=SanicHelpFormatter,
|
||||
)
|
||||
scale.add_argument("replicas", type=int)
|
||||
scale.add_argument(
|
||||
"replicas",
|
||||
type=int,
|
||||
help="Number of workers requested",
|
||||
)
|
||||
|
||||
custom = subparsers.add_parser(
|
||||
"<custom>",
|
||||
|
|
|
@ -4,6 +4,7 @@ import signal
|
|||
import sys
|
||||
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Union
|
||||
|
||||
from multidict import CIMultiDict # type: ignore
|
||||
|
@ -30,6 +31,30 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
# Python 3.11 changed the way Enum formatting works for mixed-in types.
|
||||
if sys.version_info < (3, 11, 0):
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
pass
|
||||
|
||||
else:
|
||||
from enum import StrEnum # type: ignore # noqa
|
||||
|
||||
|
||||
class UpperStrEnum(StrEnum):
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name.upper()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
value = str(value).upper()
|
||||
return super().__eq__(value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_context(method: StartMethod):
|
||||
|
|
|
@ -1,19 +1,9 @@
|
|||
from enum import Enum, auto
|
||||
from enum import auto
|
||||
|
||||
from sanic.compat import UpperStrEnum
|
||||
|
||||
|
||||
class HTTPMethod(str, Enum):
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name.upper()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
value = str(value).upper()
|
||||
return super().__eq__(value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
class HTTPMethod(UpperStrEnum):
|
||||
|
||||
GET = auto()
|
||||
POST = auto()
|
||||
|
@ -24,9 +14,7 @@ class HTTPMethod(str, Enum):
|
|||
DELETE = auto()
|
||||
|
||||
|
||||
class LocalCertCreator(str, Enum):
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name.upper()
|
||||
class LocalCertCreator(UpperStrEnum):
|
||||
|
||||
AUTO = auto()
|
||||
TRUSTME = auto()
|
||||
|
|
|
@ -20,7 +20,7 @@ class SignalMixin(metaclass=SanicMeta):
|
|||
event: Union[str, Enum],
|
||||
*,
|
||||
apply: bool = True,
|
||||
condition: Dict[str, Any] = None,
|
||||
condition: Optional[Dict[str, Any]] = None,
|
||||
exclusive: bool = True,
|
||||
) -> Callable[[SignalHandler], SignalHandler]:
|
||||
"""
|
||||
|
@ -64,7 +64,7 @@ class SignalMixin(metaclass=SanicMeta):
|
|||
self,
|
||||
handler: Optional[Callable[..., Any]],
|
||||
event: str,
|
||||
condition: Dict[str, Any] = None,
|
||||
condition: Optional[Dict[str, Any]] = None,
|
||||
exclusive: bool = True,
|
||||
):
|
||||
if not handler:
|
||||
|
|
|
@ -851,7 +851,7 @@ class StartupMixin(metaclass=SanicMeta):
|
|||
primary.config.INSPECTOR_TLS_KEY,
|
||||
primary.config.INSPECTOR_TLS_CERT,
|
||||
)
|
||||
manager.manage("Inspector", inspector, {}, transient=True)
|
||||
manager.manage("Inspector", inspector, {}, transient=False)
|
||||
|
||||
primary._inspector = inspector
|
||||
primary._manager = manager
|
||||
|
|
|
@ -229,6 +229,7 @@ def _serve_http_1(
|
|||
|
||||
loop.run_until_complete(app._startup())
|
||||
loop.run_until_complete(app._server_event("init", "before"))
|
||||
app.ack()
|
||||
|
||||
try:
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
|
@ -306,6 +307,7 @@ def _serve_http_3(
|
|||
server = AsyncioServer(app, loop, coro, [])
|
||||
loop.run_until_complete(server.startup())
|
||||
loop.run_until_complete(server.before_start())
|
||||
app.ack()
|
||||
loop.run_until_complete(server)
|
||||
_setup_system_signals(app, run_multiple, register_sys_signals, loop)
|
||||
loop.run_until_complete(server.after_start())
|
||||
|
|
18
sanic/worker/constants.py
Normal file
18
sanic/worker/constants.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from enum import IntEnum, auto
|
||||
|
||||
from sanic.compat import UpperStrEnum
|
||||
|
||||
|
||||
class RestartOrder(UpperStrEnum):
|
||||
SHUTDOWN_FIRST = auto()
|
||||
STARTUP_FIRST = auto()
|
||||
|
||||
|
||||
class ProcessState(IntEnum):
|
||||
IDLE = auto()
|
||||
RESTARTING = auto()
|
||||
STARTING = auto()
|
||||
STARTED = auto()
|
||||
ACKED = auto()
|
||||
JOINED = auto()
|
||||
TERMINATED = auto()
|
|
@ -101,8 +101,10 @@ class Inspector:
|
|||
obj[key] = value.isoformat()
|
||||
return obj
|
||||
|
||||
def reload(self) -> None:
|
||||
def reload(self, zero_downtime: bool = False) -> None:
|
||||
message = "__ALL_PROCESSES__:"
|
||||
if zero_downtime:
|
||||
message += ":STARTUP_FIRST"
|
||||
self._publisher.send(message)
|
||||
|
||||
def scale(self, replicas) -> str:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Dict, List, Optional
|
|||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.exceptions import ServerKilled
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.worker.constants import RestartOrder
|
||||
from sanic.worker.process import ProcessState, Worker, WorkerProcess
|
||||
|
||||
|
||||
|
@ -20,7 +21,8 @@ else:
|
|||
|
||||
|
||||
class WorkerManager:
|
||||
THRESHOLD = 300 # == 30 seconds
|
||||
THRESHOLD = WorkerProcess.THRESHOLD
|
||||
MAIN_IDENT = "Sanic-Main"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -37,7 +39,7 @@ class WorkerManager:
|
|||
self.durable: Dict[str, Worker] = {}
|
||||
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
|
||||
self.worker_state = worker_state
|
||||
self.worker_state["Sanic-Main"] = {"pid": self.pid}
|
||||
self.worker_state[self.MAIN_IDENT] = {"pid": self.pid}
|
||||
self.terminated = False
|
||||
self._serve = serve
|
||||
self._server_settings = server_settings
|
||||
|
@ -119,10 +121,15 @@ class WorkerManager:
|
|||
process.terminate()
|
||||
self.terminated = True
|
||||
|
||||
def restart(self, process_names: Optional[List[str]] = None, **kwargs):
|
||||
def restart(
|
||||
self,
|
||||
process_names: Optional[List[str]] = None,
|
||||
restart_order=RestartOrder.SHUTDOWN_FIRST,
|
||||
**kwargs,
|
||||
):
|
||||
for process in self.transient_processes:
|
||||
if not process_names or process.name in process_names:
|
||||
process.restart(**kwargs)
|
||||
process.restart(restart_order=restart_order, **kwargs)
|
||||
|
||||
def scale(self, num_worker: int):
|
||||
if num_worker <= 0:
|
||||
|
@ -160,7 +167,12 @@ class WorkerManager:
|
|||
elif message == "__TERMINATE__":
|
||||
self.shutdown()
|
||||
break
|
||||
split_message = message.split(":", 1)
|
||||
logger.debug(
|
||||
"Incoming monitor message: %s",
|
||||
message,
|
||||
extra={"verbosity": 1},
|
||||
)
|
||||
split_message = message.split(":", 2)
|
||||
if message.startswith("__SCALE__"):
|
||||
self.scale(int(split_message[-1]))
|
||||
continue
|
||||
|
@ -173,10 +185,17 @@ class WorkerManager:
|
|||
]
|
||||
if "__ALL_PROCESSES__" in process_names:
|
||||
process_names = None
|
||||
order = (
|
||||
RestartOrder.STARTUP_FIRST
|
||||
if "STARTUP_FIRST" in split_message
|
||||
else RestartOrder.SHUTDOWN_FIRST
|
||||
)
|
||||
self.restart(
|
||||
process_names=process_names,
|
||||
reloaded_files=reloaded_files,
|
||||
restart_order=order,
|
||||
)
|
||||
self._sync_states()
|
||||
except InterruptedError:
|
||||
if not OS_IS_WINDOWS:
|
||||
raise
|
||||
|
@ -263,3 +282,9 @@ class WorkerManager:
|
|||
if worker_state.get("server")
|
||||
]
|
||||
return all(acked) and len(acked) == self.num_server
|
||||
|
||||
def _sync_states(self):
|
||||
for process in self.processes:
|
||||
state = self.worker_state[process.name].get("state")
|
||||
if state and process.state.name != state:
|
||||
process.set_state(ProcessState[state], True)
|
||||
|
|
|
@ -2,6 +2,7 @@ from multiprocessing.connection import Connection
|
|||
from os import environ, getpid
|
||||
from typing import Any, Dict
|
||||
|
||||
from sanic.log import Colors, logger
|
||||
from sanic.worker.process import ProcessState
|
||||
from sanic.worker.state import WorkerState
|
||||
|
||||
|
@ -16,12 +17,23 @@ class WorkerMultiplexer:
|
|||
self._state = WorkerState(worker_state, self.name)
|
||||
|
||||
def ack(self):
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Process ack: {Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self.pid,
|
||||
)
|
||||
self._state._state[self.name] = {
|
||||
**self._state._state[self.name],
|
||||
"state": ProcessState.ACKED.name,
|
||||
}
|
||||
|
||||
def restart(self, name: str = "", all_workers: bool = False):
|
||||
def restart(
|
||||
self,
|
||||
name: str = "",
|
||||
all_workers: bool = False,
|
||||
zero_downtime: bool = False,
|
||||
):
|
||||
if name and all_workers:
|
||||
raise ValueError(
|
||||
"Ambiguous restart with both a named process and"
|
||||
|
@ -29,6 +41,10 @@ class WorkerMultiplexer:
|
|||
)
|
||||
if not name:
|
||||
name = "__ALL_PROCESSES__:" if all_workers else self.name
|
||||
if not name.endswith(":"):
|
||||
name += ":"
|
||||
if zero_downtime:
|
||||
name += ":STARTUP_FIRST"
|
||||
self._monitor_publisher.send(name)
|
||||
|
||||
reload = restart # no cov
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import os
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import IntEnum, auto
|
||||
from multiprocessing.context import BaseContext
|
||||
from signal import SIGINT
|
||||
from threading import Thread
|
||||
from time import sleep
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from sanic.log import Colors, logger
|
||||
from sanic.worker.constants import ProcessState, RestartOrder
|
||||
|
||||
|
||||
def get_now():
|
||||
|
@ -14,15 +16,8 @@ def get_now():
|
|||
return now
|
||||
|
||||
|
||||
class ProcessState(IntEnum):
|
||||
IDLE = auto()
|
||||
STARTED = auto()
|
||||
ACKED = auto()
|
||||
JOINED = auto()
|
||||
TERMINATED = auto()
|
||||
|
||||
|
||||
class WorkerProcess:
|
||||
THRESHOLD = 300 # == 30 seconds
|
||||
SERVER_LABEL = "Server"
|
||||
|
||||
def __init__(self, factory, name, target, kwargs, worker_state):
|
||||
|
@ -54,8 +49,9 @@ class WorkerProcess:
|
|||
f"{Colors.SANIC}%s{Colors.END}",
|
||||
self.name,
|
||||
)
|
||||
self.set_state(ProcessState.STARTING)
|
||||
self._current_process.start()
|
||||
self.set_state(ProcessState.STARTED)
|
||||
self._process.start()
|
||||
if not self.worker_state[self.name].get("starts"):
|
||||
self.worker_state[self.name] = {
|
||||
**self.worker_state[self.name],
|
||||
|
@ -67,7 +63,7 @@ class WorkerProcess:
|
|||
|
||||
def join(self):
|
||||
self.set_state(ProcessState.JOINED)
|
||||
self._process.join()
|
||||
self._current_process.join()
|
||||
|
||||
def terminate(self):
|
||||
if self.state is not ProcessState.TERMINATED:
|
||||
|
@ -80,21 +76,23 @@ class WorkerProcess:
|
|||
)
|
||||
self.set_state(ProcessState.TERMINATED, force=True)
|
||||
try:
|
||||
# self._process.terminate()
|
||||
os.kill(self.pid, SIGINT)
|
||||
del self.worker_state[self.name]
|
||||
except (KeyError, AttributeError, ProcessLookupError):
|
||||
...
|
||||
|
||||
def restart(self, **kwargs):
|
||||
def restart(self, restart_order=RestartOrder.SHUTDOWN_FIRST, **kwargs):
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Restarting a process: {Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self.pid,
|
||||
)
|
||||
self._process.terminate()
|
||||
self.set_state(ProcessState.IDLE, force=True)
|
||||
self.set_state(ProcessState.RESTARTING, force=True)
|
||||
if restart_order is RestartOrder.SHUTDOWN_FIRST:
|
||||
self._terminate_now()
|
||||
else:
|
||||
self._old_process = self._current_process
|
||||
self.kwargs.update(
|
||||
{"config": {k.upper(): v for k, v in kwargs.items()}}
|
||||
)
|
||||
|
@ -104,6 +102,9 @@ class WorkerProcess:
|
|||
except AttributeError:
|
||||
raise RuntimeError("Restart failed")
|
||||
|
||||
if restart_order is RestartOrder.STARTUP_FIRST:
|
||||
self._terminate_soon()
|
||||
|
||||
self.worker_state[self.name] = {
|
||||
**self.worker_state[self.name],
|
||||
"pid": self.pid,
|
||||
|
@ -113,14 +114,14 @@ class WorkerProcess:
|
|||
|
||||
def is_alive(self):
|
||||
try:
|
||||
return self._process.is_alive()
|
||||
return self._current_process.is_alive()
|
||||
except AssertionError:
|
||||
return False
|
||||
|
||||
def spawn(self):
|
||||
if self.state is not ProcessState.IDLE:
|
||||
if self.state not in (ProcessState.IDLE, ProcessState.RESTARTING):
|
||||
raise Exception("Cannot spawn a worker process until it is idle.")
|
||||
self._process = self.factory(
|
||||
self._current_process = self.factory(
|
||||
name=self.name,
|
||||
target=self.target,
|
||||
kwargs=self.kwargs,
|
||||
|
@ -129,7 +130,56 @@ class WorkerProcess:
|
|||
|
||||
@property
|
||||
def pid(self):
|
||||
return self._process.pid
|
||||
return self._current_process.pid
|
||||
|
||||
def _terminate_now(self):
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Begin restart termination: "
|
||||
f"{Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self._current_process.pid,
|
||||
)
|
||||
self._current_process.terminate()
|
||||
|
||||
def _terminate_soon(self):
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Begin restart termination: "
|
||||
f"{Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self._current_process.pid,
|
||||
)
|
||||
termination_thread = Thread(target=self._wait_to_terminate)
|
||||
termination_thread.start()
|
||||
|
||||
def _wait_to_terminate(self):
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Waiting for process to be acked: "
|
||||
f"{Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self._old_process.pid,
|
||||
)
|
||||
misses = 0
|
||||
while self.state is not ProcessState.ACKED:
|
||||
sleep(0.1)
|
||||
misses += 1
|
||||
if misses > self.THRESHOLD:
|
||||
raise TimeoutError(
|
||||
f"Worker {self.name} failed to come ack within "
|
||||
f"{self.THRESHOLD / 10} seconds"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{Colors.BLUE}Process acked. Terminating: "
|
||||
f"{Colors.BOLD}{Colors.SANIC}"
|
||||
f"%s {Colors.BLUE}[%s]{Colors.END}",
|
||||
self.name,
|
||||
self._old_process.pid,
|
||||
)
|
||||
self._old_process.terminate()
|
||||
delattr(self, "_old_process")
|
||||
|
||||
|
||||
class Worker:
|
||||
|
@ -153,7 +203,11 @@ class Worker:
|
|||
|
||||
def create_process(self) -> WorkerProcess:
|
||||
process = WorkerProcess(
|
||||
factory=self.context.Process,
|
||||
# Need to ignore this typing error - The problem is the
|
||||
# BaseContext itself has no Process. But, all of its
|
||||
# implementations do. We can safely ignore as it is a typing
|
||||
# issue in the standard lib.
|
||||
factory=self.context.Process, # type: ignore
|
||||
name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}",
|
||||
target=self.serve,
|
||||
kwargs={**self.server_settings},
|
||||
|
|
|
@ -321,7 +321,8 @@ def test_inspector_inspect(urlopen, caplog, capsys):
|
|||
@pytest.mark.parametrize(
|
||||
"command,params",
|
||||
(
|
||||
(["reload"], {}),
|
||||
(["reload"], {"zero_downtime": False}),
|
||||
(["reload", "--zero-downtime"], {"zero_downtime": True}),
|
||||
(["shutdown"], {}),
|
||||
(["scale", "9"], {"replicas": 9}),
|
||||
(["foo", "--bar=something"], {"bar": "something"}),
|
||||
|
|
|
@ -88,6 +88,12 @@ def test_run_inspector_reload(publisher, http_client):
|
|||
publisher.send.assert_called_once_with("__ALL_PROCESSES__:")
|
||||
|
||||
|
||||
def test_run_inspector_reload_zero_downtime(publisher, http_client):
|
||||
_, response = http_client.post("/reload", json={"zero_downtime": True})
|
||||
assert response.status == 200
|
||||
publisher.send.assert_called_once_with("__ALL_PROCESSES__::STARTUP_FIRST")
|
||||
|
||||
|
||||
def test_run_inspector_shutdown(publisher, http_client):
|
||||
_, response = http_client.post("/shutdown")
|
||||
assert response.status == 200
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.exceptions import ServerKilled
|
||||
from sanic.worker.constants import RestartOrder
|
||||
from sanic.worker.manager import WorkerManager
|
||||
|
||||
|
||||
|
@ -102,11 +103,17 @@ def test_restart_all():
|
|||
)
|
||||
|
||||
|
||||
def test_monitor_all():
|
||||
@pytest.mark.parametrize("zero_downtime", (False, True))
|
||||
def test_monitor_all(zero_downtime):
|
||||
p1 = Mock()
|
||||
p2 = Mock()
|
||||
sub = Mock()
|
||||
sub.recv.side_effect = ["__ALL_PROCESSES__:", ""]
|
||||
incoming = (
|
||||
"__ALL_PROCESSES__::STARTUP_FIRST"
|
||||
if zero_downtime
|
||||
else "__ALL_PROCESSES__:"
|
||||
)
|
||||
sub.recv.side_effect = [incoming, ""]
|
||||
context = Mock()
|
||||
context.Process.side_effect = [p1, p2]
|
||||
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
|
||||
|
@ -114,16 +121,29 @@ def test_monitor_all():
|
|||
manager.wait_for_ack = Mock() # type: ignore
|
||||
manager.monitor()
|
||||
|
||||
restart_order = (
|
||||
RestartOrder.STARTUP_FIRST
|
||||
if zero_downtime
|
||||
else RestartOrder.SHUTDOWN_FIRST
|
||||
)
|
||||
manager.restart.assert_called_once_with(
|
||||
process_names=None, reloaded_files=""
|
||||
process_names=None,
|
||||
reloaded_files="",
|
||||
restart_order=restart_order,
|
||||
)
|
||||
|
||||
|
||||
def test_monitor_all_with_files():
|
||||
@pytest.mark.parametrize("zero_downtime", (False, True))
|
||||
def test_monitor_all_with_files(zero_downtime):
|
||||
p1 = Mock()
|
||||
p2 = Mock()
|
||||
sub = Mock()
|
||||
sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""]
|
||||
incoming = (
|
||||
"__ALL_PROCESSES__:foo,bar:STARTUP_FIRST"
|
||||
if zero_downtime
|
||||
else "__ALL_PROCESSES__:foo,bar"
|
||||
)
|
||||
sub.recv.side_effect = [incoming, ""]
|
||||
context = Mock()
|
||||
context.Process.side_effect = [p1, p2]
|
||||
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
|
||||
|
@ -131,17 +151,30 @@ def test_monitor_all_with_files():
|
|||
manager.wait_for_ack = Mock() # type: ignore
|
||||
manager.monitor()
|
||||
|
||||
restart_order = (
|
||||
RestartOrder.STARTUP_FIRST
|
||||
if zero_downtime
|
||||
else RestartOrder.SHUTDOWN_FIRST
|
||||
)
|
||||
manager.restart.assert_called_once_with(
|
||||
process_names=None, reloaded_files="foo,bar"
|
||||
process_names=None,
|
||||
reloaded_files="foo,bar",
|
||||
restart_order=restart_order,
|
||||
)
|
||||
|
||||
|
||||
def test_monitor_one_process():
|
||||
@pytest.mark.parametrize("zero_downtime", (False, True))
|
||||
def test_monitor_one_process(zero_downtime):
|
||||
p1 = Mock()
|
||||
p1.name = "Testing"
|
||||
p2 = Mock()
|
||||
sub = Mock()
|
||||
sub.recv.side_effect = [f"{p1.name}:foo,bar", ""]
|
||||
incoming = (
|
||||
f"{p1.name}:foo,bar:STARTUP_FIRST"
|
||||
if zero_downtime
|
||||
else f"{p1.name}:foo,bar"
|
||||
)
|
||||
sub.recv.side_effect = [incoming, ""]
|
||||
context = Mock()
|
||||
context.Process.side_effect = [p1, p2]
|
||||
manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
|
||||
|
@ -149,8 +182,15 @@ def test_monitor_one_process():
|
|||
manager.wait_for_ack = Mock() # type: ignore
|
||||
manager.monitor()
|
||||
|
||||
restart_order = (
|
||||
RestartOrder.STARTUP_FIRST
|
||||
if zero_downtime
|
||||
else RestartOrder.SHUTDOWN_FIRST
|
||||
)
|
||||
manager.restart.assert_called_once_with(
|
||||
process_names=[p1.name], reloaded_files="foo,bar"
|
||||
process_names=[p1.name],
|
||||
reloaded_files="foo,bar",
|
||||
restart_order=restart_order,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -98,17 +98,17 @@ def test_ack(worker_state: Dict[str, Any], m: WorkerMultiplexer):
|
|||
|
||||
def test_restart_self(monitor_publisher: Mock, m: WorkerMultiplexer):
|
||||
m.restart()
|
||||
monitor_publisher.send.assert_called_once_with("Test")
|
||||
monitor_publisher.send.assert_called_once_with("Test:")
|
||||
|
||||
|
||||
def test_restart_foo(monitor_publisher: Mock, m: WorkerMultiplexer):
|
||||
m.restart("foo")
|
||||
monitor_publisher.send.assert_called_once_with("foo")
|
||||
monitor_publisher.send.assert_called_once_with("foo:")
|
||||
|
||||
|
||||
def test_reload_alias(monitor_publisher: Mock, m: WorkerMultiplexer):
|
||||
m.reload()
|
||||
monitor_publisher.send.assert_called_once_with("Test")
|
||||
monitor_publisher.send.assert_called_once_with("Test:")
|
||||
|
||||
|
||||
def test_terminate(monitor_publisher: Mock, m: WorkerMultiplexer):
|
||||
|
@ -135,10 +135,20 @@ def test_properties(
|
|||
@pytest.mark.parametrize(
|
||||
"params,expected",
|
||||
(
|
||||
({}, "Test"),
|
||||
({"name": "foo"}, "foo"),
|
||||
({}, "Test:"),
|
||||
({"name": "foo"}, "foo:"),
|
||||
({"all_workers": True}, "__ALL_PROCESSES__:"),
|
||||
({"zero_downtime": True}, "Test::STARTUP_FIRST"),
|
||||
({"name": "foo", "all_workers": True}, ValueError),
|
||||
({"name": "foo", "zero_downtime": True}, "foo::STARTUP_FIRST"),
|
||||
(
|
||||
{"all_workers": True, "zero_downtime": True},
|
||||
"__ALL_PROCESSES__::STARTUP_FIRST",
|
||||
),
|
||||
(
|
||||
{"name": "foo", "all_workers": True, "zero_downtime": True},
|
||||
ValueError,
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_restart_params(
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
import re
|
||||
import signal
|
||||
import threading
|
||||
|
||||
from asyncio import Event
|
||||
from logging import DEBUG
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.app import Sanic
|
||||
from sanic.worker.constants import ProcessState, RestartOrder
|
||||
from sanic.worker.loader import AppLoader
|
||||
from sanic.worker.process import WorkerProcess
|
||||
from sanic.worker.reloader import Reloader
|
||||
|
||||
|
||||
|
@ -67,6 +73,88 @@ def test_iter_files():
|
|||
assert len_total_files == len_python_files + len_static_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"order,expected",
|
||||
(
|
||||
(
|
||||
RestartOrder.SHUTDOWN_FIRST,
|
||||
[
|
||||
"Restarting a process",
|
||||
"Begin restart termination",
|
||||
"Starting a process",
|
||||
],
|
||||
),
|
||||
(
|
||||
RestartOrder.STARTUP_FIRST,
|
||||
[
|
||||
"Restarting a process",
|
||||
"Starting a process",
|
||||
"Begin restart termination",
|
||||
"Waiting for process to be acked",
|
||||
"Process acked. Terminating",
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_default_reload_shutdown_order(monkeypatch, caplog, order, expected):
|
||||
current_process = Mock()
|
||||
worker_process = WorkerProcess(
|
||||
lambda **_: current_process,
|
||||
"Test",
|
||||
lambda **_: ...,
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
def start(self):
|
||||
worker_process.set_state(ProcessState.ACKED)
|
||||
self._target()
|
||||
|
||||
orig = threading.Thread.start
|
||||
monkeypatch.setattr(threading.Thread, "start", start)
|
||||
|
||||
with caplog.at_level(DEBUG):
|
||||
worker_process.restart(restart_order=order)
|
||||
|
||||
ansi = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
|
||||
def clean(msg: str):
|
||||
msg, _ = ansi.sub("", msg).split(":", 1)
|
||||
return msg
|
||||
|
||||
debug = [clean(record[2]) for record in caplog.record_tuples]
|
||||
assert debug == expected
|
||||
current_process.start.assert_called_once()
|
||||
current_process.terminate.assert_called_once()
|
||||
monkeypatch.setattr(threading.Thread, "start", orig)
|
||||
|
||||
|
||||
def test_reload_delayed(monkeypatch):
|
||||
WorkerProcess.THRESHOLD = 1
|
||||
|
||||
current_process = Mock()
|
||||
worker_process = WorkerProcess(
|
||||
lambda **_: current_process,
|
||||
"Test",
|
||||
lambda **_: ...,
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
def start(self):
|
||||
sleep(0.2)
|
||||
self._target()
|
||||
|
||||
orig = threading.Thread.start
|
||||
monkeypatch.setattr(threading.Thread, "start", start)
|
||||
|
||||
message = "Worker Test failed to come ack within 0.1 seconds"
|
||||
with pytest.raises(TimeoutError, match=message):
|
||||
worker_process.restart(restart_order=RestartOrder.STARTUP_FIRST)
|
||||
|
||||
monkeypatch.setattr(threading.Thread, "start", orig)
|
||||
|
||||
|
||||
def test_reloader_triggers_start_stop_listeners(
|
||||
app: Sanic, app_loader: AppLoader
|
||||
):
|
||||
|
|
|
@ -118,7 +118,7 @@ def test_serve_with_inspector(
|
|||
if config:
|
||||
Inspector.assert_called_once()
|
||||
WorkerManager.manage.assert_called_once_with(
|
||||
"Inspector", inspector, {}, transient=True
|
||||
"Inspector", inspector, {}, transient=False
|
||||
)
|
||||
else:
|
||||
Inspector.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue
Block a user