Allow fork in limited cases (#2624)
This commit is contained in:
parent
064168f3c8
commit
b276b91c21
|
@ -8,11 +8,6 @@ from typing import TYPE_CHECKING
|
|||
if TYPE_CHECKING:
|
||||
from sanic import Sanic
|
||||
|
||||
try:
|
||||
from sanic_ext import Extend # type: ignore
|
||||
except ImportError:
|
||||
...
|
||||
|
||||
|
||||
def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
|
||||
if not app.config.AUTO_EXTEND:
|
||||
|
@ -33,7 +28,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
|
|||
return
|
||||
|
||||
if not getattr(app, "_ext", None):
|
||||
Ext: Extend = getattr(sanic_ext, "Extend")
|
||||
Ext = getattr(sanic_ext, "Extend")
|
||||
app._ext = Ext(app, **kwargs)
|
||||
|
||||
return app.ext
|
||||
|
|
|
@ -3,10 +3,22 @@ import os
|
|||
import signal
|
||||
import sys
|
||||
|
||||
from typing import Awaitable
|
||||
from contextlib import contextmanager
|
||||
from typing import Awaitable, Union
|
||||
|
||||
from multidict import CIMultiDict # type: ignore
|
||||
|
||||
from sanic.helpers import Default
|
||||
|
||||
|
||||
if sys.version_info < (3, 8): # no cov
|
||||
StartMethod = Union[Default, str]
|
||||
else: # no cov
|
||||
from typing import Literal
|
||||
|
||||
StartMethod = Union[
|
||||
Default, Literal["fork"], Literal["forkserver"], Literal["spawn"]
|
||||
]
|
||||
|
||||
OS_IS_WINDOWS = os.name == "nt"
|
||||
UVLOOP_INSTALLED = False
|
||||
|
@ -19,6 +31,16 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_context(method: StartMethod):
|
||||
from sanic import Sanic
|
||||
|
||||
orig = Sanic.start_method
|
||||
Sanic.start_method = method
|
||||
yield
|
||||
Sanic.start_method = orig
|
||||
|
||||
|
||||
def enable_windows_color_support():
|
||||
import ctypes
|
||||
|
||||
|
|
|
@ -40,9 +40,9 @@ from sanic.application.logo import get_logo
|
|||
from sanic.application.motd import MOTD
|
||||
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
|
||||
from sanic.base.meta import SanicMeta
|
||||
from sanic.compat import OS_IS_WINDOWS, is_atty
|
||||
from sanic.compat import OS_IS_WINDOWS, StartMethod, is_atty
|
||||
from sanic.exceptions import ServerKilled
|
||||
from sanic.helpers import Default
|
||||
from sanic.helpers import Default, _default
|
||||
from sanic.http.constants import HTTP
|
||||
from sanic.http.tls import get_ssl_context, process_to_context
|
||||
from sanic.http.tls.context import SanicSSLContext
|
||||
|
@ -88,6 +88,7 @@ class StartupMixin(metaclass=SanicMeta):
|
|||
state: ApplicationState
|
||||
websocket_enabled: bool
|
||||
multiplexer: WorkerMultiplexer
|
||||
start_method: StartMethod = _default
|
||||
|
||||
def setup_loop(self):
|
||||
if not self.asgi:
|
||||
|
@ -692,12 +693,17 @@ class StartupMixin(metaclass=SanicMeta):
|
|||
return any(app.state.auto_reload for app in cls._app_registry.values())
|
||||
|
||||
@classmethod
|
||||
def _get_context(cls) -> BaseContext:
|
||||
method = (
|
||||
"spawn"
|
||||
if "linux" not in sys.platform or cls.should_auto_reload()
|
||||
else "fork"
|
||||
def _get_startup_method(cls) -> str:
|
||||
return (
|
||||
cls.start_method
|
||||
if not isinstance(cls.start_method, Default)
|
||||
else "spawn"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_context(cls) -> BaseContext:
|
||||
method = cls._get_startup_method()
|
||||
logger.debug("Creating multiprocessing context using '%s'", method)
|
||||
return get_context(method)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -8,7 +8,7 @@ import uuid
|
|||
|
||||
from contextlib import suppress
|
||||
from logging import LogRecord
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -54,7 +54,7 @@ TYPE_TO_GENERATOR_MAP = {
|
|||
"uuid": lambda: str(uuid.uuid1()),
|
||||
}
|
||||
|
||||
CACHE = {}
|
||||
CACHE: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RouteStringGenerator:
|
||||
|
@ -147,6 +147,7 @@ def app(request):
|
|||
for target, method_name in TouchUp._registry:
|
||||
CACHE[method_name] = getattr(target, method_name)
|
||||
app = Sanic(slugify.sub("-", request.node.name))
|
||||
|
||||
yield app
|
||||
for target, method_name in TouchUp._registry:
|
||||
setattr(target, method_name, CACHE[method_name])
|
||||
|
|
|
@ -353,7 +353,7 @@ def test_get_app_does_not_exist():
|
|||
"if __name__ == '__main__' "
|
||||
"block or by using an AppLoader.\nSee "
|
||||
"https://sanic.dev/en/guide/deployment/app-loader.html"
|
||||
" for more details."
|
||||
" for more details.",
|
||||
):
|
||||
Sanic.get_app("does-not-exist")
|
||||
|
||||
|
|
|
@ -117,7 +117,13 @@ def test_error_with_path_as_instance_without_simple_arg(caplog):
|
|||
),
|
||||
)
|
||||
def test_tls_options(cmd: Tuple[str, ...], caplog):
|
||||
command = ["fake.server.app", *cmd, "--port=9999", "--debug"]
|
||||
command = [
|
||||
"fake.server.app",
|
||||
*cmd,
|
||||
"--port=9999",
|
||||
"--debug",
|
||||
"--single-process",
|
||||
]
|
||||
lines = capture(command, caplog)
|
||||
assert "Goin' Fast @ https://127.0.0.1:9999" in lines
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def test_logo_true(app, caplog):
|
|||
with patch("sys.stdout.isatty") as isatty:
|
||||
isatty.return_value = True
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
app.make_coffee()
|
||||
app.make_coffee(single_process=True)
|
||||
|
||||
# Only in the regular logo
|
||||
assert " ▄███ █████ ██ " not in caplog.text
|
||||
|
|
|
@ -2,7 +2,6 @@ import asyncio
|
|||
import sys
|
||||
|
||||
from threading import Event
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -75,7 +74,7 @@ def test_create_named_task(app):
|
|||
|
||||
app.stop()
|
||||
|
||||
app.run()
|
||||
app.run(single_process=True)
|
||||
|
||||
|
||||
def test_named_task_called(app):
|
||||
|
|
|
@ -10,8 +10,7 @@ import pytest
|
|||
import sanic
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.log import Colors
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
|
||||
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, logger
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
|
@ -254,11 +253,11 @@ def test_verbosity(app, caplog, app_verbosity, log_verbosity, exists):
|
|||
|
||||
|
||||
def test_colors_enum_format():
|
||||
assert f'{Colors.END}' == Colors.END.value
|
||||
assert f'{Colors.BOLD}' == Colors.BOLD.value
|
||||
assert f'{Colors.BLUE}' == Colors.BLUE.value
|
||||
assert f'{Colors.GREEN}' == Colors.GREEN.value
|
||||
assert f'{Colors.PURPLE}' == Colors.PURPLE.value
|
||||
assert f'{Colors.RED}' == Colors.RED.value
|
||||
assert f'{Colors.SANIC}' == Colors.SANIC.value
|
||||
assert f'{Colors.YELLOW}' == Colors.YELLOW.value
|
||||
assert f"{Colors.END}" == Colors.END.value
|
||||
assert f"{Colors.BOLD}" == Colors.BOLD.value
|
||||
assert f"{Colors.BLUE}" == Colors.BLUE.value
|
||||
assert f"{Colors.GREEN}" == Colors.GREEN.value
|
||||
assert f"{Colors.PURPLE}" == Colors.PURPLE.value
|
||||
assert f"{Colors.RED}" == Colors.RED.value
|
||||
assert f"{Colors.SANIC}" == Colors.SANIC.value
|
||||
assert f"{Colors.YELLOW}" == Colors.YELLOW.value
|
||||
|
|
|
@ -3,6 +3,7 @@ import multiprocessing
|
|||
import pickle
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from asyncio import sleep
|
||||
|
||||
|
@ -11,6 +12,7 @@ import pytest
|
|||
from sanic_testing.testing import HOST, PORT
|
||||
|
||||
from sanic import Blueprint, text
|
||||
from sanic.compat import use_context
|
||||
from sanic.log import logger
|
||||
from sanic.server.socket import configure_socket
|
||||
|
||||
|
@ -20,6 +22,10 @@ from sanic.server.socket import configure_socket
|
|||
reason="SIGALRM is not implemented for this platform, we have to come "
|
||||
"up with another timeout strategy to test these",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_multiprocessing(app):
|
||||
"""Tests that the number of children we produce is correct"""
|
||||
# Selects a number at random so we can spot check
|
||||
|
@ -37,6 +43,7 @@ def test_multiprocessing(app):
|
|||
|
||||
signal.signal(signal.SIGALRM, stop_on_alarm)
|
||||
signal.alarm(2)
|
||||
with use_context("fork"):
|
||||
app.run(HOST, 4120, workers=num_workers, debug=True)
|
||||
|
||||
assert len(process_list) == num_workers + 1
|
||||
|
@ -136,6 +143,10 @@ def test_multiprocessing_legacy_unix(app):
|
|||
not hasattr(signal, "SIGALRM"),
|
||||
reason="SIGALRM is not implemented for this platform",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_multiprocessing_with_blueprint(app):
|
||||
# Selects a number at random so we can spot check
|
||||
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
|
||||
|
@ -155,6 +166,7 @@ def test_multiprocessing_with_blueprint(app):
|
|||
|
||||
bp = Blueprint("test_text")
|
||||
app.blueprint(bp)
|
||||
with use_context("fork"):
|
||||
app.run(HOST, 4121, workers=num_workers, debug=True)
|
||||
|
||||
assert len(process_list) == num_workers + 1
|
||||
|
@ -213,6 +225,10 @@ def test_pickle_app_with_static(app, protocol):
|
|||
up_p_app.run(single_process=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_main_process_event(app, caplog):
|
||||
# Selects a number at random so we can spot check
|
||||
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
|
||||
|
@ -235,6 +251,7 @@ def test_main_process_event(app, caplog):
|
|||
def main_process_stop2(app, loop):
|
||||
logger.info("main_process_stop")
|
||||
|
||||
with use_context("fork"):
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run(HOST, PORT, workers=num_workers)
|
||||
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
from contextlib import closing
|
||||
from socket import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
|
@ -623,6 +620,4 @@ def test_streaming_echo():
|
|||
res = await read_chunk()
|
||||
assert res == None
|
||||
|
||||
# Use random port for tests
|
||||
with closing(socket()) as sock:
|
||||
app.run(access_log=False)
|
||||
app.run(access_log=False, single_process=True)
|
||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
|||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing import Event
|
||||
|
@ -17,6 +18,7 @@ import sanic.http.tls.creators
|
|||
|
||||
from sanic import Sanic
|
||||
from sanic.application.constants import Mode
|
||||
from sanic.compat import use_context
|
||||
from sanic.constants import LocalCertCreator
|
||||
from sanic.exceptions import SanicException
|
||||
from sanic.helpers import _default
|
||||
|
@ -426,7 +428,12 @@ def test_logger_vhosts(caplog):
|
|||
app.stop()
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir])
|
||||
app.run(
|
||||
host="127.0.0.1",
|
||||
port=42102,
|
||||
ssl=[localhost_dir, sanic_dir],
|
||||
single_process=True,
|
||||
)
|
||||
|
||||
logmsg = [
|
||||
m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
|
||||
|
@ -642,6 +649,10 @@ def test_sanic_ssl_context_create():
|
|||
assert isinstance(sanic_context, SanicSSLContext)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
|
||||
|
||||
ssl_dict = {"cert": localhost_cert, "key": localhost_key}
|
||||
|
@ -657,6 +668,7 @@ def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
|
|||
app.stop()
|
||||
|
||||
assert not event.is_set()
|
||||
with use_context("fork"):
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run(ssl=ssl_dict)
|
||||
assert event.is_set()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from asyncio import AbstractEventLoop, sleep
|
||||
from string import ascii_lowercase
|
||||
|
@ -12,6 +13,7 @@ import pytest
|
|||
from pytest import LogCaptureFixture
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.compat import use_context
|
||||
from sanic.request import Request
|
||||
from sanic.response import text
|
||||
|
||||
|
@ -174,7 +176,9 @@ def handler(request: Request):
|
|||
|
||||
async def client(app: Sanic, loop: AbstractEventLoop):
|
||||
try:
|
||||
async with httpx.AsyncClient(uds=SOCKPATH) as client:
|
||||
|
||||
transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
r = await client.get("http://myhost.invalid/")
|
||||
assert r.status_code == 200
|
||||
assert r.text == os.path.abspath(SOCKPATH)
|
||||
|
@ -183,7 +187,12 @@ async def client(app: Sanic, loop: AbstractEventLoop):
|
|||
app.stop()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_unix_connection_multiple_workers():
|
||||
with use_context("fork"):
|
||||
app_multi = Sanic(name="test")
|
||||
app_multi.get("/")(handler)
|
||||
app_multi.listener("after_server_start")(client)
|
||||
|
|
|
@ -1,13 +1,20 @@
|
|||
from logging import ERROR, INFO
|
||||
from signal import SIGINT, SIGKILL
|
||||
from signal import SIGINT
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.exceptions import ServerKilled
|
||||
from sanic.worker.manager import WorkerManager
|
||||
|
||||
|
||||
if not OS_IS_WINDOWS:
|
||||
from signal import SIGKILL
|
||||
else:
|
||||
SIGKILL = SIGINT
|
||||
|
||||
|
||||
def fake_serve():
|
||||
...
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
|
||||
from multiprocessing import Event
|
||||
from os import environ, getpid
|
||||
from typing import Any, Dict, Type, Union
|
||||
|
@ -6,6 +8,7 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.compat import use_context
|
||||
from sanic.worker.multiplexer import WorkerMultiplexer
|
||||
from sanic.worker.state import WorkerState
|
||||
|
||||
|
@ -28,6 +31,10 @@ def m(monitor_publisher, worker_state):
|
|||
del environ["SANIC_WORKER_NAME"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform not in ("linux", "darwin"),
|
||||
reason="This test requires fork context",
|
||||
)
|
||||
def test_has_multiplexer_default(app: Sanic):
|
||||
event = Event()
|
||||
|
||||
|
@ -41,6 +48,7 @@ def test_has_multiplexer_default(app: Sanic):
|
|||
app.shared_ctx.event.set()
|
||||
app.stop()
|
||||
|
||||
with use_context("fork"):
|
||||
app.run()
|
||||
|
||||
assert event.is_set()
|
||||
|
|
25
tests/worker/test_startup.py
Normal file
25
tests/worker/test_startup.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"start_method,platform,expected",
|
||||
(
|
||||
(None, "linux", "spawn"),
|
||||
(None, "other", "spawn"),
|
||||
("fork", "linux", "fork"),
|
||||
("fork", "other", "fork"),
|
||||
("forkserver", "linux", "forkserver"),
|
||||
("forkserver", "other", "forkserver"),
|
||||
("spawn", "linux", "spawn"),
|
||||
("spawn", "other", "spawn"),
|
||||
),
|
||||
)
|
||||
def test_get_context(start_method, platform, expected):
|
||||
if start_method:
|
||||
Sanic.start_method = start_method
|
||||
with patch("sys.platform", platform):
|
||||
assert Sanic._get_startup_method() == expected
|
Loading…
Reference in New Issue
Block a user