Add GenericCreator for loading SSL certs in processes (#2578)
This commit is contained in:
		| @@ -5,18 +5,10 @@ import sys | |||||||
|  |  | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import ( | from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast | ||||||
|     TYPE_CHECKING, |  | ||||||
|     Any, |  | ||||||
|     Callable, |  | ||||||
|     Dict, |  | ||||||
|     Optional, |  | ||||||
|     Type, |  | ||||||
|     Union, |  | ||||||
|     cast, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| from sanic.http.tls.creators import CertCreator, MkcertCreator, TrustmeCreator | from sanic.http.tls.context import process_to_context | ||||||
|  | from sanic.http.tls.creators import MkcertCreator, TrustmeCreator | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @@ -106,21 +98,30 @@ class AppLoader: | |||||||
|  |  | ||||||
|  |  | ||||||
| class CertLoader: | class CertLoader: | ||||||
|     _creator_class: Type[CertCreator] |     _creators = { | ||||||
|  |         "mkcert": MkcertCreator, | ||||||
|  |         "trustme": TrustmeCreator, | ||||||
|  |     } | ||||||
|  |  | ||||||
|     def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]): |     def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]): | ||||||
|         creator_name = ssl_data.get("creator") |         self._ssl_data = ssl_data | ||||||
|         if creator_name not in ("mkcert", "trustme"): |  | ||||||
|  |         creator_name = cast(str, ssl_data.get("creator")) | ||||||
|  |  | ||||||
|  |         self._creator_class = self._creators.get(creator_name) | ||||||
|  |         if not creator_name: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if not self._creator_class: | ||||||
|             raise RuntimeError(f"Unknown certificate creator: {creator_name}") |             raise RuntimeError(f"Unknown certificate creator: {creator_name}") | ||||||
|         elif creator_name == "mkcert": |  | ||||||
|             self._creator_class = MkcertCreator |  | ||||||
|         elif creator_name == "trustme": |  | ||||||
|             self._creator_class = TrustmeCreator |  | ||||||
|  |  | ||||||
|         self._key = ssl_data["key"] |         self._key = ssl_data["key"] | ||||||
|         self._cert = ssl_data["cert"] |         self._cert = ssl_data["cert"] | ||||||
|         self._localhost = cast(str, ssl_data["localhost"]) |         self._localhost = cast(str, ssl_data["localhost"]) | ||||||
|  |  | ||||||
|     def load(self, app: SanicApp): |     def load(self, app: SanicApp): | ||||||
|  |         if not self._creator_class: | ||||||
|  |             return process_to_context(self._ssl_data) | ||||||
|  |  | ||||||
|         creator = self._creator_class(app, self._key, self._cert) |         creator = self._creator_class(app, self._key, self._cert) | ||||||
|         return creator.generate_cert(self._localhost) |         return creator.generate_cert(self._localhost) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ssl | |||||||
| import subprocess | import subprocess | ||||||
|  |  | ||||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||||
|  | from multiprocessing import Event | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from unittest.mock import Mock, patch | from unittest.mock import Mock, patch | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
| @@ -636,3 +637,29 @@ def test_sanic_ssl_context_create(): | |||||||
|  |  | ||||||
|     assert sanic_context is context |     assert sanic_context is context | ||||||
|     assert isinstance(sanic_context, SanicSSLContext) |     assert isinstance(sanic_context, SanicSSLContext) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_ssl_in_multiprocess_mode(app: Sanic, caplog): | ||||||
|  |  | ||||||
|  |     ssl_dict = {"cert": localhost_cert, "key": localhost_key} | ||||||
|  |     event = Event() | ||||||
|  |  | ||||||
|  |     @app.main_process_start | ||||||
|  |     async def main_start(app: Sanic): | ||||||
|  |         app.shared_ctx.event = event | ||||||
|  |  | ||||||
|  |     @app.after_server_start | ||||||
|  |     async def shutdown(app): | ||||||
|  |         app.shared_ctx.event.set() | ||||||
|  |         app.stop() | ||||||
|  |  | ||||||
|  |     assert not event.is_set() | ||||||
|  |     with caplog.at_level(logging.INFO): | ||||||
|  |         app.run(ssl=ssl_dict) | ||||||
|  |     assert event.is_set() | ||||||
|  |  | ||||||
|  |     assert ( | ||||||
|  |         "sanic.root", | ||||||
|  |         logging.INFO, | ||||||
|  |         "Goin' Fast @ https://127.0.0.1:8000", | ||||||
|  |     ) in caplog.record_tuples | ||||||
|   | |||||||
| @@ -86,6 +86,10 @@ def test_input_is_module(): | |||||||
| @patch("sanic.worker.loader.TrustmeCreator") | @patch("sanic.worker.loader.TrustmeCreator") | ||||||
| @patch("sanic.worker.loader.MkcertCreator") | @patch("sanic.worker.loader.MkcertCreator") | ||||||
| def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str): | def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str): | ||||||
|  |     CertLoader._creators = { | ||||||
|  |         "mkcert": MkcertCreator, | ||||||
|  |         "trustme": TrustmeCreator, | ||||||
|  |     } | ||||||
|     MkcertCreator.return_value = MkcertCreator |     MkcertCreator.return_value = MkcertCreator | ||||||
|     TrustmeCreator.return_value = TrustmeCreator |     TrustmeCreator.return_value = TrustmeCreator | ||||||
|     data = { |     data = { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Adam Hopkins
					Adam Hopkins