import logging
import os
import ssl
import subprocess
import sys

from contextlib import contextmanager
from multiprocessing import Event
from pathlib import Path
from unittest.mock import Mock, patch
from urllib.parse import urlparse

import pytest

from sanic_testing.testing import HOST, PORT

import sanic.http.tls.creators

from sanic import Sanic
from sanic.application.constants import Mode
from sanic.compat import use_context
from sanic.constants import LocalCertCreator
from sanic.exceptions import SanicException
from sanic.helpers import _default
from sanic.http.tls.context import SanicSSLContext
from sanic.http.tls.creators import (
    MkcertCreator,
    TrustmeCreator,
    get_ssl_context,
)
from sanic.response import text


current_dir = os.path.dirname(os.path.realpath(__file__))
localhost_dir = os.path.join(current_dir, "certs/localhost")
sanic_dir = os.path.join(current_dir, "certs/sanic.example")
invalid_dir = os.path.join(current_dir, "certs/invalid.nonexist")
localhost_cert = os.path.join(localhost_dir, "fullchain.pem")
localhost_key = os.path.join(localhost_dir, "privkey.pem")
sanic_cert = os.path.join(sanic_dir, "fullchain.pem")
sanic_key = os.path.join(sanic_dir, "privkey.pem")


@pytest.fixture
def server_cert():
    return Mock()


@pytest.fixture
def issue_cert(server_cert):
    mock = Mock(return_value=server_cert)
    return mock


@pytest.fixture
def ca(issue_cert):
    ca = Mock()
    ca.issue_cert = issue_cert
    return ca


@pytest.fixture
def trustme(ca):
    module = Mock()
    module.CA = Mock(return_value=ca)
    return module


@pytest.fixture
def MockMkcertCreator():
    class Creator(MkcertCreator):
        SUPPORTED = True

        def check_supported(self):
            if not self.SUPPORTED:
                raise SanicException("Nope")

        generate_cert = Mock()

    return Creator


@pytest.fixture
def MockTrustmeCreator():
    class Creator(TrustmeCreator):
        SUPPORTED = True

        def check_supported(self):
            if not self.SUPPORTED:
                raise SanicException("Nope")

        generate_cert = Mock()

    return Creator


@contextmanager
def replace_server_name(hostname):
    """Temporarily replace the server name sent with all TLS requests with
    a fake hostname."""

    def hack_wrap_bio(
        self,
        incoming,
        outgoing,
        server_side=False,
        server_hostname=None,
        session=None,
    ):
        return orig_wrap_bio(
            self, incoming, outgoing, server_side, hostname, session
        )

    orig_wrap_bio, ssl.SSLContext.wrap_bio = (
        ssl.SSLContext.wrap_bio,
        hack_wrap_bio,
    )
    try:
        yield
    finally:
        ssl.SSLContext.wrap_bio = orig_wrap_bio


@pytest.mark.parametrize(
    "path,query,expected_url",
    [
        ("/foo", "", "https://{}:{}/foo"),
        ("/bar/baz", "", "https://{}:{}/bar/baz"),
        ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"),
    ],
)
def test_url_attributes_with_ssl_context(app, path, query, expected_url):
    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
    context.load_cert_chain(localhost_cert, localhost_key)

    async def handler(request):
        return text("OK")

    app.add_route(handler, path)

    request, _ = app.test_client.get(
        f"https://{HOST}:{PORT}" + path + f"?{query}",
        server_kwargs={"ssl": context},
    )
    assert request.url == expected_url.format(HOST, request.server_port)

    parsed = urlparse(request.url)

    assert parsed.scheme == request.scheme
    assert parsed.path == request.path
    assert parsed.query == request.query_string
    assert parsed.netloc == request.host


@pytest.mark.parametrize(
    "path,query,expected_url",
    [
        ("/foo", "", "https://{}:{}/foo"),
        ("/bar/baz", "", "https://{}:{}/bar/baz"),
        ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"),
    ],
)
def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
    ssl_dict = {"cert": localhost_cert, "key": localhost_key}

    async def handler(request):
        return text("OK")

    app.add_route(handler, path)

    request, _ = app.test_client.get(
        f"https://{HOST}:{PORT}" + path + f"?{query}",
        server_kwargs={"ssl": ssl_dict},
    )
    assert request.url == expected_url.format(HOST, request.server_port)

    parsed = urlparse(request.url)

    assert parsed.scheme == request.scheme
    assert parsed.path == request.path
    assert parsed.query == request.query_string
    assert parsed.netloc == request.host


def test_cert_sni_single(app):
    @app.get("/sni")
    async def handler1(request):
        return text(request.conn_info.server_name)

    @app.get("/commonname")
    async def handler2(request):
        return text(request.conn_info.cert.get("commonName"))

    port = app.test_client.port
    _, response = app.test_client.get(
        f"https://localhost:{port}/sni",
        server_kwargs={"ssl": localhost_dir},
    )
    assert response.status == 200
    assert response.text == "localhost"

    _, response = app.test_client.get(
        f"https://localhost:{port}/commonname",
        server_kwargs={"ssl": localhost_dir},
    )
    assert response.status == 200
    assert response.text == "localhost"


def test_cert_sni_list(app):
    ssl_list = [sanic_dir, localhost_dir]

    @app.get("/sni")
    async def handler1(request):
        return text(request.conn_info.server_name)

    @app.get("/commonname")
    async def handler2(request):
        return text(request.conn_info.cert.get("commonName"))

    # This test should match the localhost cert
    port = app.test_client.port
    _, response = app.test_client.get(
        f"https://localhost:{port}/sni",
        server_kwargs={"ssl": ssl_list},
    )
    assert response.status == 200
    assert response.text == "localhost"

    request, response = app.test_client.get(
        f"https://localhost:{port}/commonname",
        server_kwargs={"ssl": ssl_list},
    )
    assert response.status == 200
    assert response.text == "localhost"

    # This part should use the sanic.example cert because it matches
    with replace_server_name("www.sanic.example"):
        _, response = app.test_client.get(
            f"https://127.0.0.1:{port}/sni",
            server_kwargs={"ssl": ssl_list},
        )
        assert response.status == 200
        assert response.text == "www.sanic.example"

        _, response = app.test_client.get(
            f"https://127.0.0.1:{port}/commonname",
            server_kwargs={"ssl": ssl_list},
        )
        assert response.status == 200
        assert response.text == "sanic.example"

    # This part should use the sanic.example cert, that being the first listed
    with replace_server_name("invalid.test"):
        _, response = app.test_client.get(
            f"https://127.0.0.1:{port}/sni",
            server_kwargs={"ssl": ssl_list},
        )
        assert response.status == 200
        assert response.text == "invalid.test"

        _, response = app.test_client.get(
            f"https://127.0.0.1:{port}/commonname",
            server_kwargs={"ssl": ssl_list},
        )
        assert response.status == 200
        assert response.text == "sanic.example"


@pytest.mark.xfail
def test_missing_sni(app):
    """The sanic cert does not list 127.0.0.1 and httpx does not send
    IP as SNI anyway."""
    ssl_list = [None, sanic_dir]

    @app.get("/sni")
    async def handler(request):
        return text(request.conn_info.server_name)

    port = app.test_client.port
    with pytest.raises(Exception) as exc:
        app.test_client.get(
            f"https://127.0.0.1:{port}/sni",
            server_kwargs={"ssl": ssl_list},
        )
    assert "Request and response object expected" in str(exc.value)


@pytest.mark.xfail
def test_no_matching_cert(app):
    """The sanic cert does not list 127.0.0.1 and httpx does not send
    IP as SNI anyway."""
    ssl_list = [None, sanic_dir]

    @app.get("/sni")
    async def handler(request):
        return text(request.conn_info.server_name)

    port = app.test_client.port
    with replace_server_name("invalid.test"):
        with pytest.raises(Exception) as exc:
            app.test_client.get(
                f"https://127.0.0.1:{port}/sni",
                server_kwargs={"ssl": ssl_list},
            )
    assert "Request and response object expected" in str(exc.value)


@pytest.mark.xfail
def test_wildcards(app):
    ssl_list = [None, localhost_dir, sanic_dir]

    @app.get("/sni")
    async def handler(request):
        return text(request.conn_info.server_name)

    port = app.test_client.port

    with replace_server_name("foo.sanic.test"):
        _, response = app.test_client.get(
            f"https://127.0.0.1:{port}/sni",
            server_kwargs={"ssl": ssl_list},
        )
        assert response.status == 200
        assert response.text == "foo.sanic.test"

    with replace_server_name("sanic.test"):
        with pytest.raises(Exception) as exc:
            _, response = app.test_client.get(
                f"https://127.0.0.1:{port}/sni",
                server_kwargs={"ssl": ssl_list},
            )
        assert "Request and response object expected" in str(exc.value)
    with replace_server_name("sub.foo.sanic.test"):
        with pytest.raises(Exception) as exc:
            _, response = app.test_client.get(
                f"https://127.0.0.1:{port}/sni",
                server_kwargs={"ssl": ssl_list},
            )
        assert "Request and response object expected" in str(exc.value)


def test_invalid_ssl_dict(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    ssl_dict = {"cert": None, "key": None}

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": ssl_dict})

    assert str(excinfo.value) == "SSL dict needs filenames for cert and key."


def test_invalid_ssl_type(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": False})

    assert "Invalid ssl argument" in str(excinfo.value)


def test_cert_file_on_pathlist(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    ssl_list = [sanic_cert]

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": ssl_list})

    assert "folder expected" in str(excinfo.value)
    assert sanic_cert in str(excinfo.value)


def test_missing_cert_path(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    ssl_list = [invalid_dir]

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": ssl_list})

    assert "not found" in str(excinfo.value)
    assert invalid_dir + "/privkey.pem" in str(excinfo.value)


def test_missing_cert_file(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    invalid2 = invalid_dir.replace("nonexist", "certmissing")
    ssl_list = [invalid2]

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": ssl_list})

    assert "not found" in str(excinfo.value)
    assert invalid2 + "/fullchain.pem" in str(excinfo.value)


def test_no_certs_on_list(app):
    @app.get("/test")
    async def handler(request):
        return text("ssl test")

    ssl_list = [None]

    with pytest.raises(ValueError) as excinfo:
        app.test_client.get("/test", server_kwargs={"ssl": ssl_list})

    assert "No certificates" in str(excinfo.value)


def test_logger_vhosts(caplog):
    app = Sanic(name="test_logger_vhosts")

    @app.after_server_start
    def stop(*args):
        app.stop()

    with caplog.at_level(logging.INFO):
        app.run(
            host="127.0.0.1",
            port=42102,
            ssl=[localhost_dir, sanic_dir],
            single_process=True,
        )

    logmsg = [
        m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
    ][0]

    assert logmsg == (
        "Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, "
        "sanic.example, www.sanic.example, *.sanic.test, "
        "2001:DB8:0:0:0:0:0:541C"
    )


def test_mk_cert_creator_default(app: Sanic):
    cert_creator = MkcertCreator(app, _default, _default)
    assert isinstance(cert_creator.tmpdir, Path)
    assert cert_creator.tmpdir.exists()


def test_mk_cert_creator_is_supported(app):
    cert_creator = MkcertCreator(app, _default, _default)
    with patch("subprocess.run") as run:
        cert_creator.check_supported()
        run.assert_called_once_with(
            ["mkcert", "-help"],
            check=True,
            stderr=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
        )


def test_mk_cert_creator_is_not_supported(app):
    cert_creator = MkcertCreator(app, _default, _default)
    with patch("subprocess.run") as run:
        run.side_effect = Exception("")
        with pytest.raises(
            SanicException, match="Sanic is attempting to use mkcert"
        ):
            cert_creator.check_supported()


def test_mk_cert_creator_generate_cert_default(app):
    cert_creator = MkcertCreator(app, _default, _default)
    with patch("subprocess.run") as run:
        with patch("sanic.http.tls.creators.CertSimple"):
            retval = Mock()
            retval.stdout = "foo"
            run.return_value = retval
            cert_creator.generate_cert("localhost")
            run.assert_called_once()


def test_mk_cert_creator_generate_cert_localhost(app):
    cert_creator = MkcertCreator(app, localhost_key, localhost_cert)
    with patch("subprocess.run") as run:
        with patch("sanic.http.tls.creators.CertSimple"):
            cert_creator.generate_cert("localhost")
            run.assert_not_called()


def test_trustme_creator_default(app: Sanic):
    cert_creator = TrustmeCreator(app, _default, _default)
    assert isinstance(cert_creator.tmpdir, Path)
    assert cert_creator.tmpdir.exists()


def test_trustme_creator_is_supported(app, monkeypatch):
    monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", True)
    cert_creator = TrustmeCreator(app, _default, _default)
    cert_creator.check_supported()


def test_trustme_creator_is_not_supported(app, monkeypatch):
    monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", False)
    cert_creator = TrustmeCreator(app, _default, _default)
    with pytest.raises(
        SanicException, match="Sanic is attempting to use trustme"
    ):
        cert_creator.check_supported()


def test_trustme_creator_generate_cert_default(
    app, monkeypatch, trustme, issue_cert, server_cert, ca
):
    monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
    cert_creator = TrustmeCreator(app, _default, _default)
    cert = cert_creator.generate_cert("localhost")

    assert isinstance(cert, SanicSSLContext)
    trustme.CA.assert_called_once_with()
    issue_cert.assert_called_once_with("localhost")
    server_cert.configure_cert.assert_called_once()
    ca.configure_trust.assert_called_once()
    ca.cert_pem.write_to_path.assert_called_once_with(str(cert.sanic["cert"]))
    write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
    write_to_path.assert_called_once_with(str(cert.sanic["key"]))


def test_trustme_creator_generate_cert_localhost(
    app, monkeypatch, trustme, server_cert, ca
):
    monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
    cert_creator = TrustmeCreator(app, localhost_key, localhost_cert)
    cert_creator.generate_cert("localhost")

    ca.cert_pem.write_to_path.assert_called_once_with(localhost_cert)
    write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
    write_to_path.assert_called_once_with(localhost_key)


def test_get_ssl_context_with_ssl_context(app):
    mock_context = Mock()
    context = get_ssl_context(app, mock_context)
    assert context is mock_context


def test_get_ssl_context_in_production(app):
    app.state.mode = Mode.PRODUCTION
    with pytest.raises(
        SanicException,
        match="Cannot run Sanic as an HTTPS server in PRODUCTION mode",
    ):
        get_ssl_context(app, None)


@pytest.mark.parametrize(
    "requirement,mk_supported,trustme_supported,mk_called,trustme_called,err",
    (
        (LocalCertCreator.AUTO, True, False, True, False, None),
        (LocalCertCreator.AUTO, True, True, True, False, None),
        (LocalCertCreator.AUTO, False, True, False, True, None),
        (
            LocalCertCreator.AUTO,
            False,
            False,
            False,
            False,
            "Sanic could not find package to create a TLS certificate",
        ),
        (LocalCertCreator.MKCERT, True, False, True, False, None),
        (LocalCertCreator.MKCERT, True, True, True, False, None),
        (LocalCertCreator.MKCERT, False, True, False, False, "Nope"),
        (LocalCertCreator.MKCERT, False, False, False, False, "Nope"),
        (LocalCertCreator.TRUSTME, True, False, False, False, "Nope"),
        (LocalCertCreator.TRUSTME, True, True, False, True, None),
        (LocalCertCreator.TRUSTME, False, True, False, True, None),
        (LocalCertCreator.TRUSTME, False, False, False, False, "Nope"),
    ),
)
def test_get_ssl_context_only_mkcert(
    app,
    monkeypatch,
    MockMkcertCreator,
    MockTrustmeCreator,
    requirement,
    mk_supported,
    trustme_supported,
    mk_called,
    trustme_called,
    err,
):
    app.state.mode = Mode.DEBUG
    app.config.LOCAL_CERT_CREATOR = requirement
    monkeypatch.setattr(
        sanic.http.tls.creators, "MkcertCreator", MockMkcertCreator
    )
    monkeypatch.setattr(
        sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
    )
    MockMkcertCreator.SUPPORTED = mk_supported
    MockTrustmeCreator.SUPPORTED = trustme_supported

    if err:
        with pytest.raises(SanicException, match=err):
            get_ssl_context(app, None)
    else:
        get_ssl_context(app, None)

    if mk_called:
        MockMkcertCreator.generate_cert.assert_called_once_with("localhost")
    else:
        MockMkcertCreator.generate_cert.assert_not_called()
    if trustme_called:
        MockTrustmeCreator.generate_cert.assert_called_once_with("localhost")
    else:
        MockTrustmeCreator.generate_cert.assert_not_called()


# def test_no_http3_with_trustme(
#     app,
#     monkeypatch,
#     MockTrustmeCreator,
# ):
#     monkeypatch.setattr(
#         sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
#     )
#     MockTrustmeCreator.SUPPORTED = True
#     app.config.LOCAL_CERT_CREATOR = "TRUSTME"
#     with pytest.raises(
#         SanicException,
#         match=(
#             "Sorry, you cannot currently use trustme as a local certificate "
#             "generator for an HTTP/3 server"
#         ),
#     ):
#         app.run(version=3, debug=True)


def test_sanic_ssl_context_create():
    context = ssl.SSLContext()
    sanic_context = SanicSSLContext.create_from_ssl_context(context)

    assert sanic_context is context
    assert isinstance(sanic_context, SanicSSLContext)


@pytest.mark.skipif(
    sys.platform not in ("linux", "darwin"),
    reason="This test requires fork context",
)
def test_ssl_in_multiprocess_mode(app: Sanic, caplog):

    ssl_dict = {"cert": localhost_cert, "key": localhost_key}
    event = Event()

    @app.main_process_start
    async def main_start(app: Sanic):
        app.shared_ctx.event = event

    @app.after_server_start
    async def shutdown(app):
        app.shared_ctx.event.set()
        app.stop()

    assert not event.is_set()
    with use_context("fork"):
        with caplog.at_level(logging.INFO):
            app.run(ssl=ssl_dict)
    assert event.is_set()

    assert (
        "sanic.root",
        logging.INFO,
        "Goin' Fast @ https://127.0.0.1:8000",
    ) in caplog.record_tuples