Add CertLoader as application option (#2722)
This commit is contained in:
parent
a245ab3773
commit
89188f5fc6
|
@ -92,6 +92,7 @@ from sanic.signals import Signal, SignalRouter
|
|||
from sanic.touchup import TouchUp, TouchUpMeta
|
||||
from sanic.types.shared_ctx import SharedContext
|
||||
from sanic.worker.inspector import Inspector
|
||||
from sanic.worker.loader import CertLoader
|
||||
from sanic.worker.manager import WorkerManager
|
||||
|
||||
|
||||
|
@ -139,6 +140,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
|
|||
"_test_client",
|
||||
"_test_manager",
|
||||
"blueprints",
|
||||
"certloader_class",
|
||||
"config",
|
||||
"configure_logging",
|
||||
"ctx",
|
||||
|
@ -181,6 +183,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
|
|||
loads: Optional[Callable[..., Any]] = None,
|
||||
inspector: bool = False,
|
||||
inspector_class: Optional[Type[Inspector]] = None,
|
||||
certloader_class: Optional[Type[CertLoader]] = None,
|
||||
) -> None:
|
||||
super().__init__(name=name)
|
||||
# logging
|
||||
|
@ -215,6 +218,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
|
|||
self.asgi = False
|
||||
self.auto_reload = False
|
||||
self.blueprints: Dict[str, Blueprint] = {}
|
||||
self.certloader_class: Type[CertLoader] = (
|
||||
certloader_class or CertLoader
|
||||
)
|
||||
self.configure_logging: bool = configure_logging
|
||||
self.ctx: Any = ctx or SimpleNamespace()
|
||||
self.error_handler: ErrorHandler = error_handler or ErrorHandler()
|
||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
|||
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from ssl import SSLContext
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
|
||||
|
||||
from sanic.http.tls.context import process_to_context
|
||||
|
@ -103,8 +104,16 @@ class CertLoader:
|
|||
"trustme": TrustmeCreator,
|
||||
}
|
||||
|
||||
def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]):
|
||||
def __init__(
|
||||
self,
|
||||
ssl_data: Optional[
|
||||
Union[SSLContext, Dict[str, Union[str, os.PathLike]]]
|
||||
],
|
||||
):
|
||||
self._ssl_data = ssl_data
|
||||
self._creator_class = None
|
||||
if not ssl_data or not isinstance(ssl_data, dict):
|
||||
return
|
||||
|
||||
creator_name = cast(str, ssl_data.get("creator"))
|
||||
|
||||
|
|
|
@ -73,8 +73,8 @@ def worker_serve(
|
|||
info.settings["app"] = a
|
||||
a.state.server_info.append(info)
|
||||
|
||||
if isinstance(ssl, dict):
|
||||
cert_loader = CertLoader(ssl)
|
||||
if isinstance(ssl, dict) or app.certloader_class is not CertLoader:
|
||||
cert_loader = app.certloader_class(ssl or {})
|
||||
ssl = cert_loader.load(app)
|
||||
for info in app.state.server_info:
|
||||
info.settings["ssl"] = ssl
|
||||
|
|
|
@ -12,7 +12,7 @@ from urllib.parse import urlparse
|
|||
|
||||
import pytest
|
||||
|
||||
from sanic_testing.testing import HOST, PORT
|
||||
from sanic_testing.testing import HOST, PORT, SanicTestClient
|
||||
|
||||
import sanic.http.tls.creators
|
||||
|
||||
|
@ -29,6 +29,7 @@ from sanic.http.tls.creators import (
|
|||
get_ssl_context,
|
||||
)
|
||||
from sanic.response import text
|
||||
from sanic.worker.loader import CertLoader
|
||||
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -427,6 +428,29 @@ def test_no_certs_on_list(app):
|
|||
assert "No certificates" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_custom_cert_loader():
|
||||
class MyCertLoader(CertLoader):
|
||||
def load(self, app: Sanic):
|
||||
self._ssl_data = {
|
||||
"key": localhost_key,
|
||||
"cert": localhost_cert,
|
||||
}
|
||||
return super().load(app)
|
||||
|
||||
app = Sanic("custom", certloader_class=MyCertLoader)
|
||||
|
||||
@app.get("/test")
|
||||
async def handler(request):
|
||||
return text("ssl test")
|
||||
|
||||
client = SanicTestClient(app, port=44556)
|
||||
|
||||
request, response = client.get("https://localhost:44556/test")
|
||||
assert request.scheme == "https"
|
||||
assert response.status_code == 200
|
||||
assert response.text == "ssl test"
|
||||
|
||||
|
||||
def test_logger_vhosts(caplog):
|
||||
app = Sanic(name="test_logger_vhosts")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user