sanic/tests/http3/test_http_receiver.py

325 lines
9.7 KiB
Python
Raw Normal View History

2022-06-27 09:19:26 +01:00
from unittest.mock import Mock
import pytest
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import ProtocolNegotiated
from sanic import Request, Sanic
from sanic.compat import Header
from sanic.config import DEFAULT_CONFIG
from sanic.exceptions import BadRequest, PayloadTooLarge
2022-06-27 09:19:26 +01:00
from sanic.http.constants import Stage
from sanic.http.http3 import Http3, HTTPReceiver
from sanic.models.server_types import ConnInfo
from sanic.response import empty, json
from sanic.server.protocols.http_protocol import Http3Protocol
try:
from unittest.mock import AsyncMock
except ImportError:
from tests.asyncmock import AsyncMock # type: ignore
pytestmark = pytest.mark.asyncio
@pytest.fixture(autouse=True)
async def setup(app: Sanic):
@app.get("/")
async def handler(*_):
return empty()
app.router.finalize()
app.signal_router.finalize()
app.signal_router.allow_fail_builtin = False
@pytest.fixture
def http_request(app):
return Request(b"/", Header({}), "3", "GET", Mock(), app)
def generate_protocol(app):
connection = QuicConnection(configuration=QuicConfiguration())
connection._ack_delay = 0
connection._loss = Mock()
connection._loss.spaces = []
connection._loss.get_loss_detection_time = lambda: None
connection.datagrams_to_send = Mock(return_value=[]) # type: ignore
return Http3Protocol(
connection,
app=app,
stream_handler=None,
)
def generate_http_receiver(app, http_request) -> HTTPReceiver:
protocol = generate_protocol(app)
receiver = HTTPReceiver(
protocol.transmit,
protocol,
http_request,
)
http_request.stream = receiver
return receiver
def test_http_receiver_init(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
assert receiver.request_body is None
assert receiver.stage is Stage.IDLE
assert receiver.headers_sent is False
assert receiver.response is None
assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"]
assert receiver.request_bytes == 0
async def test_http_receiver_run_request(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_request = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(ProtocolNegotiated(alpn_protocol="h3"))
2022-06-27 09:19:26 +01:00
await receiver.run()
handler.assert_awaited_once_with(receiver.request)
async def test_http_receiver_run_exception(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_exception = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(ProtocolNegotiated(alpn_protocol="h3"))
2022-06-27 09:19:26 +01:00
exception = Exception("Oof")
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
handler.reset_mock()
receiver.stage = Stage.REQUEST
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
def test_http_receiver_respond(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
response = empty()
receiver.stage = Stage.RESPONSE
with pytest.raises(RuntimeError, match="Response already started"):
receiver.respond(response)
receiver.stage = Stage.HANDLER
receiver.response = Mock()
resp = receiver.respond(response)
assert receiver.response is resp
assert resp is response
assert response.stream is receiver
def test_http_receiver_receive_body(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
receiver.request_max_size = 4
receiver.receive_body(b"..")
assert receiver.request.body == b".."
receiver.receive_body(b"..")
assert receiver.request.body == b"...."
with pytest.raises(PayloadTooLarge, match="Request body exceeds the size limit"):
2022-06-27 09:19:26 +01:00
receiver.receive_body(b"..")
def test_http3_events(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(DataReceived(b"foobar", 1, False))
receiver = http3.receivers[1]
assert len(http3.receivers) == 1
assert receiver.request.stream_id == 1
assert receiver.request.path == "/location"
assert receiver.request.method == "GET"
assert receiver.request.headers["foo"] == "bar"
assert receiver.request.body == b"foobar"
async def test_send_headers(app: Sanic, http_request: Request):
send_headers_mock = Mock()
existing_send_headers = H3Connection.send_headers
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(ProtocolNegotiated(alpn_protocol="h3"))
2022-06-27 09:19:26 +01:00
http_request._protocol = receiver.protocol
def send_headers(*args, **kwargs):
send_headers_mock(*args, **kwargs)
return existing_send_headers(receiver.protocol.connection, *args, **kwargs)
2022-06-27 09:19:26 +01:00
receiver.protocol.connection.send_headers = send_headers
receiver.head_only = False
response = json({}, status=201, headers={"foo": "bar"})
with pytest.raises(RuntimeError, match="no response"):
receiver.send_headers()
receiver.response = response
receiver.send_headers()
assert receiver.headers_sent
assert receiver.stage is Stage.RESPONSE
send_headers_mock.assert_called_once_with(
stream_id=0,
headers=[
(b":status", b"201"),
(b"foo", b"bar"),
(b"content-length", b"2"),
(b"content-type", b"application/json"),
],
)
def test_multiple_streams(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
2,
False,
)
)
receiver1 = http3.get_receiver_by_stream_id(1)
receiver2 = http3.get_receiver_by_stream_id(2)
assert len(http3.receivers) == 2
assert isinstance(receiver1, HTTPReceiver)
assert isinstance(receiver2, HTTPReceiver)
assert receiver1 is not receiver2
def test_request_stream_id(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request, Request)
assert receiver.request.stream_id == 1
def test_request_conn_info(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request.conn_info, ConnInfo)
def test_request_header_encoding(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
with pytest.raises(BadRequest) as exc_info:
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
("foo\u00A0".encode(), b"bar"),
],
1,
False,
)
)
assert exc_info.value.status_code == 400
assert str(exc_info.value) == "Header names may only contain US-ASCII characters."
def test_request_url_encoding(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
with pytest.raises(BadRequest) as exc_info:
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location\xA0"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
assert exc_info.value.status_code == 400
assert str(exc_info.value) == "URL may only contain US-ASCII characters."