Allow fork in limited cases (#2624)

This commit is contained in:
Adam Hopkins 2022-12-15 11:49:26 +02:00 committed by GitHub
parent 064168f3c8
commit b276b91c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 156 additions and 55 deletions

View File

@ -8,11 +8,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sanic import Sanic
try:
from sanic_ext import Extend # type: ignore
except ImportError:
...
def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
if not app.config.AUTO_EXTEND:
@ -33,7 +28,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
return
if not getattr(app, "_ext", None):
Ext: Extend = getattr(sanic_ext, "Extend")
Ext = getattr(sanic_ext, "Extend")
app._ext = Ext(app, **kwargs)
return app.ext

View File

@ -3,10 +3,22 @@ import os
import signal
import sys
from typing import Awaitable
from contextlib import contextmanager
from typing import Awaitable, Union
from multidict import CIMultiDict # type: ignore
from sanic.helpers import Default
if sys.version_info < (3, 8): # no cov
StartMethod = Union[Default, str]
else: # no cov
from typing import Literal
StartMethod = Union[
Default, Literal["fork"], Literal["forkserver"], Literal["spawn"]
]
OS_IS_WINDOWS = os.name == "nt"
UVLOOP_INSTALLED = False
@ -19,6 +31,16 @@ except ImportError:
pass
@contextmanager
def use_context(method: StartMethod):
from sanic import Sanic
orig = Sanic.start_method
Sanic.start_method = method
yield
Sanic.start_method = orig
def enable_windows_color_support():
import ctypes

View File

@ -40,9 +40,9 @@ from sanic.application.logo import get_logo
from sanic.application.motd import MOTD
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS, is_atty
from sanic.compat import OS_IS_WINDOWS, StartMethod, is_atty
from sanic.exceptions import ServerKilled
from sanic.helpers import Default
from sanic.helpers import Default, _default
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context
from sanic.http.tls.context import SanicSSLContext
@ -88,6 +88,7 @@ class StartupMixin(metaclass=SanicMeta):
state: ApplicationState
websocket_enabled: bool
multiplexer: WorkerMultiplexer
start_method: StartMethod = _default
def setup_loop(self):
if not self.asgi:
@ -692,12 +693,17 @@ class StartupMixin(metaclass=SanicMeta):
return any(app.state.auto_reload for app in cls._app_registry.values())
@classmethod
def _get_context(cls) -> BaseContext:
method = (
"spawn"
if "linux" not in sys.platform or cls.should_auto_reload()
else "fork"
def _get_startup_method(cls) -> str:
return (
cls.start_method
if not isinstance(cls.start_method, Default)
else "spawn"
)
@classmethod
def _get_context(cls) -> BaseContext:
method = cls._get_startup_method()
logger.debug("Creating multiprocessing context using '%s'", method)
return get_context(method)
@classmethod

View File

@ -8,7 +8,7 @@ import uuid
from contextlib import suppress
from logging import LogRecord
from typing import List, Tuple
from typing import Any, Dict, List, Tuple
from unittest.mock import MagicMock
import pytest
@ -54,7 +54,7 @@ TYPE_TO_GENERATOR_MAP = {
"uuid": lambda: str(uuid.uuid1()),
}
CACHE = {}
CACHE: Dict[str, Any] = {}
class RouteStringGenerator:
@ -147,6 +147,7 @@ def app(request):
for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name))
yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])

View File

@ -353,7 +353,7 @@ def test_get_app_does_not_exist():
"if __name__ == '__main__' "
"block or by using an AppLoader.\nSee "
"https://sanic.dev/en/guide/deployment/app-loader.html"
" for more details."
" for more details.",
):
Sanic.get_app("does-not-exist")

View File

@ -117,7 +117,13 @@ def test_error_with_path_as_instance_without_simple_arg(caplog):
),
)
def test_tls_options(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd, "--port=9999", "--debug"]
command = [
"fake.server.app",
*cmd,
"--port=9999",
"--debug",
"--single-process",
]
lines = capture(command, caplog)
assert "Goin' Fast @ https://127.0.0.1:9999" in lines

View File

@ -39,7 +39,7 @@ def test_logo_true(app, caplog):
with patch("sys.stdout.isatty") as isatty:
isatty.return_value = True
with caplog.at_level(logging.DEBUG):
app.make_coffee()
app.make_coffee(single_process=True)
# Only in the regular logo
assert " ▄███ █████ ██ " not in caplog.text

View File

@ -2,7 +2,6 @@ import asyncio
import sys
from threading import Event
from unittest.mock import Mock
import pytest
@ -75,7 +74,7 @@ def test_create_named_task(app):
app.stop()
app.run()
app.run(single_process=True)
def test_named_task_called(app):

View File

@ -10,8 +10,7 @@ import pytest
import sanic
from sanic import Sanic
from sanic.log import Colors
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, logger
from sanic.response import text
@ -254,11 +253,11 @@ def test_verbosity(app, caplog, app_verbosity, log_verbosity, exists):
def test_colors_enum_format():
assert f'{Colors.END}' == Colors.END.value
assert f'{Colors.BOLD}' == Colors.BOLD.value
assert f'{Colors.BLUE}' == Colors.BLUE.value
assert f'{Colors.GREEN}' == Colors.GREEN.value
assert f'{Colors.PURPLE}' == Colors.PURPLE.value
assert f'{Colors.RED}' == Colors.RED.value
assert f'{Colors.SANIC}' == Colors.SANIC.value
assert f'{Colors.YELLOW}' == Colors.YELLOW.value
assert f"{Colors.END}" == Colors.END.value
assert f"{Colors.BOLD}" == Colors.BOLD.value
assert f"{Colors.BLUE}" == Colors.BLUE.value
assert f"{Colors.GREEN}" == Colors.GREEN.value
assert f"{Colors.PURPLE}" == Colors.PURPLE.value
assert f"{Colors.RED}" == Colors.RED.value
assert f"{Colors.SANIC}" == Colors.SANIC.value
assert f"{Colors.YELLOW}" == Colors.YELLOW.value

View File

@ -3,6 +3,7 @@ import multiprocessing
import pickle
import random
import signal
import sys
from asyncio import sleep
@ -11,6 +12,7 @@ import pytest
from sanic_testing.testing import HOST, PORT
from sanic import Blueprint, text
from sanic.compat import use_context
from sanic.log import logger
from sanic.server.socket import configure_socket
@ -20,6 +22,10 @@ from sanic.server.socket import configure_socket
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
@ -37,6 +43,7 @@ def test_multiprocessing(app):
signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2)
with use_context("fork"):
app.run(HOST, 4120, workers=num_workers, debug=True)
assert len(process_list) == num_workers + 1
@ -136,6 +143,10 @@ def test_multiprocessing_legacy_unix(app):
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform",
)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing_with_blueprint(app):
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
@ -155,6 +166,7 @@ def test_multiprocessing_with_blueprint(app):
bp = Blueprint("test_text")
app.blueprint(bp)
with use_context("fork"):
app.run(HOST, 4121, workers=num_workers, debug=True)
assert len(process_list) == num_workers + 1
@ -213,6 +225,10 @@ def test_pickle_app_with_static(app, protocol):
up_p_app.run(single_process=True)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_main_process_event(app, caplog):
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
@ -235,6 +251,7 @@ def test_main_process_event(app, caplog):
def main_process_stop2(app, loop):
logger.info("main_process_stop")
with use_context("fork"):
with caplog.at_level(logging.INFO):
app.run(HOST, PORT, workers=num_workers)

View File

@ -1,8 +1,5 @@
import asyncio
from contextlib import closing
from socket import socket
import pytest
from sanic import Sanic
@ -623,6 +620,4 @@ def test_streaming_echo():
res = await read_chunk()
assert res == None
# Use random port for tests
with closing(socket()) as sock:
app.run(access_log=False)
app.run(access_log=False, single_process=True)

View File

@ -2,6 +2,7 @@ import logging
import os
import ssl
import subprocess
import sys
from contextlib import contextmanager
from multiprocessing import Event
@ -17,6 +18,7 @@ import sanic.http.tls.creators
from sanic import Sanic
from sanic.application.constants import Mode
from sanic.compat import use_context
from sanic.constants import LocalCertCreator
from sanic.exceptions import SanicException
from sanic.helpers import _default
@ -426,7 +428,12 @@ def test_logger_vhosts(caplog):
app.stop()
with caplog.at_level(logging.INFO):
app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir])
app.run(
host="127.0.0.1",
port=42102,
ssl=[localhost_dir, sanic_dir],
single_process=True,
)
logmsg = [
m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
@ -642,6 +649,10 @@ def test_sanic_ssl_context_create():
assert isinstance(sanic_context, SanicSSLContext)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
ssl_dict = {"cert": localhost_cert, "key": localhost_key}
@ -657,6 +668,7 @@ def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
app.stop()
assert not event.is_set()
with use_context("fork"):
with caplog.at_level(logging.INFO):
app.run(ssl=ssl_dict)
assert event.is_set()

View File

@ -1,6 +1,7 @@
# import asyncio
import logging
import os
import sys
from asyncio import AbstractEventLoop, sleep
from string import ascii_lowercase
@ -12,6 +13,7 @@ import pytest
from pytest import LogCaptureFixture
from sanic import Sanic
from sanic.compat import use_context
from sanic.request import Request
from sanic.response import text
@ -174,7 +176,9 @@ def handler(request: Request):
async def client(app: Sanic, loop: AbstractEventLoop):
try:
async with httpx.AsyncClient(uds=SOCKPATH) as client:
transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
async with httpx.AsyncClient(transport=transport) as client:
r = await client.get("http://myhost.invalid/")
assert r.status_code == 200
assert r.text == os.path.abspath(SOCKPATH)
@ -183,7 +187,12 @@ async def client(app: Sanic, loop: AbstractEventLoop):
app.stop()
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_unix_connection_multiple_workers():
with use_context("fork"):
app_multi = Sanic(name="test")
app_multi.get("/")(handler)
app_multi.listener("after_server_start")(client)

View File

@ -1,13 +1,20 @@
from logging import ERROR, INFO
from signal import SIGINT, SIGKILL
from signal import SIGINT
from unittest.mock import Mock, call, patch
import pytest
from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled
from sanic.worker.manager import WorkerManager
if not OS_IS_WINDOWS:
from signal import SIGKILL
else:
SIGKILL = SIGINT
def fake_serve():
...

View File

@ -1,3 +1,5 @@
import sys
from multiprocessing import Event
from os import environ, getpid
from typing import Any, Dict, Type, Union
@ -6,6 +8,7 @@ from unittest.mock import Mock
import pytest
from sanic import Sanic
from sanic.compat import use_context
from sanic.worker.multiplexer import WorkerMultiplexer
from sanic.worker.state import WorkerState
@ -28,6 +31,10 @@ def m(monitor_publisher, worker_state):
del environ["SANIC_WORKER_NAME"]
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_has_multiplexer_default(app: Sanic):
event = Event()
@ -41,6 +48,7 @@ def test_has_multiplexer_default(app: Sanic):
app.shared_ctx.event.set()
app.stop()
with use_context("fork"):
app.run()
assert event.is_set()

View File

@ -0,0 +1,25 @@
from unittest.mock import patch
import pytest
from sanic import Sanic
@pytest.mark.parametrize(
"start_method,platform,expected",
(
(None, "linux", "spawn"),
(None, "other", "spawn"),
("fork", "linux", "fork"),
("fork", "other", "fork"),
("forkserver", "linux", "forkserver"),
("forkserver", "other", "forkserver"),
("spawn", "linux", "spawn"),
("spawn", "other", "spawn"),
),
)
def test_get_context(start_method, platform, expected):
if start_method:
Sanic.start_method = start_method
with patch("sys.platform", platform):
assert Sanic._get_startup_method() == expected