diff --git a/sanic/app.py b/sanic/app.py index 9fb7e027..9efc53ef 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -92,6 +92,7 @@ from sanic.signals import Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta from sanic.types.shared_ctx import SharedContext from sanic.worker.inspector import Inspector +from sanic.worker.loader import CertLoader from sanic.worker.manager import WorkerManager @@ -139,6 +140,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): "_test_client", "_test_manager", "blueprints", + "certloader_class", "config", "configure_logging", "ctx", @@ -181,6 +183,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): loads: Optional[Callable[..., Any]] = None, inspector: bool = False, inspector_class: Optional[Type[Inspector]] = None, + certloader_class: Optional[Type[CertLoader]] = None, ) -> None: super().__init__(name=name) # logging @@ -215,6 +218,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): self.asgi = False self.auto_reload = False self.blueprints: Dict[str, Blueprint] = {} + self.certloader_class: Type[CertLoader] = ( + certloader_class or CertLoader + ) self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.error_handler: ErrorHandler = error_handler or ErrorHandler() diff --git a/sanic/worker/loader.py b/sanic/worker/loader.py index 344593db..d29f4c68 100644 --- a/sanic/worker/loader.py +++ b/sanic/worker/loader.py @@ -5,6 +5,7 @@ import sys from importlib import import_module from pathlib import Path +from ssl import SSLContext from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast from sanic.http.tls.context import process_to_context @@ -103,8 +104,16 @@ class CertLoader: "trustme": TrustmeCreator, } - def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]): + def __init__( + self, + ssl_data: Optional[ + Union[SSLContext, Dict[str, Union[str, os.PathLike]]] + ], + ): self._ssl_data = ssl_data + self._creator_class = None + if not ssl_data or not isinstance(ssl_data, dict): + return creator_name = cast(str, ssl_data.get("creator")) diff --git a/sanic/worker/serve.py b/sanic/worker/serve.py index 39c647b2..583d3eaf 100644 --- a/sanic/worker/serve.py +++ b/sanic/worker/serve.py @@ -73,8 +73,8 @@ def worker_serve( info.settings["app"] = a a.state.server_info.append(info) - if isinstance(ssl, dict): - cert_loader = CertLoader(ssl) + if isinstance(ssl, dict) or app.certloader_class is not CertLoader: + cert_loader = app.certloader_class(ssl or {}) ssl = cert_loader.load(app) for info in app.state.server_info: info.settings["ssl"] = ssl diff --git a/tests/test_tls.py b/tests/test_tls.py index d41784f0..6d2cb981 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -12,7 +12,7 @@ from urllib.parse import urlparse import pytest -from sanic_testing.testing import HOST, PORT +from sanic_testing.testing import HOST, PORT, SanicTestClient import sanic.http.tls.creators @@ -29,6 +29,7 @@ from sanic.http.tls.creators import ( get_ssl_context, ) from sanic.response import text +from sanic.worker.loader import CertLoader current_dir = os.path.dirname(os.path.realpath(__file__)) @@ -427,6 +428,29 @@ def test_no_certs_on_list(app): assert "No certificates" in str(excinfo.value) +def test_custom_cert_loader(): + class MyCertLoader(CertLoader): + def load(self, app: Sanic): + self._ssl_data = { + "key": localhost_key, + "cert": localhost_cert, + } + return super().load(app) + + app = Sanic("custom", certloader_class=MyCertLoader) + + @app.get("/test") + async def handler(request): + return text("ssl test") + + client = SanicTestClient(app, port=44556) + + request, response = client.get("https://localhost:44556/test") + assert request.scheme == "https" + assert response.status_code == 200 + assert response.text == "ssl test" + + def test_logger_vhosts(caplog): app = Sanic(name="test_logger_vhosts")