sanic/sanic/tls.py
Kian Meng Ang ab35121864
Fix typos (#2329)
* Fix typos

* Downgrade mistune version

* Fix blueprint host param

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
2021-12-06 09:17:01 +02:00

197 lines
6.9 KiB
Python

import os
import ssl
from typing import Iterable, Optional, Union
from sanic.log import logger
# Only allow secure ciphers, notably leaving out AES-CBC mode
# OpenSSL chooses ECDSA or RSA depending on the cert in use
CIPHERS_TLS12 = [
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
]
def create_context(
certfile: Optional[str] = None,
keyfile: Optional[str] = None,
password: Optional[str] = None,
) -> ssl.SSLContext:
"""Create a context with secure crypto and HTTP/1.1 in protocols."""
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.set_ciphers(":".join(CIPHERS_TLS12))
context.set_alpn_protocols(["http/1.1"])
context.sni_callback = server_name_callback
if certfile and keyfile:
context.load_cert_chain(certfile, keyfile, password)
return context
def shorthand_to_ctx(
ctxdef: Union[None, ssl.SSLContext, dict, str]
) -> Optional[ssl.SSLContext]:
"""Convert an ssl argument shorthand to an SSLContext object."""
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
return ctxdef
if isinstance(ctxdef, str):
return load_cert_dir(ctxdef)
if isinstance(ctxdef, dict):
return CertSimple(**ctxdef)
raise ValueError(
f"Invalid ssl argument {type(ctxdef)}."
" Expecting a list of certdirs, a dict or an SSLContext."
)
def process_to_context(
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
) -> Optional[ssl.SSLContext]:
"""Process app.run ssl argument from easy formats to full SSLContext."""
return (
CertSelector(map(shorthand_to_ctx, ssldef))
if isinstance(ssldef, (list, tuple))
else shorthand_to_ctx(ssldef)
)
def load_cert_dir(p: str) -> ssl.SSLContext:
if os.path.isfile(p):
raise ValueError(f"Certificate folder expected but {p} is a file.")
keyfile = os.path.join(p, "privkey.pem")
certfile = os.path.join(p, "fullchain.pem")
if not os.access(keyfile, os.R_OK):
raise ValueError(
f"Certificate not found or permission denied {keyfile}"
)
if not os.access(certfile, os.R_OK):
raise ValueError(
f"Certificate not found or permission denied {certfile}"
)
return CertSimple(certfile, keyfile)
class CertSimple(ssl.SSLContext):
"""A wrapper for creating SSLContext with a sanic attribute."""
def __new__(cls, cert, key, **kw):
# try common aliases, rename to cert/key
certfile = kw["cert"] = kw.pop("certificate", None) or cert
keyfile = kw["key"] = kw.pop("keyfile", None) or key
password = kw.pop("password", None)
if not certfile or not keyfile:
raise ValueError("SSL dict needs filenames for cert and key.")
subject = {}
if "names" not in kw:
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
kw["names"] = [
name
for t, name in cert["subjectAltName"]
if t in ["DNS", "IP Address"]
]
subject = {k: v for item in cert["subject"] for k, v in item}
self = create_context(certfile, keyfile, password)
self.__class__ = cls
self.sanic = {**subject, **kw}
return self
def __init__(self, cert, key, **kw):
pass # Do not call super().__init__ because it is already initialized
class CertSelector(ssl.SSLContext):
"""Automatically select SSL certificate based on the hostname that the
client is trying to access, via SSL SNI. Paths to certificate folders
with privkey.pem and fullchain.pem in them should be provided, and
will be matched in the order given whenever there is a new connection.
"""
def __new__(cls, ctxs):
return super().__new__(cls)
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
super().__init__()
self.sni_callback = selector_sni_callback # type: ignore
self.sanic_select = []
self.sanic_fallback = None
all_names = []
for i, ctx in enumerate(ctxs):
if not ctx:
continue
names = dict(getattr(ctx, "sanic", {})).get("names", [])
all_names += names
self.sanic_select.append(ctx)
if i == 0:
self.sanic_fallback = ctx
if not all_names:
raise ValueError(
"No certificates with SubjectAlternativeNames found."
)
logger.info(f"Certificate vhosts: {', '.join(all_names)}")
def find_cert(self: CertSelector, server_name: str):
"""Find the first certificate that matches the given SNI.
:raises ssl.CertificateError: No matching certificate found.
:return: A matching ssl.SSLContext object if found."""
if not server_name:
if self.sanic_fallback:
return self.sanic_fallback
raise ValueError(
"The client provided no SNI to match for certificate."
)
for ctx in self.sanic_select:
if match_hostname(ctx, server_name):
return ctx
if self.sanic_fallback:
return self.sanic_fallback
raise ValueError(f"No certificate found matching hostname {server_name!r}")
def match_hostname(
ctx: Union[ssl.SSLContext, CertSelector], hostname: str
) -> bool:
"""Match names from CertSelector against a received hostname."""
# Local certs are considered trusted, so this can be less pedantic
# and thus faster than the deprecated ssl.match_hostname function is.
names = dict(getattr(ctx, "sanic", {})).get("names", [])
hostname = hostname.lower()
for name in names:
if name.startswith("*."):
if hostname.split(".", 1)[-1] == name[2:]:
return True
elif name == hostname:
return True
return False
def selector_sni_callback(
sslobj: ssl.SSLObject, server_name: str, ctx: CertSelector
) -> Optional[int]:
"""Select a certificate matching the SNI."""
# Call server_name_callback to store the SNI on sslobj
server_name_callback(sslobj, server_name, ctx)
# Find a new context matching the hostname
try:
sslobj.context = find_cert(ctx, server_name)
except ValueError as e:
logger.warning(f"Rejecting TLS connection: {e}")
# This would show ERR_SSL_UNRECOGNIZED_NAME_ALERT on client side if
# asyncio/uvloop did proper SSL shutdown. They don't.
return ssl.ALERT_DESCRIPTION_UNRECOGNIZED_NAME
return None # mypy complains without explicit return
def server_name_callback(
sslobj: ssl.SSLObject, server_name: str, ctx: ssl.SSLContext
) -> None:
"""Store the received SNI as sslobj.sanic_server_name."""
sslobj.sanic_server_name = server_name # type: ignore