379 lines
11 KiB
Python
379 lines
11 KiB
Python
|
import logging
|
||
|
import os
|
||
|
import ssl
|
||
|
import uuid
|
||
|
|
||
|
from contextlib import contextmanager
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from sanic_testing.testing import HOST, PORT, SanicTestClient
|
||
|
|
||
|
from sanic import Sanic
|
||
|
from sanic.compat import OS_IS_WINDOWS
|
||
|
from sanic.log import logger
|
||
|
from sanic.response import text
|
||
|
|
||
|
|
||
|
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||
|
localhost_dir = os.path.join(current_dir, "certs/localhost")
|
||
|
sanic_dir = os.path.join(current_dir, "certs/sanic.example")
|
||
|
invalid_dir = os.path.join(current_dir, "certs/invalid.nonexist")
|
||
|
localhost_cert = os.path.join(localhost_dir, "fullchain.pem")
|
||
|
localhost_key = os.path.join(localhost_dir, "privkey.pem")
|
||
|
sanic_cert = os.path.join(sanic_dir, "fullchain.pem")
|
||
|
sanic_key = os.path.join(sanic_dir, "privkey.pem")
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def replace_server_name(hostname):
|
||
|
"""Temporarily replace the server name sent with all TLS requests with a fake hostname."""
|
||
|
|
||
|
def hack_wrap_bio(
|
||
|
self,
|
||
|
incoming,
|
||
|
outgoing,
|
||
|
server_side=False,
|
||
|
server_hostname=None,
|
||
|
session=None,
|
||
|
):
|
||
|
return orig_wrap_bio(
|
||
|
self, incoming, outgoing, server_side, hostname, session
|
||
|
)
|
||
|
|
||
|
orig_wrap_bio, ssl.SSLContext.wrap_bio = (
|
||
|
ssl.SSLContext.wrap_bio,
|
||
|
hack_wrap_bio,
|
||
|
)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
ssl.SSLContext.wrap_bio = orig_wrap_bio
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"path,query,expected_url",
|
||
|
[
|
||
|
("/foo", "", "https://{}:{}/foo"),
|
||
|
("/bar/baz", "", "https://{}:{}/bar/baz"),
|
||
|
("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"),
|
||
|
],
|
||
|
)
|
||
|
def test_url_attributes_with_ssl_context(app, path, query, expected_url):
|
||
|
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||
|
context.load_cert_chain(localhost_cert, localhost_key)
|
||
|
|
||
|
async def handler(request):
|
||
|
return text("OK")
|
||
|
|
||
|
app.add_route(handler, path)
|
||
|
|
||
|
port = app.test_client.port
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||
|
server_kwargs={"ssl": context},
|
||
|
)
|
||
|
assert request.url == expected_url.format(HOST, request.server_port)
|
||
|
|
||
|
parsed = urlparse(request.url)
|
||
|
|
||
|
assert parsed.scheme == request.scheme
|
||
|
assert parsed.path == request.path
|
||
|
assert parsed.query == request.query_string
|
||
|
assert parsed.netloc == request.host
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"path,query,expected_url",
|
||
|
[
|
||
|
("/foo", "", "https://{}:{}/foo"),
|
||
|
("/bar/baz", "", "https://{}:{}/bar/baz"),
|
||
|
("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"),
|
||
|
],
|
||
|
)
|
||
|
def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
|
||
|
ssl_dict = {"cert": localhost_cert, "key": localhost_key}
|
||
|
|
||
|
async def handler(request):
|
||
|
return text("OK")
|
||
|
|
||
|
app.add_route(handler, path)
|
||
|
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||
|
server_kwargs={"ssl": ssl_dict},
|
||
|
)
|
||
|
assert request.url == expected_url.format(HOST, request.server_port)
|
||
|
|
||
|
parsed = urlparse(request.url)
|
||
|
|
||
|
assert parsed.scheme == request.scheme
|
||
|
assert parsed.path == request.path
|
||
|
assert parsed.query == request.query_string
|
||
|
assert parsed.netloc == request.host
|
||
|
|
||
|
|
||
|
def test_cert_sni_single(app):
|
||
|
@app.get("/sni")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.server_name)
|
||
|
|
||
|
@app.get("/commonname")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.cert.get("commonName"))
|
||
|
|
||
|
port = app.test_client.port
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://localhost:{port}/sni",
|
||
|
server_kwargs={"ssl": localhost_dir},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "localhost"
|
||
|
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://localhost:{port}/commonname",
|
||
|
server_kwargs={"ssl": localhost_dir},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "localhost"
|
||
|
|
||
|
|
||
|
def test_cert_sni_list(app):
|
||
|
ssl_list = [sanic_dir, localhost_dir]
|
||
|
|
||
|
@app.get("/sni")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.server_name)
|
||
|
|
||
|
@app.get("/commonname")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.cert.get("commonName"))
|
||
|
|
||
|
# This test should match the localhost cert
|
||
|
port = app.test_client.port
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://localhost:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "localhost"
|
||
|
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://localhost:{port}/commonname",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "localhost"
|
||
|
|
||
|
# This part should use the sanic.example cert because it matches
|
||
|
with replace_server_name("www.sanic.example"):
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "www.sanic.example"
|
||
|
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/commonname",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "sanic.example"
|
||
|
|
||
|
# This part should use the sanic.example cert, that being the first listed
|
||
|
with replace_server_name("invalid.test"):
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "invalid.test"
|
||
|
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/commonname",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "sanic.example"
|
||
|
|
||
|
|
||
|
def test_missing_sni(app):
|
||
|
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
|
||
|
ssl_list = [None, sanic_dir]
|
||
|
|
||
|
@app.get("/sni")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.server_name)
|
||
|
|
||
|
port = app.test_client.port
|
||
|
with pytest.raises(Exception) as exc:
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert "Request and response object expected" in str(exc.value)
|
||
|
|
||
|
|
||
|
def test_no_matching_cert(app):
|
||
|
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
|
||
|
ssl_list = [None, sanic_dir]
|
||
|
|
||
|
@app.get("/sni")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.server_name)
|
||
|
|
||
|
port = app.test_client.port
|
||
|
with replace_server_name("invalid.test"):
|
||
|
with pytest.raises(Exception) as exc:
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert "Request and response object expected" in str(exc.value)
|
||
|
|
||
|
|
||
|
def test_wildcards(app):
|
||
|
ssl_list = [None, localhost_dir, sanic_dir]
|
||
|
|
||
|
@app.get("/sni")
|
||
|
async def handler(request):
|
||
|
return text(request.conn_info.server_name)
|
||
|
|
||
|
port = app.test_client.port
|
||
|
|
||
|
with replace_server_name("foo.sanic.test"):
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert response.status == 200
|
||
|
assert response.text == "foo.sanic.test"
|
||
|
|
||
|
with replace_server_name("sanic.test"):
|
||
|
with pytest.raises(Exception) as exc:
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert "Request and response object expected" in str(exc.value)
|
||
|
with replace_server_name("sub.foo.sanic.test"):
|
||
|
with pytest.raises(Exception) as exc:
|
||
|
request, response = app.test_client.get(
|
||
|
f"https://127.0.0.1:{port}/sni",
|
||
|
server_kwargs={"ssl": ssl_list},
|
||
|
)
|
||
|
assert "Request and response object expected" in str(exc.value)
|
||
|
|
||
|
|
||
|
def test_invalid_ssl_dict(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
ssl_dict = {"cert": None, "key": None}
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": ssl_dict}
|
||
|
)
|
||
|
|
||
|
assert str(excinfo.value) == "SSL dict needs filenames for cert and key."
|
||
|
|
||
|
|
||
|
def test_invalid_ssl_type(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": False}
|
||
|
)
|
||
|
|
||
|
assert "Invalid ssl argument" in str(excinfo.value)
|
||
|
|
||
|
|
||
|
def test_cert_file_on_pathlist(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
ssl_list = [sanic_cert]
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": ssl_list}
|
||
|
)
|
||
|
|
||
|
assert "folder expected" in str(excinfo.value)
|
||
|
assert sanic_cert in str(excinfo.value)
|
||
|
|
||
|
|
||
|
def test_missing_cert_path(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
ssl_list = [invalid_dir]
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": ssl_list}
|
||
|
)
|
||
|
|
||
|
assert "not found" in str(excinfo.value)
|
||
|
assert invalid_dir + "/privkey.pem" in str(excinfo.value)
|
||
|
|
||
|
|
||
|
def test_missing_cert_file(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
invalid2 = invalid_dir.replace("nonexist", "certmissing")
|
||
|
ssl_list = [invalid2]
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": ssl_list}
|
||
|
)
|
||
|
|
||
|
assert "not found" in str(excinfo.value)
|
||
|
assert invalid2 + "/fullchain.pem" in str(excinfo.value)
|
||
|
|
||
|
|
||
|
def test_no_certs_on_list(app):
|
||
|
@app.get("/test")
|
||
|
async def handler(request):
|
||
|
return text("ssl test")
|
||
|
|
||
|
ssl_list = [None]
|
||
|
|
||
|
with pytest.raises(ValueError) as excinfo:
|
||
|
request, response = app.test_client.get(
|
||
|
"/test", server_kwargs={"ssl": ssl_list}
|
||
|
)
|
||
|
|
||
|
assert "No certificates" in str(excinfo.value)
|
||
|
|
||
|
|
||
|
def test_logger_vhosts(caplog):
|
||
|
app = Sanic(name=__name__)
|
||
|
|
||
|
@app.after_server_start
|
||
|
def stop(*args):
|
||
|
app.stop()
|
||
|
|
||
|
with caplog.at_level(logging.INFO):
|
||
|
app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir])
|
||
|
|
||
|
logmsg = [
|
||
|
m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
|
||
|
][0]
|
||
|
|
||
|
assert logmsg == (
|
||
|
"Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, sanic.example, www.sanic.example, *.sanic.test, 2001:DB8:0:0:0:0:0:541C"
|
||
|
)
|