Allow fork in limited cases (#2624)

This commit is contained in:
Adam Hopkins 2022-12-15 11:49:26 +02:00 committed by GitHub
parent 064168f3c8
commit b276b91c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 156 additions and 55 deletions

View File

@ -8,11 +8,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from sanic import Sanic from sanic import Sanic
try:
from sanic_ext import Extend # type: ignore
except ImportError:
...
def setup_ext(app: Sanic, *, fail: bool = False, **kwargs): def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
if not app.config.AUTO_EXTEND: if not app.config.AUTO_EXTEND:
@ -33,7 +28,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
return return
if not getattr(app, "_ext", None): if not getattr(app, "_ext", None):
Ext: Extend = getattr(sanic_ext, "Extend") Ext = getattr(sanic_ext, "Extend")
app._ext = Ext(app, **kwargs) app._ext = Ext(app, **kwargs)
return app.ext return app.ext

View File

@ -3,10 +3,22 @@ import os
import signal import signal
import sys import sys
from typing import Awaitable from contextlib import contextmanager
from typing import Awaitable, Union
from multidict import CIMultiDict # type: ignore from multidict import CIMultiDict # type: ignore
from sanic.helpers import Default
if sys.version_info < (3, 8): # no cov
StartMethod = Union[Default, str]
else: # no cov
from typing import Literal
StartMethod = Union[
Default, Literal["fork"], Literal["forkserver"], Literal["spawn"]
]
OS_IS_WINDOWS = os.name == "nt" OS_IS_WINDOWS = os.name == "nt"
UVLOOP_INSTALLED = False UVLOOP_INSTALLED = False
@ -19,6 +31,16 @@ except ImportError:
pass pass
@contextmanager
def use_context(method: StartMethod):
from sanic import Sanic
orig = Sanic.start_method
Sanic.start_method = method
yield
Sanic.start_method = orig
def enable_windows_color_support(): def enable_windows_color_support():
import ctypes import ctypes

View File

@ -40,9 +40,9 @@ from sanic.application.logo import get_logo
from sanic.application.motd import MOTD from sanic.application.motd import MOTD
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
from sanic.base.meta import SanicMeta from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS, is_atty from sanic.compat import OS_IS_WINDOWS, StartMethod, is_atty
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
from sanic.helpers import Default from sanic.helpers import Default, _default
from sanic.http.constants import HTTP from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context from sanic.http.tls import get_ssl_context, process_to_context
from sanic.http.tls.context import SanicSSLContext from sanic.http.tls.context import SanicSSLContext
@ -88,6 +88,7 @@ class StartupMixin(metaclass=SanicMeta):
state: ApplicationState state: ApplicationState
websocket_enabled: bool websocket_enabled: bool
multiplexer: WorkerMultiplexer multiplexer: WorkerMultiplexer
start_method: StartMethod = _default
def setup_loop(self): def setup_loop(self):
if not self.asgi: if not self.asgi:
@ -692,12 +693,17 @@ class StartupMixin(metaclass=SanicMeta):
return any(app.state.auto_reload for app in cls._app_registry.values()) return any(app.state.auto_reload for app in cls._app_registry.values())
@classmethod @classmethod
def _get_context(cls) -> BaseContext: def _get_startup_method(cls) -> str:
method = ( return (
"spawn" cls.start_method
if "linux" not in sys.platform or cls.should_auto_reload() if not isinstance(cls.start_method, Default)
else "fork" else "spawn"
) )
@classmethod
def _get_context(cls) -> BaseContext:
method = cls._get_startup_method()
logger.debug("Creating multiprocessing context using '%s'", method)
return get_context(method) return get_context(method)
@classmethod @classmethod

View File

@ -8,7 +8,7 @@ import uuid
from contextlib import suppress from contextlib import suppress
from logging import LogRecord from logging import LogRecord
from typing import List, Tuple from typing import Any, Dict, List, Tuple
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -54,7 +54,7 @@ TYPE_TO_GENERATOR_MAP = {
"uuid": lambda: str(uuid.uuid1()), "uuid": lambda: str(uuid.uuid1()),
} }
CACHE = {} CACHE: Dict[str, Any] = {}
class RouteStringGenerator: class RouteStringGenerator:
@ -147,6 +147,7 @@ def app(request):
for target, method_name in TouchUp._registry: for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name) CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name)) app = Sanic(slugify.sub("-", request.node.name))
yield app yield app
for target, method_name in TouchUp._registry: for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name]) setattr(target, method_name, CACHE[method_name])

View File

@ -349,11 +349,11 @@ def test_get_app_does_not_exist():
with pytest.raises( with pytest.raises(
SanicException, SanicException,
match="Sanic app name 'does-not-exist' not found.\n" match="Sanic app name 'does-not-exist' not found.\n"
"App instantiation must occur outside " "App instantiation must occur outside "
"if __name__ == '__main__' " "if __name__ == '__main__' "
"block or by using an AppLoader.\nSee " "block or by using an AppLoader.\nSee "
"https://sanic.dev/en/guide/deployment/app-loader.html" "https://sanic.dev/en/guide/deployment/app-loader.html"
" for more details." " for more details.",
): ):
Sanic.get_app("does-not-exist") Sanic.get_app("does-not-exist")

View File

@ -117,7 +117,13 @@ def test_error_with_path_as_instance_without_simple_arg(caplog):
), ),
) )
def test_tls_options(cmd: Tuple[str, ...], caplog): def test_tls_options(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd, "--port=9999", "--debug"] command = [
"fake.server.app",
*cmd,
"--port=9999",
"--debug",
"--single-process",
]
lines = capture(command, caplog) lines = capture(command, caplog)
assert "Goin' Fast @ https://127.0.0.1:9999" in lines assert "Goin' Fast @ https://127.0.0.1:9999" in lines

View File

@ -39,7 +39,7 @@ def test_logo_true(app, caplog):
with patch("sys.stdout.isatty") as isatty: with patch("sys.stdout.isatty") as isatty:
isatty.return_value = True isatty.return_value = True
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
app.make_coffee() app.make_coffee(single_process=True)
# Only in the regular logo # Only in the regular logo
assert " ▄███ █████ ██ " not in caplog.text assert " ▄███ █████ ██ " not in caplog.text

View File

@ -2,7 +2,6 @@ import asyncio
import sys import sys
from threading import Event from threading import Event
from unittest.mock import Mock
import pytest import pytest
@ -75,7 +74,7 @@ def test_create_named_task(app):
app.stop() app.stop()
app.run() app.run(single_process=True)
def test_named_task_called(app): def test_named_task_called(app):

View File

@ -10,8 +10,7 @@ import pytest
import sanic import sanic
from sanic import Sanic from sanic import Sanic
from sanic.log import Colors from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, logger
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.response import text from sanic.response import text
@ -254,11 +253,11 @@ def test_verbosity(app, caplog, app_verbosity, log_verbosity, exists):
def test_colors_enum_format(): def test_colors_enum_format():
assert f'{Colors.END}' == Colors.END.value assert f"{Colors.END}" == Colors.END.value
assert f'{Colors.BOLD}' == Colors.BOLD.value assert f"{Colors.BOLD}" == Colors.BOLD.value
assert f'{Colors.BLUE}' == Colors.BLUE.value assert f"{Colors.BLUE}" == Colors.BLUE.value
assert f'{Colors.GREEN}' == Colors.GREEN.value assert f"{Colors.GREEN}" == Colors.GREEN.value
assert f'{Colors.PURPLE}' == Colors.PURPLE.value assert f"{Colors.PURPLE}" == Colors.PURPLE.value
assert f'{Colors.RED}' == Colors.RED.value assert f"{Colors.RED}" == Colors.RED.value
assert f'{Colors.SANIC}' == Colors.SANIC.value assert f"{Colors.SANIC}" == Colors.SANIC.value
assert f'{Colors.YELLOW}' == Colors.YELLOW.value assert f"{Colors.YELLOW}" == Colors.YELLOW.value

View File

@ -3,6 +3,7 @@ import multiprocessing
import pickle import pickle
import random import random
import signal import signal
import sys
from asyncio import sleep from asyncio import sleep
@ -11,6 +12,7 @@ import pytest
from sanic_testing.testing import HOST, PORT from sanic_testing.testing import HOST, PORT
from sanic import Blueprint, text from sanic import Blueprint, text
from sanic.compat import use_context
from sanic.log import logger from sanic.log import logger
from sanic.server.socket import configure_socket from sanic.server.socket import configure_socket
@ -20,6 +22,10 @@ from sanic.server.socket import configure_socket
reason="SIGALRM is not implemented for this platform, we have to come " reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these", "up with another timeout strategy to test these",
) )
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing(app): def test_multiprocessing(app):
"""Tests that the number of children we produce is correct""" """Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check # Selects a number at random so we can spot check
@ -37,7 +43,8 @@ def test_multiprocessing(app):
signal.signal(signal.SIGALRM, stop_on_alarm) signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2) signal.alarm(2)
app.run(HOST, 4120, workers=num_workers, debug=True) with use_context("fork"):
app.run(HOST, 4120, workers=num_workers, debug=True)
assert len(process_list) == num_workers + 1 assert len(process_list) == num_workers + 1
@ -136,6 +143,10 @@ def test_multiprocessing_legacy_unix(app):
not hasattr(signal, "SIGALRM"), not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform", reason="SIGALRM is not implemented for this platform",
) )
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing_with_blueprint(app): def test_multiprocessing_with_blueprint(app):
# Selects a number at random so we can spot check # Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
@ -155,7 +166,8 @@ def test_multiprocessing_with_blueprint(app):
bp = Blueprint("test_text") bp = Blueprint("test_text")
app.blueprint(bp) app.blueprint(bp)
app.run(HOST, 4121, workers=num_workers, debug=True) with use_context("fork"):
app.run(HOST, 4121, workers=num_workers, debug=True)
assert len(process_list) == num_workers + 1 assert len(process_list) == num_workers + 1
@ -213,6 +225,10 @@ def test_pickle_app_with_static(app, protocol):
up_p_app.run(single_process=True) up_p_app.run(single_process=True)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_main_process_event(app, caplog): def test_main_process_event(app, caplog):
# Selects a number at random so we can spot check # Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
@ -235,8 +251,9 @@ def test_main_process_event(app, caplog):
def main_process_stop2(app, loop): def main_process_stop2(app, loop):
logger.info("main_process_stop") logger.info("main_process_stop")
with caplog.at_level(logging.INFO): with use_context("fork"):
app.run(HOST, PORT, workers=num_workers) with caplog.at_level(logging.INFO):
app.run(HOST, PORT, workers=num_workers)
assert ( assert (
caplog.record_tuples.count(("sanic.root", 20, "main_process_start")) caplog.record_tuples.count(("sanic.root", 20, "main_process_start"))

View File

@ -1,8 +1,5 @@
import asyncio import asyncio
from contextlib import closing
from socket import socket
import pytest import pytest
from sanic import Sanic from sanic import Sanic
@ -623,6 +620,4 @@ def test_streaming_echo():
res = await read_chunk() res = await read_chunk()
assert res == None assert res == None
# Use random port for tests app.run(access_log=False, single_process=True)
with closing(socket()) as sock:
app.run(access_log=False)

View File

@ -2,6 +2,7 @@ import logging
import os import os
import ssl import ssl
import subprocess import subprocess
import sys
from contextlib import contextmanager from contextlib import contextmanager
from multiprocessing import Event from multiprocessing import Event
@ -17,6 +18,7 @@ import sanic.http.tls.creators
from sanic import Sanic from sanic import Sanic
from sanic.application.constants import Mode from sanic.application.constants import Mode
from sanic.compat import use_context
from sanic.constants import LocalCertCreator from sanic.constants import LocalCertCreator
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.helpers import _default from sanic.helpers import _default
@ -426,7 +428,12 @@ def test_logger_vhosts(caplog):
app.stop() app.stop()
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir]) app.run(
host="127.0.0.1",
port=42102,
ssl=[localhost_dir, sanic_dir],
single_process=True,
)
logmsg = [ logmsg = [
m for s, l, m in caplog.record_tuples if m.startswith("Certificate") m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
@ -642,6 +649,10 @@ def test_sanic_ssl_context_create():
assert isinstance(sanic_context, SanicSSLContext) assert isinstance(sanic_context, SanicSSLContext)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_ssl_in_multiprocess_mode(app: Sanic, caplog): def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
ssl_dict = {"cert": localhost_cert, "key": localhost_key} ssl_dict = {"cert": localhost_cert, "key": localhost_key}
@ -657,8 +668,9 @@ def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
app.stop() app.stop()
assert not event.is_set() assert not event.is_set()
with caplog.at_level(logging.INFO): with use_context("fork"):
app.run(ssl=ssl_dict) with caplog.at_level(logging.INFO):
app.run(ssl=ssl_dict)
assert event.is_set() assert event.is_set()
assert ( assert (

View File

@ -1,6 +1,7 @@
# import asyncio # import asyncio
import logging import logging
import os import os
import sys
from asyncio import AbstractEventLoop, sleep from asyncio import AbstractEventLoop, sleep
from string import ascii_lowercase from string import ascii_lowercase
@ -12,6 +13,7 @@ import pytest
from pytest import LogCaptureFixture from pytest import LogCaptureFixture
from sanic import Sanic from sanic import Sanic
from sanic.compat import use_context
from sanic.request import Request from sanic.request import Request
from sanic.response import text from sanic.response import text
@ -174,7 +176,9 @@ def handler(request: Request):
async def client(app: Sanic, loop: AbstractEventLoop): async def client(app: Sanic, loop: AbstractEventLoop):
try: try:
async with httpx.AsyncClient(uds=SOCKPATH) as client:
transport = httpx.AsyncHTTPTransport(uds=SOCKPATH)
async with httpx.AsyncClient(transport=transport) as client:
r = await client.get("http://myhost.invalid/") r = await client.get("http://myhost.invalid/")
assert r.status_code == 200 assert r.status_code == 200
assert r.text == os.path.abspath(SOCKPATH) assert r.text == os.path.abspath(SOCKPATH)
@ -183,11 +187,16 @@ async def client(app: Sanic, loop: AbstractEventLoop):
app.stop() app.stop()
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_unix_connection_multiple_workers(): def test_unix_connection_multiple_workers():
app_multi = Sanic(name="test") with use_context("fork"):
app_multi.get("/")(handler) app_multi = Sanic(name="test")
app_multi.listener("after_server_start")(client) app_multi.get("/")(handler)
app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2) app_multi.listener("after_server_start")(client)
app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2)
# @pytest.mark.xfail( # @pytest.mark.xfail(

View File

@ -1,13 +1,20 @@
from logging import ERROR, INFO from logging import ERROR, INFO
from signal import SIGINT, SIGKILL from signal import SIGINT
from unittest.mock import Mock, call, patch from unittest.mock import Mock, call, patch
import pytest import pytest
from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled from sanic.exceptions import ServerKilled
from sanic.worker.manager import WorkerManager from sanic.worker.manager import WorkerManager
if not OS_IS_WINDOWS:
from signal import SIGKILL
else:
SIGKILL = SIGINT
def fake_serve(): def fake_serve():
... ...

View File

@ -1,3 +1,5 @@
import sys
from multiprocessing import Event from multiprocessing import Event
from os import environ, getpid from os import environ, getpid
from typing import Any, Dict, Type, Union from typing import Any, Dict, Type, Union
@ -6,6 +8,7 @@ from unittest.mock import Mock
import pytest import pytest
from sanic import Sanic from sanic import Sanic
from sanic.compat import use_context
from sanic.worker.multiplexer import WorkerMultiplexer from sanic.worker.multiplexer import WorkerMultiplexer
from sanic.worker.state import WorkerState from sanic.worker.state import WorkerState
@ -28,6 +31,10 @@ def m(monitor_publisher, worker_state):
del environ["SANIC_WORKER_NAME"] del environ["SANIC_WORKER_NAME"]
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_has_multiplexer_default(app: Sanic): def test_has_multiplexer_default(app: Sanic):
event = Event() event = Event()
@ -41,7 +48,8 @@ def test_has_multiplexer_default(app: Sanic):
app.shared_ctx.event.set() app.shared_ctx.event.set()
app.stop() app.stop()
app.run() with use_context("fork"):
app.run()
assert event.is_set() assert event.is_set()

View File

@ -0,0 +1,25 @@
from unittest.mock import patch
import pytest
from sanic import Sanic
@pytest.mark.parametrize(
"start_method,platform,expected",
(
(None, "linux", "spawn"),
(None, "other", "spawn"),
("fork", "linux", "fork"),
("fork", "other", "fork"),
("forkserver", "linux", "forkserver"),
("forkserver", "other", "forkserver"),
("spawn", "linux", "spawn"),
("spawn", "other", "spawn"),
),
)
def test_get_context(start_method, platform, expected):
if start_method:
Sanic.start_method = start_method
with patch("sys.platform", platform):
assert Sanic._get_startup_method() == expected