from unittest.mock import Mock

import pytest

from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import ProtocolNegotiated

from sanic import Request, Sanic
from sanic.compat import Header
from sanic.config import DEFAULT_CONFIG
from sanic.exceptions import PayloadTooLarge
from sanic.http.constants import Stage
from sanic.http.http3 import Http3, HTTPReceiver
from sanic.models.server_types import ConnInfo
from sanic.response import empty, json
from sanic.server.protocols.http_protocol import Http3Protocol


try:
    from unittest.mock import AsyncMock
except ImportError:
    from tests.asyncmock import AsyncMock  # type: ignore

pytestmark = pytest.mark.asyncio


@pytest.fixture(autouse=True)
async def setup(app: Sanic):
    @app.get("/")
    async def handler(*_):
        return empty()

    app.router.finalize()
    app.signal_router.finalize()
    app.signal_router.allow_fail_builtin = False


@pytest.fixture
def http_request(app):
    return Request(b"/", Header({}), "3", "GET", Mock(), app)


def generate_protocol(app):
    connection = QuicConnection(configuration=QuicConfiguration())
    connection._ack_delay = 0
    connection._loss = Mock()
    connection._loss.spaces = []
    connection._loss.get_loss_detection_time = lambda: None
    connection.datagrams_to_send = Mock(return_value=[])  # type: ignore
    return Http3Protocol(
        connection,
        app=app,
        stream_handler=None,
    )


def generate_http_receiver(app, http_request) -> HTTPReceiver:
    protocol = generate_protocol(app)
    receiver = HTTPReceiver(
        protocol.transmit,
        protocol,
        http_request,
    )
    http_request.stream = receiver
    return receiver


def test_http_receiver_init(app: Sanic, http_request: Request):
    receiver = generate_http_receiver(app, http_request)
    assert receiver.request_body is None
    assert receiver.stage is Stage.IDLE
    assert receiver.headers_sent is False
    assert receiver.response is None
    assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"]
    assert receiver.request_bytes == 0


async def test_http_receiver_run_request(app: Sanic, http_request: Request):
    handler = AsyncMock()

    class mock_handle(Sanic):
        handle_request = handler

    app.__class__ = mock_handle
    receiver = generate_http_receiver(app, http_request)
    receiver.protocol.quic_event_received(
        ProtocolNegotiated(alpn_protocol="h3")
    )
    await receiver.run()
    handler.assert_awaited_once_with(receiver.request)


async def test_http_receiver_run_exception(app: Sanic, http_request: Request):
    handler = AsyncMock()

    class mock_handle(Sanic):
        handle_exception = handler

    app.__class__ = mock_handle
    receiver = generate_http_receiver(app, http_request)
    receiver.protocol.quic_event_received(
        ProtocolNegotiated(alpn_protocol="h3")
    )
    exception = Exception("Oof")
    await receiver.run(exception)
    handler.assert_awaited_once_with(receiver.request, exception)

    handler.reset_mock()
    receiver.stage = Stage.REQUEST
    await receiver.run(exception)
    handler.assert_awaited_once_with(receiver.request, exception)


def test_http_receiver_respond(app: Sanic, http_request: Request):
    receiver = generate_http_receiver(app, http_request)
    response = empty()

    receiver.stage = Stage.RESPONSE
    with pytest.raises(RuntimeError, match="Response already started"):
        receiver.respond(response)

    receiver.stage = Stage.HANDLER
    receiver.response = Mock()
    resp = receiver.respond(response)

    assert receiver.response is resp
    assert resp is response
    assert response.stream is receiver


def test_http_receiver_receive_body(app: Sanic, http_request: Request):
    receiver = generate_http_receiver(app, http_request)
    receiver.request_max_size = 4

    receiver.receive_body(b"..")
    assert receiver.request.body == b".."

    receiver.receive_body(b"..")
    assert receiver.request.body == b"...."

    with pytest.raises(
        PayloadTooLarge, match="Request body exceeds the size limit"
    ):
        receiver.receive_body(b"..")


def test_http3_events(app):
    protocol = generate_protocol(app)
    http3 = Http3(protocol, protocol.transmit)
    http3.http_event_received(
        HeadersReceived(
            [
                (b":method", b"GET"),
                (b":path", b"/location"),
                (b":scheme", b"https"),
                (b":authority", b"localhost:8443"),
                (b"foo", b"bar"),
            ],
            1,
            False,
        )
    )
    http3.http_event_received(DataReceived(b"foobar", 1, False))
    receiver = http3.receivers[1]

    assert len(http3.receivers) == 1
    assert receiver.request.stream_id == 1
    assert receiver.request.path == "/location"
    assert receiver.request.method == "GET"
    assert receiver.request.headers["foo"] == "bar"
    assert receiver.request.body == b"foobar"


async def test_send_headers(app: Sanic, http_request: Request):
    send_headers_mock = Mock()
    existing_send_headers = H3Connection.send_headers
    receiver = generate_http_receiver(app, http_request)
    receiver.protocol.quic_event_received(
        ProtocolNegotiated(alpn_protocol="h3")
    )

    http_request._protocol = receiver.protocol

    def send_headers(*args, **kwargs):
        send_headers_mock(*args, **kwargs)
        return existing_send_headers(
            receiver.protocol.connection, *args, **kwargs
        )

    receiver.protocol.connection.send_headers = send_headers
    receiver.head_only = False
    response = json({}, status=201, headers={"foo": "bar"})

    with pytest.raises(RuntimeError, match="no response"):
        receiver.send_headers()

    receiver.response = response
    receiver.send_headers()

    assert receiver.headers_sent
    assert receiver.stage is Stage.RESPONSE
    send_headers_mock.assert_called_once_with(
        stream_id=0,
        headers=[
            (b":status", b"201"),
            (b"foo", b"bar"),
            (b"content-length", b"2"),
            (b"content-type", b"application/json"),
        ],
    )


def test_multiple_streams(app):
    protocol = generate_protocol(app)
    http3 = Http3(protocol, protocol.transmit)
    http3.http_event_received(
        HeadersReceived(
            [
                (b":method", b"GET"),
                (b":path", b"/location"),
                (b":scheme", b"https"),
                (b":authority", b"localhost:8443"),
                (b"foo", b"bar"),
            ],
            1,
            False,
        )
    )
    http3.http_event_received(
        HeadersReceived(
            [
                (b":method", b"GET"),
                (b":path", b"/location"),
                (b":scheme", b"https"),
                (b":authority", b"localhost:8443"),
                (b"foo", b"bar"),
            ],
            2,
            False,
        )
    )

    receiver1 = http3.get_receiver_by_stream_id(1)
    receiver2 = http3.get_receiver_by_stream_id(2)
    assert len(http3.receivers) == 2
    assert isinstance(receiver1, HTTPReceiver)
    assert isinstance(receiver2, HTTPReceiver)
    assert receiver1 is not receiver2


def test_request_stream_id(app):
    protocol = generate_protocol(app)
    http3 = Http3(protocol, protocol.transmit)
    http3.http_event_received(
        HeadersReceived(
            [
                (b":method", b"GET"),
                (b":path", b"/location"),
                (b":scheme", b"https"),
                (b":authority", b"localhost:8443"),
                (b"foo", b"bar"),
            ],
            1,
            False,
        )
    )
    receiver = http3.get_receiver_by_stream_id(1)

    assert isinstance(receiver.request, Request)
    assert receiver.request.stream_id == 1


def test_request_conn_info(app):
    protocol = generate_protocol(app)
    http3 = Http3(protocol, protocol.transmit)
    http3.http_event_received(
        HeadersReceived(
            [
                (b":method", b"GET"),
                (b":path", b"/location"),
                (b":scheme", b"https"),
                (b":authority", b"localhost:8443"),
                (b"foo", b"bar"),
            ],
            1,
            False,
        )
    )
    receiver = http3.get_receiver_by_stream_id(1)

    assert isinstance(receiver.request.conn_info, ConnInfo)