Scale workers (#2617)

This commit is contained in:
Adam Hopkins 2022-12-13 09:28:23 +02:00 committed by GitHub
parent 13e9ab7ba9
commit db39e127bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 238 additions and 103 deletions

View File

@ -98,13 +98,29 @@ Or, a path to a directory to run as a simple HTTP server:
except ValueError as e: except ValueError as e:
error_logger.exception(f"Failed to run app: {e}") error_logger.exception(f"Failed to run app: {e}")
else: else:
if self.args.inspect or self.args.inspect_raw or self.args.trigger: if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
else: else:
for http_version in self.args.http: for http_version in self.args.http:
app.prepare(**kwargs, version=http_version) app.prepare(**kwargs, version=http_version)
if self.args.inspect or self.args.inspect_raw or self.args.trigger: if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
if self.args.scale is not None:
if self.args.scale <= 0:
error_logger.error("There must be at least 1 worker")
sys.exit(1)
action = f"scale={self.args.scale}"
else:
action = self.args.trigger or ( action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty" "raw" if self.args.inspect_raw else "pretty"
) )

View File

@ -115,6 +115,12 @@ class ApplicationGroup(Group):
const="shutdown", const="shutdown",
help=("Trigger all processes to shutdown"), help=("Trigger all processes to shutdown"),
) )
group.add_argument(
"--scale",
dest="scale",
type=int,
help=("Scale number of workers"),
)
class HTTPVersionGroup(Group): class HTTPVersionGroup(Group):

View File

@ -55,17 +55,20 @@ class Inspector:
else: else:
action = conn.recv(64) action = conn.recv(64)
if action == b"reload": if action == b"reload":
conn.send(b"\n")
self.reload() self.reload()
elif action == b"shutdown": elif action == b"shutdown":
conn.send(b"\n")
self.shutdown() self.shutdown()
elif action.startswith(b"scale"):
num_workers = int(action.split(b"=", 1)[-1])
logger.info("Scaling to %s", num_workers)
self.scale(num_workers)
else: else:
data = dumps(self.state_to_json()) data = dumps(self.state_to_json())
conn.send(data.encode()) conn.send(data.encode())
conn.send(b"\n")
conn.close() conn.close()
finally: finally:
logger.debug("Inspector closing") logger.info("Inspector closing")
sock.close() sock.close()
def stop(self, *_): def stop(self, *_):
@ -80,6 +83,10 @@ class Inspector:
message = "__ALL_PROCESSES__:" message = "__ALL_PROCESSES__:"
self._publisher.send(message) self._publisher.send(message)
def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._publisher.send(message)
def shutdown(self): def shutdown(self):
message = "__TERMINATE__" message = "__TERMINATE__"
self._publisher.send(message) self._publisher.send(message)

View File

@ -1,8 +1,10 @@
import os import os
from itertools import count
from random import choice
from signal import SIGINT, SIGTERM, Signals from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from typing import List, Optional from typing import Dict, List, Optional
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
@ -30,33 +32,61 @@ class WorkerManager:
): ):
self.num_server = number self.num_server = number
self.context = context self.context = context
self.transient: List[Worker] = [] self.transient: Dict[str, Worker] = {}
self.durable: List[Worker] = [] self.durable: Dict[str, Worker] = {}
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid} self.worker_state["Sanic-Main"] = {"pid": self.pid}
self.terminated = False self.terminated = False
self._serve = serve
self._server_settings = server_settings
self._server_count = count()
if number == 0: if number == 0:
raise RuntimeError("Cannot serve with no workers") raise RuntimeError("Cannot serve with no workers")
for i in range(number): for _ in range(number):
self.manage( self.create_server()
f"{WorkerProcess.SERVER_LABEL}-{i}",
serve,
server_settings,
transient=True,
)
signal_func(SIGINT, self.shutdown_signal) signal_func(SIGINT, self.shutdown_signal)
signal_func(SIGTERM, self.shutdown_signal) signal_func(SIGTERM, self.shutdown_signal)
def manage(self, ident, func, kwargs, transient=False): def manage(self, ident, func, kwargs, transient=False) -> Worker:
container = self.transient if transient else self.durable container = self.transient if transient else self.durable
container.append( worker = Worker(ident, func, kwargs, self.context, self.worker_state)
Worker(ident, func, kwargs, self.context, self.worker_state) container[worker.ident] = worker
return worker
def create_server(self) -> Worker:
server_number = next(self._server_count)
return self.manage(
f"{WorkerProcess.SERVER_LABEL}-{server_number}",
self._serve,
self._server_settings,
transient=True,
) )
def shutdown_server(self, ident: Optional[str] = None) -> None:
if not ident:
servers = [
worker
for worker in self.transient.values()
if worker.ident.startswith(WorkerProcess.SERVER_LABEL)
]
if not servers:
error_logger.error(
"Server shutdown failed because a server was not found."
)
return
worker = choice(servers) # nosec B311
else:
worker = self.transient[ident]
for process in worker.processes:
process.terminate()
del self.transient[worker.ident]
def run(self): def run(self):
self.start() self.start()
self.monitor() self.monitor()
@ -94,6 +124,28 @@ class WorkerManager:
if not process_names or process.name in process_names: if not process_names or process.name in process_names:
process.restart(**kwargs) process.restart(**kwargs)
def scale(self, num_worker: int):
if num_worker <= 0:
raise ValueError("Cannot scale to 0 workers.")
change = num_worker - self.num_server
if change == 0:
logger.info(
f"No change needed. There are already {num_worker} workers."
)
return
logger.info(f"Scaling from {self.num_server} to {num_worker} workers")
if change > 0:
for _ in range(change):
worker = self.create_server()
for process in worker.processes:
process.start()
else:
for _ in range(abs(change)):
self.shutdown_server()
self.num_server = num_worker
def monitor(self): def monitor(self):
self.wait_for_ack() self.wait_for_ack()
while True: while True:
@ -109,6 +161,9 @@ class WorkerManager:
self.shutdown() self.shutdown()
break break
split_message = message.split(":", 1) split_message = message.split(":", 1)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
continue
processes = split_message[0] processes = split_message[0]
reloaded_files = ( reloaded_files = (
split_message[1] if len(split_message) > 1 else None split_message[1] if len(split_message) > 1 else None
@ -161,8 +216,8 @@ class WorkerManager:
self.kill() self.kill()
@property @property
def workers(self): def workers(self) -> List[Worker]:
return self.transient + self.durable return list(self.transient.values()) + list(self.durable.values())
@property @property
def processes(self): def processes(self):
@ -172,7 +227,7 @@ class WorkerManager:
@property @property
def transient_processes(self): def transient_processes(self):
for worker in self.transient: for worker in self.transient.values():
for process in worker.processes: for process in worker.processes:
yield process yield process

View File

@ -33,6 +33,10 @@ class WorkerMultiplexer:
reload = restart # no cov reload = restart # no cov
def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._monitor_publisher.send(message)
def terminate(self, early: bool = False): def terminate(self, early: bool = False):
message = "__TERMINATE_EARLY__" if early else "__TERMINATE__" message = "__TERMINATE_EARLY__" if early else "__TERMINATE__"
self._monitor_publisher.send(message) self._monitor_publisher.send(message)

View File

@ -133,6 +133,8 @@ class WorkerProcess:
class Worker: class Worker:
WORKER_PREFIX = "Sanic-"
def __init__( def __init__(
self, self,
ident: str, ident: str,
@ -152,7 +154,7 @@ class Worker:
def create_process(self) -> WorkerProcess: def create_process(self) -> WorkerProcess:
process = WorkerProcess( process = WorkerProcess(
factory=self.context.Process, factory=self.context.Process,
name=f"Sanic-{self.ident}-{len(self.processes)}", name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}",
target=self.serve, target=self.serve,
kwargs={**self.server_settings}, kwargs={**self.server_settings},
worker_state=self.worker_state, worker_state=self.worker_state,

View File

@ -74,7 +74,9 @@ def test_send_inspect_conn_refused(socket: Mock, sys: Mock, caplog):
@patch("sanic.worker.inspector.configure_socket") @patch("sanic.worker.inspector.configure_socket")
@pytest.mark.parametrize("action", (b"reload", b"shutdown", b"foo")) @pytest.mark.parametrize(
"action", (b"reload", b"shutdown", b"scale=5", b"foo")
)
def test_run_inspector(configure_socket: Mock, action: bytes): def test_run_inspector(configure_socket: Mock, action: bytes):
sock = Mock() sock = Mock()
conn = Mock() conn = Mock()
@ -83,6 +85,7 @@ def test_run_inspector(configure_socket: Mock, action: bytes):
inspector = Inspector(Mock(), {}, {}, "localhost", 9999) inspector = Inspector(Mock(), {}, {}, "localhost", 9999)
inspector.reload = Mock() # type: ignore inspector.reload = Mock() # type: ignore
inspector.shutdown = Mock() # type: ignore inspector.shutdown = Mock() # type: ignore
inspector.scale = Mock() # type: ignore
inspector.state_to_json = Mock(return_value="foo") # type: ignore inspector.state_to_json = Mock(return_value="foo") # type: ignore
def accept(): def accept():
@ -98,20 +101,26 @@ def test_run_inspector(configure_socket: Mock, action: bytes):
) )
conn.recv.assert_called_with(64) conn.recv.assert_called_with(64)
if action == b"reload":
conn.send.assert_called_with(b"\n") conn.send.assert_called_with(b"\n")
if action == b"reload":
inspector.reload.assert_called() inspector.reload.assert_called()
inspector.shutdown.assert_not_called() inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called() inspector.state_to_json.assert_not_called()
elif action == b"shutdown": elif action == b"shutdown":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_not_called() inspector.reload.assert_not_called()
inspector.shutdown.assert_called() inspector.shutdown.assert_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called() inspector.state_to_json.assert_not_called()
else: elif action.startswith(b"scale"):
conn.send.assert_called_with(b'"foo"')
inspector.reload.assert_not_called() inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called() inspector.shutdown.assert_not_called()
inspector.scale.assert_called_once_with(5)
inspector.state_to_json.assert_not_called()
else:
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_called() inspector.state_to_json.assert_called()
@ -165,3 +174,11 @@ def test_shutdown():
inspector.shutdown() inspector.shutdown()
publisher.send.assert_called_once_with("__TERMINATE__") publisher.send.assert_called_once_with("__TERMINATE__")
def test_scale():
publisher = Mock()
inspector = Inspector(publisher, {}, {}, "", 0)
inspector.scale(3)
publisher.send.assert_called_once_with("__SCALE__:3")

View File

@ -1,3 +1,4 @@
from logging import ERROR, INFO
from signal import SIGINT, SIGKILL from signal import SIGINT, SIGKILL
from unittest.mock import Mock, call, patch from unittest.mock import Mock, call, patch
@ -14,14 +15,7 @@ def fake_serve():
def test_manager_no_workers(): def test_manager_no_workers():
message = "Cannot serve with no workers" message = "Cannot serve with no workers"
with pytest.raises(RuntimeError, match=message): with pytest.raises(RuntimeError, match=message):
WorkerManager( WorkerManager(0, fake_serve, {}, Mock(), (Mock(), Mock()), {})
0,
fake_serve,
{},
Mock(),
(Mock(), Mock()),
{},
)
@patch("sanic.worker.process.os") @patch("sanic.worker.process.os")
@ -30,14 +24,7 @@ def test_terminate(os_mock: Mock):
process.pid = 1234 process.pid = 1234
context = Mock() context = Mock()
context.Process.return_value = process context.Process.return_value = process
manager = WorkerManager( manager = WorkerManager(1, fake_serve, {}, context, (Mock(), Mock()), {})
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
assert manager.terminated is False assert manager.terminated is False
manager.terminate() manager.terminate()
assert manager.terminated is True assert manager.terminated is True
@ -51,14 +38,7 @@ def test_shutown(os_mock: Mock):
process.is_alive.return_value = True process.is_alive.return_value = True
context = Mock() context = Mock()
context.Process.return_value = process context.Process.return_value = process
manager = WorkerManager( manager = WorkerManager(1, fake_serve, {}, context, (Mock(), Mock()), {})
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
manager.shutdown() manager.shutdown()
os_mock.kill.assert_called_once_with(1234, SIGINT) os_mock.kill.assert_called_once_with(1234, SIGINT)
@ -69,14 +49,7 @@ def test_kill(os_mock: Mock):
process.pid = 1234 process.pid = 1234
context = Mock() context = Mock()
context.Process.return_value = process context.Process.return_value = process
manager = WorkerManager( manager = WorkerManager(1, fake_serve, {}, context, (Mock(), Mock()), {})
1,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
with pytest.raises(ServerKilled): with pytest.raises(ServerKilled):
manager.kill() manager.kill()
os_mock.kill.assert_called_once_with(1234, SIGKILL) os_mock.kill.assert_called_once_with(1234, SIGKILL)
@ -87,14 +60,7 @@ def test_restart_all():
p2 = Mock() p2 = Mock()
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2, p1, p2] context.Process.side_effect = [p1, p2, p1, p2]
manager = WorkerManager( manager = WorkerManager(2, fake_serve, {}, context, (Mock(), Mock()), {})
2,
fake_serve,
{},
context,
(Mock(), Mock()),
{},
)
assert len(list(manager.transient_processes)) assert len(list(manager.transient_processes))
manager.restart() manager.restart()
p1.terminate.assert_called_once() p1.terminate.assert_called_once()
@ -136,14 +102,7 @@ def test_monitor_all():
sub.recv.side_effect = ["__ALL_PROCESSES__:", ""] sub.recv.side_effect = ["__ALL_PROCESSES__:", ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager( manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
@ -160,14 +119,7 @@ def test_monitor_all_with_files():
sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""] sub.recv.side_effect = ["__ALL_PROCESSES__:foo,bar", ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager( manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
@ -185,14 +137,7 @@ def test_monitor_one_process():
sub.recv.side_effect = [f"{p1.name}:foo,bar", ""] sub.recv.side_effect = [f"{p1.name}:foo,bar", ""]
context = Mock() context = Mock()
context.Process.side_effect = [p1, p2] context.Process.side_effect = [p1, p2]
manager = WorkerManager( manager = WorkerManager(2, fake_serve, {}, context, (Mock(), sub), {})
2,
fake_serve,
{},
context,
(Mock(), sub),
{},
)
manager.restart = Mock() # type: ignore manager.restart = Mock() # type: ignore
manager.wait_for_ack = Mock() # type: ignore manager.wait_for_ack = Mock() # type: ignore
manager.monitor() manager.monitor()
@ -204,16 +149,94 @@ def test_monitor_one_process():
def test_shutdown_signal(): def test_shutdown_signal():
pub = Mock() pub = Mock()
manager = WorkerManager( manager = WorkerManager(1, fake_serve, {}, Mock(), (pub, Mock()), {})
1,
fake_serve,
{},
Mock(),
(pub, Mock()),
{},
)
manager.shutdown = Mock() # type: ignore manager.shutdown = Mock() # type: ignore
manager.shutdown_signal(SIGINT, None) manager.shutdown_signal(SIGINT, None)
pub.send.assert_called_with(None) pub.send.assert_called_with(None)
manager.shutdown.assert_called_once_with() manager.shutdown.assert_called_once_with()
def test_shutdown_servers(caplog):
p1 = Mock()
p1.pid = 1234
context = Mock()
context.Process.side_effect = [p1]
pub = Mock()
manager = WorkerManager(1, fake_serve, {}, context, (pub, Mock()), {})
with patch("os.kill") as kill:
with caplog.at_level(ERROR):
manager.shutdown_server()
kill.assert_called_once_with(1234, SIGINT)
kill.reset_mock()
assert not caplog.record_tuples
manager.shutdown_server()
kill.assert_not_called()
assert (
"sanic.error",
ERROR,
"Server shutdown failed because a server was not found.",
) in caplog.record_tuples
def test_shutdown_servers_named():
p1 = Mock()
p1.pid = 1234
p2 = Mock()
p2.pid = 6543
context = Mock()
context.Process.side_effect = [p1, p2]
pub = Mock()
manager = WorkerManager(2, fake_serve, {}, context, (pub, Mock()), {})
with patch("os.kill") as kill:
with pytest.raises(KeyError):
manager.shutdown_server("foo")
manager.shutdown_server("Server-1")
kill.assert_called_once_with(6543, SIGINT)
def test_scale(caplog):
p1 = Mock()
p1.pid = 1234
p2 = Mock()
p2.pid = 3456
p3 = Mock()
p3.pid = 5678
context = Mock()
context.Process.side_effect = [p1, p2, p3]
pub = Mock()
manager = WorkerManager(1, fake_serve, {}, context, (pub, Mock()), {})
assert len(manager.transient) == 1
manager.scale(3)
assert len(manager.transient) == 3
with patch("os.kill") as kill:
manager.scale(2)
assert len(manager.transient) == 2
manager.scale(1)
assert len(manager.transient) == 1
kill.call_count == 2
with caplog.at_level(INFO):
manager.scale(1)
assert (
"sanic.root",
INFO,
"No change needed. There are already 1 workers.",
) in caplog.record_tuples
with pytest.raises(ValueError, match=r"Cannot scale to 0 workers\."):
manager.scale(0)

View File

@ -108,6 +108,11 @@ def test_terminate(monitor_publisher: Mock, m: WorkerMultiplexer):
monitor_publisher.send.assert_called_once_with("__TERMINATE__") monitor_publisher.send.assert_called_once_with("__TERMINATE__")
def test_scale(monitor_publisher: Mock, m: WorkerMultiplexer):
m.scale(99)
monitor_publisher.send.assert_called_once_with("__SCALE__:99")
def test_properties( def test_properties(
monitor_publisher: Mock, worker_state: Dict[str, Any], m: WorkerMultiplexer monitor_publisher: Mock, worker_state: Dict[str, Any], m: WorkerMultiplexer
): ):