Vhost support using multiple TLS certificates (#2270)
* Initial support for using multiple SSL certificates. * Also list IP address subjectAltNames on log. * Use Python 3.7+ way of specifying TLSv1.2 as the minimum version. Linter fixes. * isort * Cleanup, store server name for later use. Add RSA ciphers. Log rejected SNIs. * Cleanup, linter. * Alter the order of initial log messages and handling. In particular, enable debug mode early so that debug messages during init can be shown. * Store server name (SNI) to conn_info. * Update test with new error message. * Refactor for readability. * Cleanup * Replace old expired test cert with new ones and a script for regenerating them as needed. * Refactor TLS tests to a separate file. * Add cryptography to dev deps for rebuilding TLS certs. * Minor adjustment to messages. * Tests added for new TLS code. * Find the correct log row before testing for message. The order was different on CI. * More log message order fixup. The tests do not account for the logo being printed first. * Another attempt at log message indexing fixup. * Major TLS refactoring. CertSelector now allows dicts and SSLContext within its list. Server names are stored even when no list is used. SSLContext.sanic now contains a dict with any setting passed and information extracted from cert. That information is available on request.conn_info.cert. Type annotations added. More tests incl. a handler for faking hostname in tests. * Remove a problematic logger test that apparently was not adding any coverage or value to anything. * Revert accidental commit of uvloop disable. * Typing fixes / refactoring. * Additional test for cert selection. Certs recreated without DNS:localhost on sanic.example cert. * Add tests for single certificate path shorthand and SNI information. * Move TLS dict processing to CertSimple, make the names field optional and use names from the cert if absent. * Sanic CLI options --tls and --tls-strict-host to use the new features. * SSL argument typing updated * Use ValueError for internal message passing to avoid CertificateError's odd message formatting. * Linter * Test CLI TLS options. * Maybe the right codeclimate option now... * Improved TLS argument help, removed support for combining --cert/--key with --tls. * Removed support for strict checking without any certs, black forced fscked up formatting. * Update CLI tests for stricter TLS options. Co-authored-by: L. Karkkainen <tronic@users.noreply.github.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import sys
|
||||
from argparse import ArgumentParser, RawTextHelpFormatter
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Union
|
||||
|
||||
from sanic_routing import __version__ as __routing_version__ # type: ignore
|
||||
|
||||
@@ -79,10 +79,30 @@ def main():
|
||||
help="location of unix socket\n ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cert", dest="cert", type=str, help="Location of certificate for SSL"
|
||||
"--cert",
|
||||
dest="cert",
|
||||
type=str,
|
||||
help="Location of fullchain.pem, bundle.crt or equivalent",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key", dest="key", type=str, help="location of keyfile for SSL\n "
|
||||
"--key",
|
||||
dest="key",
|
||||
type=str,
|
||||
help="Location of privkey.pem or equivalent .key file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tls",
|
||||
metavar="DIR",
|
||||
type=str,
|
||||
action="append",
|
||||
help="TLS certificate folder with fullchain.pem and privkey.pem\n"
|
||||
"May be specified multiple times to choose of multiple certificates",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tls-strict-host",
|
||||
dest="tlshost",
|
||||
action="store_true",
|
||||
help="Only allow clients that send an SNI matching server certs\n ",
|
||||
)
|
||||
parser.add_bool_arguments(
|
||||
"--access-logs", dest="access_log", help="display access logs"
|
||||
@@ -126,6 +146,26 @@ def main():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Custom TLS mismatch handling for better diagnostics
|
||||
if (
|
||||
# one of cert/key missing
|
||||
bool(args.cert) != bool(args.key)
|
||||
# new and old style args used together
|
||||
or args.tls
|
||||
and args.cert
|
||||
# strict host checking without certs would always fail
|
||||
or args.tlshost
|
||||
and not args.tls
|
||||
and not args.cert
|
||||
):
|
||||
parser.print_usage(sys.stderr)
|
||||
error_logger.error(
|
||||
"sanic: error: TLS certificates must be specified by either of:\n"
|
||||
" --cert certdir/fullchain.pem --key certdir/privkey.pem\n"
|
||||
" --tls certdir (equivalent to the above)"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
module_path = os.path.abspath(os.getcwd())
|
||||
if module_path not in sys.path:
|
||||
@@ -155,14 +195,18 @@ def main():
|
||||
f"Perhaps you meant {args.module}.app?"
|
||||
)
|
||||
|
||||
ssl: Union[None, dict, str, list] = []
|
||||
if args.tlshost:
|
||||
ssl.append(None)
|
||||
if args.cert is not None or args.key is not None:
|
||||
ssl: Optional[Dict[str, Any]] = {
|
||||
"cert": args.cert,
|
||||
"key": args.key,
|
||||
}
|
||||
else:
|
||||
ssl.append(dict(cert=args.cert, key=args.key))
|
||||
if args.tls:
|
||||
ssl += args.tls
|
||||
if not ssl:
|
||||
ssl = None
|
||||
|
||||
elif len(ssl) == 1 and ssl[0] is not None:
|
||||
# Use only one cert, no TLSSelector.
|
||||
ssl = ssl[0]
|
||||
kwargs = {
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
|
||||
78
sanic/app.py
78
sanic/app.py
@@ -19,7 +19,7 @@ from functools import partial
|
||||
from inspect import isawaitable
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
from ssl import Purpose, SSLContext, create_default_context
|
||||
from ssl import SSLContext
|
||||
from traceback import format_exc
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
@@ -78,6 +78,7 @@ from sanic.server import serve, serve_multiple, serve_single
|
||||
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||
from sanic.server.websockets.impl import ConnectionClosed
|
||||
from sanic.signals import Signal, SignalRouter
|
||||
from sanic.tls import process_to_context
|
||||
from sanic.touchup import TouchUp, TouchUpMeta
|
||||
|
||||
|
||||
@@ -952,7 +953,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
*,
|
||||
debug: bool = False,
|
||||
auto_reload: Optional[bool] = None,
|
||||
ssl: Union[Dict[str, str], SSLContext, None] = None,
|
||||
ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
|
||||
sock: Optional[socket] = None,
|
||||
workers: int = 1,
|
||||
protocol: Optional[Type[Protocol]] = None,
|
||||
@@ -979,7 +980,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
:type auto_relaod: bool
|
||||
:param ssl: SSLContext, or location of certificate and key
|
||||
for SSL encryption of worker(s)
|
||||
:type ssl: SSLContext or dict
|
||||
:type ssl: str, dict, SSLContext or list
|
||||
:param sock: Socket for the server to accept connections from
|
||||
:type sock: socket
|
||||
:param workers: Number of processes received before it is respected
|
||||
@@ -1089,7 +1090,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
port: Optional[int] = None,
|
||||
*,
|
||||
debug: bool = False,
|
||||
ssl: Union[Dict[str, str], SSLContext, None] = None,
|
||||
ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
|
||||
sock: Optional[socket] = None,
|
||||
protocol: Type[Protocol] = None,
|
||||
backlog: int = 100,
|
||||
@@ -1281,16 +1282,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
auto_reload=False,
|
||||
):
|
||||
"""Helper function used by `run` and `create_server`."""
|
||||
|
||||
if isinstance(ssl, dict):
|
||||
# try common aliaseses
|
||||
cert = ssl.get("cert") or ssl.get("certificate")
|
||||
key = ssl.get("key") or ssl.get("keyfile")
|
||||
if cert is None or key is None:
|
||||
raise ValueError("SSLContext or certificate and key required.")
|
||||
context = create_default_context(purpose=Purpose.CLIENT_AUTH)
|
||||
context.load_cert_chain(cert, keyfile=key)
|
||||
ssl = context
|
||||
if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0:
|
||||
raise ValueError(
|
||||
"PROXIES_COUNT cannot be negative. "
|
||||
@@ -1300,6 +1291,35 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
|
||||
self.error_handler.debug = debug
|
||||
self.debug = debug
|
||||
if self.configure_logging and debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
if (
|
||||
self.config.LOGO
|
||||
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
|
||||
):
|
||||
logger.debug(
|
||||
self.config.LOGO
|
||||
if isinstance(self.config.LOGO, str)
|
||||
else BASE_LOGO
|
||||
)
|
||||
# Serve
|
||||
if host and port:
|
||||
proto = "http"
|
||||
if ssl is not None:
|
||||
proto = "https"
|
||||
if unix:
|
||||
logger.info(f"Goin' Fast @ {unix} {proto}://...")
|
||||
else:
|
||||
# colon(:) is legal for a host only in an ipv6 address
|
||||
display_host = f"[{host}]" if ":" in host else host
|
||||
logger.info(f"Goin' Fast @ {proto}://{display_host}:{port}")
|
||||
|
||||
debug_mode = "enabled" if self.debug else "disabled"
|
||||
reload_mode = "enabled" if auto_reload else "disabled"
|
||||
logger.debug(f"Sanic auto-reload: {reload_mode}")
|
||||
logger.debug(f"Sanic debug mode: {debug_mode}")
|
||||
|
||||
ssl = process_to_context(ssl)
|
||||
|
||||
server_settings = {
|
||||
"protocol": protocol,
|
||||
@@ -1328,39 +1348,9 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||
listeners = [partial(listener, self) for listener in listeners]
|
||||
server_settings[settings_name] = listeners
|
||||
|
||||
if self.configure_logging and debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
if (
|
||||
self.config.LOGO
|
||||
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
|
||||
):
|
||||
logger.debug(
|
||||
self.config.LOGO
|
||||
if isinstance(self.config.LOGO, str)
|
||||
else BASE_LOGO
|
||||
)
|
||||
|
||||
if run_async:
|
||||
server_settings["run_async"] = True
|
||||
|
||||
# Serve
|
||||
if host and port:
|
||||
proto = "http"
|
||||
if ssl is not None:
|
||||
proto = "https"
|
||||
if unix:
|
||||
logger.info(f"Goin' Fast @ {unix} {proto}://...")
|
||||
else:
|
||||
# colon(:) is legal for a host only in an ipv6 address
|
||||
display_host = f"[{host}]" if ":" in host else host
|
||||
logger.info(f"Goin' Fast @ {proto}://{display_host}:{port}")
|
||||
|
||||
debug_mode = "enabled" if self.debug else "disabled"
|
||||
reload_mode = "enabled" if auto_reload else "disabled"
|
||||
logger.debug(f"Sanic auto-reload: {reload_mode}")
|
||||
logger.debug(f"Sanic debug mode: {debug_mode}")
|
||||
|
||||
return server_settings
|
||||
|
||||
def _build_endpoint_name(self, *parts):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ssl import SSLObject
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
from sanic.models.protocol_types import TransportProtocol
|
||||
|
||||
@@ -20,8 +22,10 @@ class ConnInfo:
|
||||
"peername",
|
||||
"server_port",
|
||||
"server",
|
||||
"server_name",
|
||||
"sockname",
|
||||
"ssl",
|
||||
"cert",
|
||||
)
|
||||
|
||||
def __init__(self, transport: TransportProtocol, unix=None):
|
||||
@@ -31,8 +35,16 @@ class ConnInfo:
|
||||
self.server_port = self.client_port = 0
|
||||
self.client_ip = ""
|
||||
self.sockname = addr = transport.get_extra_info("sockname")
|
||||
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
|
||||
|
||||
self.ssl = False
|
||||
self.server_name = ""
|
||||
self.cert = {}
|
||||
sslobj: Optional[SSLObject] = transport.get_extra_info(
|
||||
"ssl_object"
|
||||
) # type: ignore
|
||||
if sslobj:
|
||||
self.ssl = True
|
||||
self.server_name = getattr(sslobj, "sanic_server_name", None) or ""
|
||||
self.cert = getattr(sslobj.context, "sanic", {})
|
||||
if isinstance(addr, str): # UNIX socket
|
||||
self.server = unix or addr
|
||||
return
|
||||
|
||||
196
sanic/tls.py
Normal file
196
sanic/tls.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import os
|
||||
import ssl
|
||||
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from sanic.log import logger
|
||||
|
||||
|
||||
# Only allow secure ciphers, notably leaving out AES-CBC mode
|
||||
# OpenSSL chooses ECDSA or RSA depending on the cert in use
|
||||
CIPHERS_TLS12 = [
|
||||
"ECDHE-ECDSA-CHACHA20-POLY1305",
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384",
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256",
|
||||
"ECDHE-RSA-CHACHA20-POLY1305",
|
||||
"ECDHE-RSA-AES256-GCM-SHA384",
|
||||
"ECDHE-RSA-AES128-GCM-SHA256",
|
||||
]
|
||||
|
||||
|
||||
def create_context(
|
||||
certfile: Optional[str] = None,
|
||||
keyfile: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> ssl.SSLContext:
|
||||
"""Create a context with secure crypto and HTTP/1.1 in protocols."""
|
||||
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
context.set_ciphers(":".join(CIPHERS_TLS12))
|
||||
context.set_alpn_protocols(["http/1.1"])
|
||||
context.sni_callback = server_name_callback
|
||||
if certfile and keyfile:
|
||||
context.load_cert_chain(certfile, keyfile, password)
|
||||
return context
|
||||
|
||||
|
||||
def shorthand_to_ctx(
|
||||
ctxdef: Union[None, ssl.SSLContext, dict, str]
|
||||
) -> Optional[ssl.SSLContext]:
|
||||
"""Convert an ssl argument shorthand to an SSLContext object."""
|
||||
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
|
||||
return ctxdef
|
||||
if isinstance(ctxdef, str):
|
||||
return load_cert_dir(ctxdef)
|
||||
if isinstance(ctxdef, dict):
|
||||
return CertSimple(**ctxdef)
|
||||
raise ValueError(
|
||||
f"Invalid ssl argument {type(ctxdef)}."
|
||||
" Expecting a list of certdirs, a dict or an SSLContext."
|
||||
)
|
||||
|
||||
|
||||
def process_to_context(
|
||||
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
|
||||
) -> Optional[ssl.SSLContext]:
|
||||
"""Process app.run ssl argument from easy formats to full SSLContext."""
|
||||
return (
|
||||
CertSelector(map(shorthand_to_ctx, ssldef))
|
||||
if isinstance(ssldef, (list, tuple))
|
||||
else shorthand_to_ctx(ssldef)
|
||||
)
|
||||
|
||||
|
||||
def load_cert_dir(p: str) -> ssl.SSLContext:
|
||||
if os.path.isfile(p):
|
||||
raise ValueError(f"Certificate folder expected but {p} is a file.")
|
||||
keyfile = os.path.join(p, "privkey.pem")
|
||||
certfile = os.path.join(p, "fullchain.pem")
|
||||
if not os.access(keyfile, os.R_OK):
|
||||
raise ValueError(
|
||||
f"Certificate not found or permission denied {keyfile}"
|
||||
)
|
||||
if not os.access(certfile, os.R_OK):
|
||||
raise ValueError(
|
||||
f"Certificate not found or permission denied {certfile}"
|
||||
)
|
||||
return CertSimple(certfile, keyfile)
|
||||
|
||||
|
||||
class CertSimple(ssl.SSLContext):
|
||||
"""A wrapper for creating SSLContext with a sanic attribute."""
|
||||
|
||||
def __new__(cls, cert, key, **kw):
|
||||
# try common aliases, rename to cert/key
|
||||
certfile = kw["cert"] = kw.pop("certificate", None) or cert
|
||||
keyfile = kw["key"] = kw.pop("keyfile", None) or key
|
||||
password = kw.pop("password", None)
|
||||
if not certfile or not keyfile:
|
||||
raise ValueError("SSL dict needs filenames for cert and key.")
|
||||
subject = {}
|
||||
if "names" not in kw:
|
||||
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
|
||||
kw["names"] = [
|
||||
name
|
||||
for t, name in cert["subjectAltName"]
|
||||
if t in ["DNS", "IP Address"]
|
||||
]
|
||||
subject = {k: v for item in cert["subject"] for k, v in item}
|
||||
self = create_context(certfile, keyfile, password)
|
||||
self.__class__ = cls
|
||||
self.sanic = {**subject, **kw}
|
||||
return self
|
||||
|
||||
def __init__(self, cert, key, **kw):
|
||||
pass # Do not call super().__init__ because it is already initialized
|
||||
|
||||
|
||||
class CertSelector(ssl.SSLContext):
|
||||
"""Automatically select SSL certificate based on the hostname that the
|
||||
client is trying to access, via SSL SNI. Paths to certificate folders
|
||||
with privkey.pem and fullchain.pem in them should be provided, and
|
||||
will be matched in the order given whenever there is a new connection.
|
||||
"""
|
||||
|
||||
def __new__(cls, ctxs):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
|
||||
super().__init__()
|
||||
self.sni_callback = selector_sni_callback # type: ignore
|
||||
self.sanic_select = []
|
||||
self.sanic_fallback = None
|
||||
all_names = []
|
||||
for i, ctx in enumerate(ctxs):
|
||||
if not ctx:
|
||||
continue
|
||||
names = getattr(ctx, "sanic", {}).get("names", [])
|
||||
all_names += names
|
||||
self.sanic_select.append(ctx)
|
||||
if i == 0:
|
||||
self.sanic_fallback = ctx
|
||||
if not all_names:
|
||||
raise ValueError(
|
||||
"No certificates with SubjectAlternativeNames found."
|
||||
)
|
||||
logger.info(f"Certificate vhosts: {', '.join(all_names)}")
|
||||
|
||||
|
||||
def find_cert(self: CertSelector, server_name: str):
|
||||
"""Find the first certificate that matches the given SNI.
|
||||
|
||||
:raises ssl.CertificateError: No matching certificate found.
|
||||
:return: A matching ssl.SSLContext object if found."""
|
||||
if not server_name:
|
||||
if self.sanic_fallback:
|
||||
return self.sanic_fallback
|
||||
raise ValueError(
|
||||
"The client provided no SNI to match for certificate."
|
||||
)
|
||||
for ctx in self.sanic_select:
|
||||
if match_hostname(ctx, server_name):
|
||||
return ctx
|
||||
if self.sanic_fallback:
|
||||
return self.sanic_fallback
|
||||
raise ValueError(f"No certificate found matching hostname {server_name!r}")
|
||||
|
||||
|
||||
def match_hostname(
|
||||
ctx: Union[ssl.SSLContext, CertSelector], hostname: str
|
||||
) -> bool:
|
||||
"""Match names from CertSelector against a received hostname."""
|
||||
# Local certs are considered trusted, so this can be less pedantic
|
||||
# and thus faster than the deprecated ssl.match_hostname function is.
|
||||
names = getattr(ctx, "sanic", {}).get("names", [])
|
||||
hostname = hostname.lower()
|
||||
for name in names:
|
||||
if name.startswith("*."):
|
||||
if hostname.split(".", 1)[-1] == name[2:]:
|
||||
return True
|
||||
elif name == hostname:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def selector_sni_callback(
|
||||
sslobj: ssl.SSLObject, server_name: str, ctx: CertSelector
|
||||
) -> Optional[int]:
|
||||
"""Select a certificate mathing the SNI."""
|
||||
# Call server_name_callback to store the SNI on sslobj
|
||||
server_name_callback(sslobj, server_name, ctx)
|
||||
# Find a new context matching the hostname
|
||||
try:
|
||||
sslobj.context = find_cert(ctx, server_name)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Rejecting TLS connection: {e}")
|
||||
# This would show ERR_SSL_UNRECOGNIZED_NAME_ALERT on client side if
|
||||
# asyncio/uvloop did proper SSL shutdown. They don't.
|
||||
return ssl.ALERT_DESCRIPTION_UNRECOGNIZED_NAME
|
||||
return None # mypy complains without explicit return
|
||||
|
||||
|
||||
def server_name_callback(
|
||||
sslobj: ssl.SSLObject, server_name: str, ctx: ssl.SSLContext
|
||||
) -> None:
|
||||
"""Store the received SNI as sslobj.sanic_server_name."""
|
||||
sslobj.sanic_server_name = server_name # type: ignore
|
||||
Reference in New Issue
Block a user