Compare commits

...

1 Commits

Author SHA1 Message Date
Adam Hopkins
fa864f0bab
Start and restart arbitrary processes 2023-07-09 13:53:14 +03:00
5 changed files with 239 additions and 57 deletions

View File

@ -5,7 +5,6 @@ import logging
import logging.config import logging.config
import re import re
import sys import sys
from asyncio import ( from asyncio import (
AbstractEventLoop, AbstractEventLoop,
CancelledError, CancelledError,
@ -55,12 +54,7 @@ from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support
from sanic.config import SANIC_PREFIX, Config from sanic.config import SANIC_PREFIX, Config
from sanic.exceptions import ( from sanic.exceptions import BadRequest, SanicException, ServerError, URLBuildError
BadRequest,
SanicException,
ServerError,
URLBuildError,
)
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.helpers import Default, _default from sanic.helpers import Default, _default
from sanic.http import Stage from sanic.http import Stage
@ -90,7 +84,6 @@ from sanic.worker.inspector import Inspector
from sanic.worker.loader import CertLoader from sanic.worker.loader import CertLoader
from sanic.worker.manager import WorkerManager from sanic.worker.manager import WorkerManager
if TYPE_CHECKING: if TYPE_CHECKING:
try: try:
from sanic_ext import Extend # type: ignore from sanic_ext import Extend # type: ignore
@ -1676,7 +1669,10 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
def inspector(self): def inspector(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._inspector: if environ.get("SANIC_WORKER_PROCESS") or not self._inspector:
raise SanicException( raise SanicException(
"Can only access the inspector from the main process" "Can only access the inspector from the main process "
"after main_process_start has run. For example, you most "
"likely want to use it inside the @app.main_process_ready "
"event listener."
) )
return self._inspector return self._inspector
@ -1684,6 +1680,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
def manager(self): def manager(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._manager: if environ.get("SANIC_WORKER_PROCESS") or not self._manager:
raise SanicException( raise SanicException(
"Can only access the manager from the main process" "Can only access the manager from the main process "
"after main_process_start has run. For example, you most "
"likely want to use it inside the @app.main_process_ready "
"event listener."
) )
return self._manager return self._manager

View File

@ -16,3 +16,5 @@ class ProcessState(IntEnum):
ACKED = auto() ACKED = auto()
JOINED = auto() JOINED = auto()
TERMINATED = auto() TERMINATED = auto()
FAILED = auto()
COMPLETED = auto()

View File

@ -1,11 +1,11 @@
import os import os
from contextlib import suppress from contextlib import suppress
from itertools import count from enum import IntEnum, auto
from itertools import chain, count
from random import choice from random import choice
from signal import SIGINT, SIGTERM, Signals from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
@ -13,13 +13,17 @@ from sanic.log import error_logger, logger
from sanic.worker.constants import RestartOrder from sanic.worker.constants import RestartOrder
from sanic.worker.process import ProcessState, Worker, WorkerProcess from sanic.worker.process import ProcessState, Worker, WorkerProcess
if not OS_IS_WINDOWS: if not OS_IS_WINDOWS:
from signal import SIGKILL from signal import SIGKILL
else: else:
SIGKILL = SIGINT SIGKILL = SIGINT
class MonitorCycle(IntEnum):
BREAK = auto()
CONTINUE = auto()
class WorkerManager: class WorkerManager:
THRESHOLD = WorkerProcess.THRESHOLD THRESHOLD = WorkerProcess.THRESHOLD
MAIN_IDENT = "Sanic-Main" MAIN_IDENT = "Sanic-Main"
@ -60,6 +64,8 @@ class WorkerManager:
func: Callable[..., Any], func: Callable[..., Any],
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
transient: bool = False, transient: bool = False,
restartable: Optional[bool] = None,
tracked: bool = True,
workers: int = 1, workers: int = 1,
) -> Worker: ) -> Worker:
""" """
@ -75,14 +81,35 @@ class WorkerManager:
then the Worker Manager will restart the process along then the Worker Manager will restart the process along
with any global restart (ex: auto-reload), defaults to False with any global restart (ex: auto-reload), defaults to False
:type transient: bool, optional :type transient: bool, optional
:param restartable: Whether to mark the process as restartable. If
True then the Worker Manager will be able to restart the process
if prompted. If transient=True, this property will be implied
to be True, defaults to None
:type restartable: Optional[bool], optional
:param tracked: Whether to track the process after completion,
defaults to True
:param workers: The number of worker processes to run, defaults to 1 :param workers: The number of worker processes to run, defaults to 1
:type workers: int, optional :type workers: int, optional
:return: The Worker instance :return: The Worker instance
:rtype: Worker :rtype: Worker
""" """
if ident in self.transient or ident in self.durable:
raise ValueError(f"Worker {ident} already exists")
restartable = restartable if restartable is not None else transient
if transient and not restartable:
raise ValueError(
"Cannot create a transient worker that is not restartable"
)
container = self.transient if transient else self.durable container = self.transient if transient else self.durable
worker = Worker( worker = Worker(
ident, func, kwargs, self.context, self.worker_state, workers ident,
func,
kwargs,
self.context,
self.worker_state,
workers,
restartable,
tracked,
) )
container[worker.ident] = worker container[worker.ident] = worker
return worker return worker
@ -94,6 +121,7 @@ class WorkerManager:
self._serve, self._serve,
self._server_settings, self._server_settings,
transient=True, transient=True,
restartable=True,
) )
def shutdown_server(self, ident: Optional[str] = None) -> None: def shutdown_server(self, ident: Optional[str] = None) -> None:
@ -153,9 +181,32 @@ class WorkerManager:
restart_order=RestartOrder.SHUTDOWN_FIRST, restart_order=RestartOrder.SHUTDOWN_FIRST,
**kwargs, **kwargs,
): ):
restarted = set()
for process in self.transient_processes: for process in self.transient_processes:
if not process_names or process.name in process_names: if process.restartable and (
not process_names or process.name in process_names
):
process.restart(restart_order=restart_order, **kwargs) process.restart(restart_order=restart_order, **kwargs)
restarted.add(process.name)
if process_names:
for process in self.durable_processes:
if process.restartable and process.name in process_names:
if process.state not in (
ProcessState.COMPLETED,
ProcessState.FAILED,
):
error_logger.error(
f"Cannot restart process {process.name} because "
"it is not in a final state. Current state is: "
f"{process.state.name}."
)
continue
process.restart(restart_order=restart_order, **kwargs)
restarted.add(process.name)
if process_names and not restarted:
error_logger.error(
f"Failed to restart processes: {', '.join(process_names)}"
)
def scale(self, num_worker: int): def scale(self, num_worker: int):
if num_worker <= 0: if num_worker <= 0:
@ -183,45 +234,13 @@ class WorkerManager:
self.wait_for_ack() self.wait_for_ack()
while True: while True:
try: try:
if self.monitor_subscriber.poll(0.1): cycle = self._poll_monitor()
message = self.monitor_subscriber.recv() if cycle is MonitorCycle.BREAK:
logger.debug(
f"Monitor message: {message}", extra={"verbosity": 2}
)
if not message:
break break
elif message == "__TERMINATE__": elif cycle is MonitorCycle.CONTINUE:
self.shutdown()
break
logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
continue continue
processes = split_message[0]
reloaded_files = (
split_message[1] if len(split_message) > 1 else None
)
process_names = [
name.strip() for name in processes.split(",")
]
if "__ALL_PROCESSES__" in process_names:
process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
restart_order=order,
)
self._sync_states() self._sync_states()
self._cleanup_non_tracked_workers()
except InterruptedError: except InterruptedError:
if not OS_IS_WINDOWS: if not OS_IS_WINDOWS:
raise raise
@ -264,6 +283,10 @@ class WorkerManager:
def workers(self) -> List[Worker]: def workers(self) -> List[Worker]:
return list(self.transient.values()) + list(self.durable.values()) return list(self.transient.values()) + list(self.durable.values())
@property
def all_workers(self) -> Iterable[Tuple[str, Worker]]:
return chain(self.transient.items(), self.durable.items())
@property @property
def processes(self): def processes(self):
for worker in self.workers: for worker in self.workers:
@ -276,6 +299,12 @@ class WorkerManager:
for process in worker.processes: for process in worker.processes:
yield process yield process
@property
def durable_processes(self):
for worker in self.durable.values():
for process in worker.processes:
yield process
def kill(self): def kill(self):
for process in self.processes: for process in self.processes:
logger.info("Killing %s [%s]", process.name, process.pid) logger.info("Killing %s [%s]", process.name, process.pid)
@ -298,6 +327,25 @@ class WorkerManager:
process.terminate() process.terminate()
self._shutting_down = True self._shutting_down = True
def remove_worker(self, worker: Worker) -> None:
if worker.tracked:
error_logger.error(
f"Worker {worker.ident} is tracked and cannot be removed."
)
return
if worker.has_alive_processes():
error_logger.error(
f"Worker {worker.ident} has alive processes and cannot be "
"removed."
)
return
self.transient.pop(worker.ident, None)
self.durable.pop(worker.ident, None)
for process in worker.processes:
self.worker_state.pop(process.name, None)
logger.info("Removed worker %s", worker.ident)
del worker
@property @property
def pid(self): def pid(self):
return os.getpid() return os.getpid()
@ -317,5 +365,97 @@ class WorkerManager:
except KeyError: except KeyError:
process.set_state(ProcessState.TERMINATED, True) process.set_state(ProcessState.TERMINATED, True)
continue continue
if not process.is_alive():
state = "FAILED" if process.exitcode else "COMPLETED"
if state and process.state.name != state: if state and process.state.name != state:
process.set_state(ProcessState[state], True) process.set_state(ProcessState[state], True)
def _cleanup_non_tracked_workers(self) -> None:
to_remove = [
worker
for worker in self.workers
if not worker.tracked and not worker.has_alive_processes()
]
for worker in to_remove:
self.remove_worker(worker)
def _poll_monitor(self) -> Optional[MonitorCycle]:
if self.monitor_subscriber.poll(0.1):
message = self.monitor_subscriber.recv()
logger.debug(f"Monitor message: {message}", extra={"verbosity": 2})
if not message:
return MonitorCycle.BREAK
elif message == "__TERMINATE__":
self._handle_terminate()
return MonitorCycle.BREAK
elif isinstance(message, tuple) and len(message) == 7:
self._handle_manage(*message)
return MonitorCycle.CONTINUE
elif not isinstance(message, str):
error_logger.error(
"Monitor received an invalid message: %s", message
)
return MonitorCycle.CONTINUE
return self._handle_message(message)
return None
def _handle_terminate(self) -> None:
self.shutdown()
def _handle_message(self, message: str) -> Optional[MonitorCycle]:
logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
return MonitorCycle.CONTINUE
processes = split_message[0]
reloaded_files = split_message[1] if len(split_message) > 1 else None
process_names: Optional[List[str]] = [
name.strip() for name in processes.split(",")
]
if process_names and "__ALL_PROCESSES__" in process_names:
process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
restart_order=order,
)
return None
def _handle_manage(
self,
ident: str,
func: Callable[..., Any],
kwargs: Dict[str, Any],
transient: bool,
restartable: Optional[bool],
tracked: bool,
workers: int,
) -> None:
try:
worker = self.manage(
ident,
func,
kwargs,
transient=transient,
restartable=restartable,
tracked=tracked,
workers=workers,
)
except Exception:
error_logger.exception("Failed to manage worker %s", ident)
else:
for process in worker.processes:
process.start()

View File

@ -1,6 +1,6 @@
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from os import environ, getpid from os import environ, getpid
from typing import Any, Dict from typing import Any, Callable, Dict, Optional
from sanic.log import Colors, logger from sanic.log import Colors, logger
from sanic.worker.process import ProcessState from sanic.worker.process import ProcessState
@ -28,6 +28,27 @@ class WorkerMultiplexer:
"state": ProcessState.ACKED.name, "state": ProcessState.ACKED.name,
} }
def manage(
self,
ident: str,
func: Callable[..., Any],
kwargs: Dict[str, Any],
transient: bool = False,
restartable: Optional[bool] = None,
tracked: bool = False,
workers: int = 1,
) -> None:
bundle = (
ident,
func,
kwargs,
transient,
restartable,
tracked,
workers,
)
self._monitor_publisher.send(bundle)
def restart( def restart(
self, self,
name: str = "", name: str = "",

View File

@ -1,5 +1,4 @@
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
from multiprocessing.context import BaseContext from multiprocessing.context import BaseContext
from signal import SIGINT from signal import SIGINT
@ -20,13 +19,22 @@ class WorkerProcess:
THRESHOLD = 300 # == 30 seconds THRESHOLD = 300 # == 30 seconds
SERVER_LABEL = "Server" SERVER_LABEL = "Server"
def __init__(self, factory, name, target, kwargs, worker_state): def __init__(
self,
factory,
name,
target,
kwargs,
worker_state,
restartable: bool = False,
):
self.state = ProcessState.IDLE self.state = ProcessState.IDLE
self.factory = factory self.factory = factory
self.name = name self.name = name
self.target = target self.target = target
self.kwargs = kwargs self.kwargs = kwargs
self.worker_state = worker_state self.worker_state = worker_state
self.restartable = restartable
if self.name not in self.worker_state: if self.name not in self.worker_state:
self.worker_state[self.name] = { self.worker_state[self.name] = {
"server": self.SERVER_LABEL in self.name "server": self.SERVER_LABEL in self.name
@ -132,6 +140,10 @@ class WorkerProcess:
def pid(self): def pid(self):
return self._current_process.pid return self._current_process.pid
@property
def exitcode(self):
return self._current_process.exitcode
def _terminate_now(self): def _terminate_now(self):
logger.debug( logger.debug(
f"{Colors.BLUE}Begin restart termination: " f"{Colors.BLUE}Begin restart termination: "
@ -193,6 +205,8 @@ class Worker:
context: BaseContext, context: BaseContext,
worker_state: Dict[str, Any], worker_state: Dict[str, Any],
num: int = 1, num: int = 1,
restartable: bool = False,
tracked: bool = True,
): ):
self.ident = ident self.ident = ident
self.num = num self.num = num
@ -201,6 +215,8 @@ class Worker:
self.server_settings = server_settings self.server_settings = server_settings
self.worker_state = worker_state self.worker_state = worker_state
self.processes: Set[WorkerProcess] = set() self.processes: Set[WorkerProcess] = set()
self.restartable = restartable
self.tracked = tracked
for _ in range(num): for _ in range(num):
self.create_process() self.create_process()
@ -215,6 +231,10 @@ class Worker:
target=self.serve, target=self.serve,
kwargs={**self.server_settings}, kwargs={**self.server_settings},
worker_state=self.worker_state, worker_state=self.worker_state,
restartable=self.restartable,
) )
self.processes.add(process) self.processes.add(process)
return process return process
def has_alive_processes(self) -> bool:
return any(process.is_alive() for process in self.processes)