Add CertLoader as application option (#2722)

This commit is contained in:
Adam Hopkins 2023-03-20 14:05:21 +02:00 committed by GitHub
parent a245ab3773
commit 89188f5fc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 4 deletions

View File

@ -92,6 +92,7 @@ from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.types.shared_ctx import SharedContext
from sanic.worker.inspector import Inspector
from sanic.worker.loader import CertLoader
from sanic.worker.manager import WorkerManager
@ -139,6 +140,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
"_test_client",
"_test_manager",
"blueprints",
"certloader_class",
"config",
"configure_logging",
"ctx",
@ -181,6 +183,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
inspector_class: Optional[Type[Inspector]] = None,
certloader_class: Optional[Type[CertLoader]] = None,
) -> None:
super().__init__(name=name)
# logging
@ -215,6 +218,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
self.asgi = False
self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {}
self.certloader_class: Type[CertLoader] = (
certloader_class or CertLoader
)
self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace()
self.error_handler: ErrorHandler = error_handler or ErrorHandler()

View File

@ -5,6 +5,7 @@ import sys
from importlib import import_module
from pathlib import Path
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
from sanic.http.tls.context import process_to_context
@ -103,8 +104,16 @@ class CertLoader:
"trustme": TrustmeCreator,
}
def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]):
def __init__(
self,
ssl_data: Optional[
Union[SSLContext, Dict[str, Union[str, os.PathLike]]]
],
):
self._ssl_data = ssl_data
self._creator_class = None
if not ssl_data or not isinstance(ssl_data, dict):
return
creator_name = cast(str, ssl_data.get("creator"))

View File

@ -73,8 +73,8 @@ def worker_serve(
info.settings["app"] = a
a.state.server_info.append(info)
if isinstance(ssl, dict):
cert_loader = CertLoader(ssl)
if isinstance(ssl, dict) or app.certloader_class is not CertLoader:
cert_loader = app.certloader_class(ssl or {})
ssl = cert_loader.load(app)
for info in app.state.server_info:
info.settings["ssl"] = ssl

View File

@ -12,7 +12,7 @@ from urllib.parse import urlparse
import pytest
from sanic_testing.testing import HOST, PORT
from sanic_testing.testing import HOST, PORT, SanicTestClient
import sanic.http.tls.creators
@ -29,6 +29,7 @@ from sanic.http.tls.creators import (
get_ssl_context,
)
from sanic.response import text
from sanic.worker.loader import CertLoader
current_dir = os.path.dirname(os.path.realpath(__file__))
@ -427,6 +428,29 @@ def test_no_certs_on_list(app):
assert "No certificates" in str(excinfo.value)
def test_custom_cert_loader():
class MyCertLoader(CertLoader):
def load(self, app: Sanic):
self._ssl_data = {
"key": localhost_key,
"cert": localhost_cert,
}
return super().load(app)
app = Sanic("custom", certloader_class=MyCertLoader)
@app.get("/test")
async def handler(request):
return text("ssl test")
client = SanicTestClient(app, port=44556)
request, response = client.get("https://localhost:44556/test")
assert request.scheme == "https"
assert response.status_code == 200
assert response.text == "ssl test"
def test_logger_vhosts(caplog):
app = Sanic(name="test_logger_vhosts")