Compare commits

...

1 Commits

Author SHA1 Message Date
Adam Hopkins
fa864f0bab
Start and restart arbitrary processes 2023-07-09 13:53:14 +03:00
5 changed files with 239 additions and 57 deletions

View File

@ -5,7 +5,6 @@ import logging
import logging.config
import re
import sys
from asyncio import (
AbstractEventLoop,
CancelledError,
@ -55,12 +54,7 @@ from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
from sanic.compat import OS_IS_WINDOWS, enable_windows_color_support
from sanic.config import SANIC_PREFIX, Config
from sanic.exceptions import (
BadRequest,
SanicException,
ServerError,
URLBuildError,
)
from sanic.exceptions import BadRequest, SanicException, ServerError, URLBuildError
from sanic.handlers import ErrorHandler
from sanic.helpers import Default, _default
from sanic.http import Stage
@ -90,7 +84,6 @@ from sanic.worker.inspector import Inspector
from sanic.worker.loader import CertLoader
from sanic.worker.manager import WorkerManager
if TYPE_CHECKING:
try:
from sanic_ext import Extend # type: ignore
@ -1676,7 +1669,10 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
def inspector(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._inspector:
raise SanicException(
"Can only access the inspector from the main process"
"Can only access the inspector from the main process "
"after main_process_start has run. For example, you most "
"likely want to use it inside the @app.main_process_ready "
"event listener."
)
return self._inspector
@ -1684,6 +1680,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
def manager(self):
if environ.get("SANIC_WORKER_PROCESS") or not self._manager:
raise SanicException(
"Can only access the manager from the main process"
"Can only access the manager from the main process "
"after main_process_start has run. For example, you most "
"likely want to use it inside the @app.main_process_ready "
"event listener."
)
return self._manager

View File

@ -16,3 +16,5 @@ class ProcessState(IntEnum):
ACKED = auto()
JOINED = auto()
TERMINATED = auto()
FAILED = auto()
COMPLETED = auto()

View File

@ -1,11 +1,11 @@
import os
from contextlib import suppress
from itertools import count
from enum import IntEnum, auto
from itertools import chain, count
from random import choice
from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled
@ -13,13 +13,17 @@ from sanic.log import error_logger, logger
from sanic.worker.constants import RestartOrder
from sanic.worker.process import ProcessState, Worker, WorkerProcess
if not OS_IS_WINDOWS:
from signal import SIGKILL
else:
SIGKILL = SIGINT
class MonitorCycle(IntEnum):
BREAK = auto()
CONTINUE = auto()
class WorkerManager:
THRESHOLD = WorkerProcess.THRESHOLD
MAIN_IDENT = "Sanic-Main"
@ -60,6 +64,8 @@ class WorkerManager:
func: Callable[..., Any],
kwargs: Dict[str, Any],
transient: bool = False,
restartable: Optional[bool] = None,
tracked: bool = True,
workers: int = 1,
) -> Worker:
"""
@ -75,14 +81,35 @@ class WorkerManager:
then the Worker Manager will restart the process along
with any global restart (ex: auto-reload), defaults to False
:type transient: bool, optional
:param restartable: Whether to mark the process as restartable. If
True then the Worker Manager will be able to restart the process
if prompted. If transient=True, this property will be implied
to be True, defaults to None
:type restartable: Optional[bool], optional
:param tracked: Whether to track the process after completion,
defaults to True
:param workers: The number of worker processes to run, defaults to 1
:type workers: int, optional
:return: The Worker instance
:rtype: Worker
"""
if ident in self.transient or ident in self.durable:
raise ValueError(f"Worker {ident} already exists")
restartable = restartable if restartable is not None else transient
if transient and not restartable:
raise ValueError(
"Cannot create a transient worker that is not restartable"
)
container = self.transient if transient else self.durable
worker = Worker(
ident, func, kwargs, self.context, self.worker_state, workers
ident,
func,
kwargs,
self.context,
self.worker_state,
workers,
restartable,
tracked,
)
container[worker.ident] = worker
return worker
@ -94,6 +121,7 @@ class WorkerManager:
self._serve,
self._server_settings,
transient=True,
restartable=True,
)
def shutdown_server(self, ident: Optional[str] = None) -> None:
@ -153,9 +181,32 @@ class WorkerManager:
restart_order=RestartOrder.SHUTDOWN_FIRST,
**kwargs,
):
restarted = set()
for process in self.transient_processes:
if not process_names or process.name in process_names:
if process.restartable and (
not process_names or process.name in process_names
):
process.restart(restart_order=restart_order, **kwargs)
restarted.add(process.name)
if process_names:
for process in self.durable_processes:
if process.restartable and process.name in process_names:
if process.state not in (
ProcessState.COMPLETED,
ProcessState.FAILED,
):
error_logger.error(
f"Cannot restart process {process.name} because "
"it is not in a final state. Current state is: "
f"{process.state.name}."
)
continue
process.restart(restart_order=restart_order, **kwargs)
restarted.add(process.name)
if process_names and not restarted:
error_logger.error(
f"Failed to restart processes: {', '.join(process_names)}"
)
def scale(self, num_worker: int):
if num_worker <= 0:
@ -183,45 +234,13 @@ class WorkerManager:
self.wait_for_ack()
while True:
try:
if self.monitor_subscriber.poll(0.1):
message = self.monitor_subscriber.recv()
logger.debug(
f"Monitor message: {message}", extra={"verbosity": 2}
)
if not message:
cycle = self._poll_monitor()
if cycle is MonitorCycle.BREAK:
break
elif message == "__TERMINATE__":
self.shutdown()
break
logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
elif cycle is MonitorCycle.CONTINUE:
continue
processes = split_message[0]
reloaded_files = (
split_message[1] if len(split_message) > 1 else None
)
process_names = [
name.strip() for name in processes.split(",")
]
if "__ALL_PROCESSES__" in process_names:
process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
restart_order=order,
)
self._sync_states()
self._cleanup_non_tracked_workers()
except InterruptedError:
if not OS_IS_WINDOWS:
raise
@ -264,6 +283,10 @@ class WorkerManager:
def workers(self) -> List[Worker]:
return list(self.transient.values()) + list(self.durable.values())
@property
def all_workers(self) -> Iterable[Tuple[str, Worker]]:
return chain(self.transient.items(), self.durable.items())
@property
def processes(self):
for worker in self.workers:
@ -276,6 +299,12 @@ class WorkerManager:
for process in worker.processes:
yield process
@property
def durable_processes(self):
for worker in self.durable.values():
for process in worker.processes:
yield process
def kill(self):
for process in self.processes:
logger.info("Killing %s [%s]", process.name, process.pid)
@ -298,6 +327,25 @@ class WorkerManager:
process.terminate()
self._shutting_down = True
def remove_worker(self, worker: Worker) -> None:
if worker.tracked:
error_logger.error(
f"Worker {worker.ident} is tracked and cannot be removed."
)
return
if worker.has_alive_processes():
error_logger.error(
f"Worker {worker.ident} has alive processes and cannot be "
"removed."
)
return
self.transient.pop(worker.ident, None)
self.durable.pop(worker.ident, None)
for process in worker.processes:
self.worker_state.pop(process.name, None)
logger.info("Removed worker %s", worker.ident)
del worker
@property
def pid(self):
return os.getpid()
@ -317,5 +365,97 @@ class WorkerManager:
except KeyError:
process.set_state(ProcessState.TERMINATED, True)
continue
if not process.is_alive():
state = "FAILED" if process.exitcode else "COMPLETED"
if state and process.state.name != state:
process.set_state(ProcessState[state], True)
def _cleanup_non_tracked_workers(self) -> None:
to_remove = [
worker
for worker in self.workers
if not worker.tracked and not worker.has_alive_processes()
]
for worker in to_remove:
self.remove_worker(worker)
def _poll_monitor(self) -> Optional[MonitorCycle]:
if self.monitor_subscriber.poll(0.1):
message = self.monitor_subscriber.recv()
logger.debug(f"Monitor message: {message}", extra={"verbosity": 2})
if not message:
return MonitorCycle.BREAK
elif message == "__TERMINATE__":
self._handle_terminate()
return MonitorCycle.BREAK
elif isinstance(message, tuple) and len(message) == 7:
self._handle_manage(*message)
return MonitorCycle.CONTINUE
elif not isinstance(message, str):
error_logger.error(
"Monitor received an invalid message: %s", message
)
return MonitorCycle.CONTINUE
return self._handle_message(message)
return None
def _handle_terminate(self) -> None:
self.shutdown()
def _handle_message(self, message: str) -> Optional[MonitorCycle]:
logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
return MonitorCycle.CONTINUE
processes = split_message[0]
reloaded_files = split_message[1] if len(split_message) > 1 else None
process_names: Optional[List[str]] = [
name.strip() for name in processes.split(",")
]
if process_names and "__ALL_PROCESSES__" in process_names:
process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
restart_order=order,
)
return None
def _handle_manage(
self,
ident: str,
func: Callable[..., Any],
kwargs: Dict[str, Any],
transient: bool,
restartable: Optional[bool],
tracked: bool,
workers: int,
) -> None:
try:
worker = self.manage(
ident,
func,
kwargs,
transient=transient,
restartable=restartable,
tracked=tracked,
workers=workers,
)
except Exception:
error_logger.exception("Failed to manage worker %s", ident)
else:
for process in worker.processes:
process.start()

View File

@ -1,6 +1,6 @@
from multiprocessing.connection import Connection
from os import environ, getpid
from typing import Any, Dict
from typing import Any, Callable, Dict, Optional
from sanic.log import Colors, logger
from sanic.worker.process import ProcessState
@ -28,6 +28,27 @@ class WorkerMultiplexer:
"state": ProcessState.ACKED.name,
}
def manage(
self,
ident: str,
func: Callable[..., Any],
kwargs: Dict[str, Any],
transient: bool = False,
restartable: Optional[bool] = None,
tracked: bool = False,
workers: int = 1,
) -> None:
bundle = (
ident,
func,
kwargs,
transient,
restartable,
tracked,
workers,
)
self._monitor_publisher.send(bundle)
def restart(
self,
name: str = "",

View File

@ -1,5 +1,4 @@
import os
from datetime import datetime, timezone
from multiprocessing.context import BaseContext
from signal import SIGINT
@ -20,13 +19,22 @@ class WorkerProcess:
THRESHOLD = 300 # == 30 seconds
SERVER_LABEL = "Server"
def __init__(self, factory, name, target, kwargs, worker_state):
def __init__(
self,
factory,
name,
target,
kwargs,
worker_state,
restartable: bool = False,
):
self.state = ProcessState.IDLE
self.factory = factory
self.name = name
self.target = target
self.kwargs = kwargs
self.worker_state = worker_state
self.restartable = restartable
if self.name not in self.worker_state:
self.worker_state[self.name] = {
"server": self.SERVER_LABEL in self.name
@ -132,6 +140,10 @@ class WorkerProcess:
def pid(self):
return self._current_process.pid
@property
def exitcode(self):
return self._current_process.exitcode
def _terminate_now(self):
logger.debug(
f"{Colors.BLUE}Begin restart termination: "
@ -193,6 +205,8 @@ class Worker:
context: BaseContext,
worker_state: Dict[str, Any],
num: int = 1,
restartable: bool = False,
tracked: bool = True,
):
self.ident = ident
self.num = num
@ -201,6 +215,8 @@ class Worker:
self.server_settings = server_settings
self.worker_state = worker_state
self.processes: Set[WorkerProcess] = set()
self.restartable = restartable
self.tracked = tracked
for _ in range(num):
self.create_process()
@ -215,6 +231,10 @@ class Worker:
target=self.serve,
kwargs={**self.server_settings},
worker_state=self.worker_state,
restartable=self.restartable,
)
self.processes.add(process)
return process
def has_alive_processes(self) -> bool:
return any(process.is_alive() for process in self.processes)