Add GenericCreator for loading SSL certs in processes (#2578)
This commit is contained in:
parent
3f4663b9f8
commit
d70636ba2e
|
@ -5,18 +5,10 @@ import sys
|
|||
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
|
||||
|
||||
from sanic.http.tls.creators import CertCreator, MkcertCreator, TrustmeCreator
|
||||
from sanic.http.tls.context import process_to_context
|
||||
from sanic.http.tls.creators import MkcertCreator, TrustmeCreator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -106,21 +98,30 @@ class AppLoader:
|
|||
|
||||
|
||||
class CertLoader:
|
||||
_creator_class: Type[CertCreator]
|
||||
_creators = {
|
||||
"mkcert": MkcertCreator,
|
||||
"trustme": TrustmeCreator,
|
||||
}
|
||||
|
||||
def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]):
|
||||
creator_name = ssl_data.get("creator")
|
||||
if creator_name not in ("mkcert", "trustme"):
|
||||
self._ssl_data = ssl_data
|
||||
|
||||
creator_name = cast(str, ssl_data.get("creator"))
|
||||
|
||||
self._creator_class = self._creators.get(creator_name)
|
||||
if not creator_name:
|
||||
return
|
||||
|
||||
if not self._creator_class:
|
||||
raise RuntimeError(f"Unknown certificate creator: {creator_name}")
|
||||
elif creator_name == "mkcert":
|
||||
self._creator_class = MkcertCreator
|
||||
elif creator_name == "trustme":
|
||||
self._creator_class = TrustmeCreator
|
||||
|
||||
self._key = ssl_data["key"]
|
||||
self._cert = ssl_data["cert"]
|
||||
self._localhost = cast(str, ssl_data["localhost"])
|
||||
|
||||
def load(self, app: SanicApp):
|
||||
if not self._creator_class:
|
||||
return process_to_context(self._ssl_data)
|
||||
|
||||
creator = self._creator_class(app, self._key, self._cert)
|
||||
return creator.generate_cert(self._localhost)
|
||||
|
|
|
@ -4,6 +4,7 @@ import ssl
|
|||
import subprocess
|
||||
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing import Event
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib.parse import urlparse
|
||||
|
@ -636,3 +637,29 @@ def test_sanic_ssl_context_create():
|
|||
|
||||
assert sanic_context is context
|
||||
assert isinstance(sanic_context, SanicSSLContext)
|
||||
|
||||
|
||||
def test_ssl_in_multiprocess_mode(app: Sanic, caplog):
|
||||
|
||||
ssl_dict = {"cert": localhost_cert, "key": localhost_key}
|
||||
event = Event()
|
||||
|
||||
@app.main_process_start
|
||||
async def main_start(app: Sanic):
|
||||
app.shared_ctx.event = event
|
||||
|
||||
@app.after_server_start
|
||||
async def shutdown(app):
|
||||
app.shared_ctx.event.set()
|
||||
app.stop()
|
||||
|
||||
assert not event.is_set()
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run(ssl=ssl_dict)
|
||||
assert event.is_set()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"Goin' Fast @ https://127.0.0.1:8000",
|
||||
) in caplog.record_tuples
|
||||
|
|
|
@ -86,6 +86,10 @@ def test_input_is_module():
|
|||
@patch("sanic.worker.loader.TrustmeCreator")
|
||||
@patch("sanic.worker.loader.MkcertCreator")
|
||||
def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str):
|
||||
CertLoader._creators = {
|
||||
"mkcert": MkcertCreator,
|
||||
"trustme": TrustmeCreator,
|
||||
}
|
||||
MkcertCreator.return_value = MkcertCreator
|
||||
TrustmeCreator.return_value = TrustmeCreator
|
||||
data = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user