Compare commits

...

1 Commits

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import os import os
import ssl import ssl
from pathlib import Path, PurePath
from typing import Any, Dict, Iterable, Optional, Union from typing import Any, Dict, Iterable, Optional, Union
from sanic.log import logger from sanic.log import logger
@ -39,23 +40,23 @@ def create_context(
def shorthand_to_ctx( def shorthand_to_ctx(
ctxdef: Union[None, ssl.SSLContext, dict, str] ctxdef: Union[None, ssl.SSLContext, dict, PurePath, str]
) -> Optional[ssl.SSLContext]: ) -> Optional[ssl.SSLContext]:
"""Convert an ssl argument shorthand to an SSLContext object.""" """Convert an ssl argument shorthand to an SSLContext object."""
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext): if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
return ctxdef return ctxdef
if isinstance(ctxdef, str): if isinstance(ctxdef, (PurePath, str)):
return load_cert_dir(ctxdef) return load_cert_dir(Path(ctxdef))
if isinstance(ctxdef, dict): if isinstance(ctxdef, dict):
return CertSimple(**ctxdef) return CertSimple(**ctxdef)
raise ValueError( raise ValueError(
f"Invalid ssl argument {type(ctxdef)}." f"Invalid ssl argument {type(ctxdef)}."
" Expecting a list of certdirs, a dict or an SSLContext." " Expecting one/list of: certdir | dict | SSLContext"
) )
def process_to_context( def process_to_context(
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple] ssldef: Union[None, ssl.SSLContext, dict, PurePath, str, list, tuple]
) -> Optional[ssl.SSLContext]: ) -> Optional[ssl.SSLContext]:
"""Process app.run ssl argument from easy formats to full SSLContext.""" """Process app.run ssl argument from easy formats to full SSLContext."""
return ( return (
@ -65,11 +66,11 @@ def process_to_context(
) )
def load_cert_dir(p: str) -> ssl.SSLContext: def load_cert_dir(p: Path) -> ssl.SSLContext:
if os.path.isfile(p): if p.is_file():
raise ValueError(f"Certificate folder expected but {p} is a file.") raise ValueError(f"Certificate folder expected but {p} is a file.")
keyfile = os.path.join(p, "privkey.pem") keyfile = p / "privkey.pem"
certfile = os.path.join(p, "fullchain.pem") certfile = p / "fullchain.pem"
if not os.access(keyfile, os.R_OK): if not os.access(keyfile, os.R_OK):
raise ValueError( raise ValueError(
f"Certificate not found or permission denied {keyfile}" f"Certificate not found or permission denied {keyfile}"