HTTP/3 Support (#2378)
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -25,6 +25,10 @@ class AsyncMock(Mock):
|
||||
def __await__(self):
|
||||
return self().__await__()
|
||||
|
||||
def reset_mock(self, *args, **kwargs):
|
||||
super().reset_mock(*args, **kwargs)
|
||||
self.await_count = 0
|
||||
|
||||
def assert_awaited_once(self):
|
||||
if not self.await_count == 1:
|
||||
msg = (
|
||||
@@ -32,3 +36,13 @@ class AsyncMock(Mock):
|
||||
f" Awaited {self.await_count} times."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
def assert_awaited_once_with(self, *args, **kwargs):
|
||||
if not self.await_count == 1:
|
||||
msg = (
|
||||
f"Expected to have been awaited once."
|
||||
f" Awaited {self.await_count} times."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
self.assert_awaited_once()
|
||||
return self.assert_called_with(*args, **kwargs)
|
||||
|
||||
47
tests/client.py
Normal file
47
tests/client.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import asyncio
|
||||
|
||||
from textwrap import dedent
|
||||
from typing import AnyStr
|
||||
|
||||
|
||||
class RawClient:
|
||||
CRLF = b"\r\n"
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
async def connect(self):
|
||||
self.reader, self.writer = await asyncio.open_connection(
|
||||
self.host, self.port
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
|
||||
async def send(self, message: AnyStr):
|
||||
if isinstance(message, str):
|
||||
msg = self._clean(message).encode("utf-8")
|
||||
else:
|
||||
msg = message
|
||||
await self._send(msg)
|
||||
|
||||
async def _send(self, message: bytes):
|
||||
if not self.writer:
|
||||
raise Exception("No open write stream")
|
||||
self.writer.write(message)
|
||||
|
||||
async def recv(self, nbytes: int = -1) -> bytes:
|
||||
if not self.reader:
|
||||
raise Exception("No open read stream")
|
||||
return await self.reader.read(nbytes)
|
||||
|
||||
def _clean(self, message: str) -> str:
|
||||
return (
|
||||
dedent(message)
|
||||
.lstrip("\n")
|
||||
.replace("\n", self.CRLF.decode("utf-8"))
|
||||
)
|
||||
@@ -150,6 +150,7 @@ def app(request):
|
||||
yield app
|
||||
for target, method_name in TouchUp._registry:
|
||||
setattr(target, method_name, CACHE[method_name])
|
||||
Sanic._app_registry.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
0
tests/http3/__init__.py
Normal file
0
tests/http3/__init__.py
Normal file
294
tests/http3/test_http_receiver.py
Normal file
294
tests/http3/test_http_receiver.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from aioquic.h3.connection import H3Connection
|
||||
from aioquic.h3.events import DataReceived, HeadersReceived
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import ProtocolNegotiated
|
||||
|
||||
from sanic import Request, Sanic
|
||||
from sanic.compat import Header
|
||||
from sanic.config import DEFAULT_CONFIG
|
||||
from sanic.exceptions import PayloadTooLarge
|
||||
from sanic.http.constants import Stage
|
||||
from sanic.http.http3 import Http3, HTTPReceiver
|
||||
from sanic.models.server_types import ConnInfo
|
||||
from sanic.response import empty, json
|
||||
from sanic.server.protocols.http_protocol import Http3Protocol
|
||||
|
||||
|
||||
try:
|
||||
from unittest.mock import AsyncMock
|
||||
except ImportError:
|
||||
from tests.asyncmock import AsyncMock # type: ignore
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup(app: Sanic):
|
||||
@app.get("/")
|
||||
async def handler(*_):
|
||||
return empty()
|
||||
|
||||
app.router.finalize()
|
||||
app.signal_router.finalize()
|
||||
app.signal_router.allow_fail_builtin = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_request(app):
|
||||
return Request(b"/", Header({}), "3", "GET", Mock(), app)
|
||||
|
||||
|
||||
def generate_protocol(app):
|
||||
connection = QuicConnection(configuration=QuicConfiguration())
|
||||
connection._ack_delay = 0
|
||||
connection._loss = Mock()
|
||||
connection._loss.spaces = []
|
||||
connection._loss.get_loss_detection_time = lambda: None
|
||||
connection.datagrams_to_send = Mock(return_value=[]) # type: ignore
|
||||
return Http3Protocol(
|
||||
connection,
|
||||
app=app,
|
||||
stream_handler=None,
|
||||
)
|
||||
|
||||
|
||||
def generate_http_receiver(app, http_request) -> HTTPReceiver:
|
||||
protocol = generate_protocol(app)
|
||||
receiver = HTTPReceiver(
|
||||
protocol.transmit,
|
||||
protocol,
|
||||
http_request,
|
||||
)
|
||||
http_request.stream = receiver
|
||||
return receiver
|
||||
|
||||
|
||||
def test_http_receiver_init(app: Sanic, http_request: Request):
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
assert receiver.request_body is None
|
||||
assert receiver.stage is Stage.IDLE
|
||||
assert receiver.headers_sent is False
|
||||
assert receiver.response is None
|
||||
assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"]
|
||||
assert receiver.request_bytes == 0
|
||||
|
||||
|
||||
async def test_http_receiver_run_request(app: Sanic, http_request: Request):
|
||||
handler = AsyncMock()
|
||||
|
||||
class mock_handle(Sanic):
|
||||
handle_request = handler
|
||||
|
||||
app.__class__ = mock_handle
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
receiver.protocol.quic_event_received(
|
||||
ProtocolNegotiated(alpn_protocol="h3")
|
||||
)
|
||||
await receiver.run()
|
||||
handler.assert_awaited_once_with(receiver.request)
|
||||
|
||||
|
||||
async def test_http_receiver_run_exception(app: Sanic, http_request: Request):
|
||||
handler = AsyncMock()
|
||||
|
||||
class mock_handle(Sanic):
|
||||
handle_exception = handler
|
||||
|
||||
app.__class__ = mock_handle
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
receiver.protocol.quic_event_received(
|
||||
ProtocolNegotiated(alpn_protocol="h3")
|
||||
)
|
||||
exception = Exception("Oof")
|
||||
await receiver.run(exception)
|
||||
handler.assert_awaited_once_with(receiver.request, exception)
|
||||
|
||||
handler.reset_mock()
|
||||
receiver.stage = Stage.REQUEST
|
||||
await receiver.run(exception)
|
||||
handler.assert_awaited_once_with(receiver.request, exception)
|
||||
|
||||
|
||||
def test_http_receiver_respond(app: Sanic, http_request: Request):
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
response = empty()
|
||||
|
||||
receiver.stage = Stage.RESPONSE
|
||||
with pytest.raises(RuntimeError, match="Response already started"):
|
||||
receiver.respond(response)
|
||||
|
||||
receiver.stage = Stage.HANDLER
|
||||
receiver.response = Mock()
|
||||
resp = receiver.respond(response)
|
||||
|
||||
assert receiver.response is resp
|
||||
assert resp is response
|
||||
assert response.stream is receiver
|
||||
|
||||
|
||||
def test_http_receiver_receive_body(app: Sanic, http_request: Request):
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
receiver.request_max_size = 4
|
||||
|
||||
receiver.receive_body(b"..")
|
||||
assert receiver.request.body == b".."
|
||||
|
||||
receiver.receive_body(b"..")
|
||||
assert receiver.request.body == b"...."
|
||||
|
||||
with pytest.raises(
|
||||
PayloadTooLarge, match="Request body exceeds the size limit"
|
||||
):
|
||||
receiver.receive_body(b"..")
|
||||
|
||||
|
||||
def test_http3_events(app):
|
||||
protocol = generate_protocol(app)
|
||||
http3 = Http3(protocol, protocol.transmit)
|
||||
http3.http_event_received(
|
||||
HeadersReceived(
|
||||
[
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/location"),
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"localhost:8443"),
|
||||
(b"foo", b"bar"),
|
||||
],
|
||||
1,
|
||||
False,
|
||||
)
|
||||
)
|
||||
http3.http_event_received(DataReceived(b"foobar", 1, False))
|
||||
receiver = http3.receivers[1]
|
||||
|
||||
assert len(http3.receivers) == 1
|
||||
assert receiver.request.stream_id == 1
|
||||
assert receiver.request.path == "/location"
|
||||
assert receiver.request.method == "GET"
|
||||
assert receiver.request.headers["foo"] == "bar"
|
||||
assert receiver.request.body == b"foobar"
|
||||
|
||||
|
||||
async def test_send_headers(app: Sanic, http_request: Request):
|
||||
send_headers_mock = Mock()
|
||||
existing_send_headers = H3Connection.send_headers
|
||||
receiver = generate_http_receiver(app, http_request)
|
||||
receiver.protocol.quic_event_received(
|
||||
ProtocolNegotiated(alpn_protocol="h3")
|
||||
)
|
||||
|
||||
http_request._protocol = receiver.protocol
|
||||
|
||||
def send_headers(*args, **kwargs):
|
||||
send_headers_mock(*args, **kwargs)
|
||||
return existing_send_headers(
|
||||
receiver.protocol.connection, *args, **kwargs
|
||||
)
|
||||
|
||||
receiver.protocol.connection.send_headers = send_headers
|
||||
receiver.head_only = False
|
||||
response = json({}, status=201, headers={"foo": "bar"})
|
||||
|
||||
with pytest.raises(RuntimeError, match="no response"):
|
||||
receiver.send_headers()
|
||||
|
||||
receiver.response = response
|
||||
receiver.send_headers()
|
||||
|
||||
assert receiver.headers_sent
|
||||
assert receiver.stage is Stage.RESPONSE
|
||||
send_headers_mock.assert_called_once_with(
|
||||
stream_id=0,
|
||||
headers=[
|
||||
(b":status", b"201"),
|
||||
(b"foo", b"bar"),
|
||||
(b"content-length", b"2"),
|
||||
(b"content-type", b"application/json"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_streams(app):
|
||||
protocol = generate_protocol(app)
|
||||
http3 = Http3(protocol, protocol.transmit)
|
||||
http3.http_event_received(
|
||||
HeadersReceived(
|
||||
[
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/location"),
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"localhost:8443"),
|
||||
(b"foo", b"bar"),
|
||||
],
|
||||
1,
|
||||
False,
|
||||
)
|
||||
)
|
||||
http3.http_event_received(
|
||||
HeadersReceived(
|
||||
[
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/location"),
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"localhost:8443"),
|
||||
(b"foo", b"bar"),
|
||||
],
|
||||
2,
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
receiver1 = http3.get_receiver_by_stream_id(1)
|
||||
receiver2 = http3.get_receiver_by_stream_id(2)
|
||||
assert len(http3.receivers) == 2
|
||||
assert isinstance(receiver1, HTTPReceiver)
|
||||
assert isinstance(receiver2, HTTPReceiver)
|
||||
assert receiver1 is not receiver2
|
||||
|
||||
|
||||
def test_request_stream_id(app):
|
||||
protocol = generate_protocol(app)
|
||||
http3 = Http3(protocol, protocol.transmit)
|
||||
http3.http_event_received(
|
||||
HeadersReceived(
|
||||
[
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/location"),
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"localhost:8443"),
|
||||
(b"foo", b"bar"),
|
||||
],
|
||||
1,
|
||||
False,
|
||||
)
|
||||
)
|
||||
receiver = http3.get_receiver_by_stream_id(1)
|
||||
|
||||
assert isinstance(receiver.request, Request)
|
||||
assert receiver.request.stream_id == 1
|
||||
|
||||
|
||||
def test_request_conn_info(app):
|
||||
protocol = generate_protocol(app)
|
||||
http3 = Http3(protocol, protocol.transmit)
|
||||
http3.http_event_received(
|
||||
HeadersReceived(
|
||||
[
|
||||
(b":method", b"GET"),
|
||||
(b":path", b"/location"),
|
||||
(b":scheme", b"https"),
|
||||
(b":authority", b"localhost:8443"),
|
||||
(b"foo", b"bar"),
|
||||
],
|
||||
1,
|
||||
False,
|
||||
)
|
||||
)
|
||||
receiver = http3.get_receiver_by_stream_id(1)
|
||||
|
||||
assert isinstance(receiver.request.conn_info, ConnInfo)
|
||||
114
tests/http3/test_server.py
Normal file
114
tests/http3/test_server.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from asyncio import Event
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.compat import UVLOOP_INSTALLED
|
||||
from sanic.http.constants import HTTP
|
||||
|
||||
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
localhost_dir = parent_dir / "certs/localhost"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", (3, HTTP.VERSION_3))
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
|
||||
reason="In 3.7 w/o uvloop the port is not always released",
|
||||
)
|
||||
def test_server_starts_http3(app: Sanic, version, caplog):
|
||||
ev = Event()
|
||||
|
||||
@app.after_server_start
|
||||
def shutdown(*_):
|
||||
ev.set()
|
||||
app.stop()
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
app.run(
|
||||
version=version,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
)
|
||||
|
||||
assert ev.is_set()
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"server: sanic, HTTP/3",
|
||||
) in caplog.record_tuples
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
|
||||
reason="In 3.7 w/o uvloop the port is not always released",
|
||||
)
|
||||
def test_server_starts_http1_and_http3(app: Sanic, caplog):
|
||||
@app.after_server_start
|
||||
def shutdown(*_):
|
||||
app.stop()
|
||||
|
||||
app.prepare(
|
||||
version=3,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
)
|
||||
app.prepare(
|
||||
version=1,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
Sanic.serve()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"server: sanic, HTTP/1.1",
|
||||
) in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.INFO,
|
||||
"server: sanic, HTTP/3",
|
||||
) in caplog.record_tuples
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
|
||||
reason="In 3.7 w/o uvloop the port is not always released",
|
||||
)
|
||||
def test_server_starts_http1_and_http3_bad_order(app: Sanic, caplog):
|
||||
@app.after_server_start
|
||||
def shutdown(*_):
|
||||
app.stop()
|
||||
|
||||
app.prepare(
|
||||
version=1,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
)
|
||||
message = (
|
||||
"Serving HTTP/3 instances as a secondary server is not supported. "
|
||||
"There can only be a single HTTP/3 worker and it must be the first "
|
||||
"instance prepared."
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=message):
|
||||
app.prepare(
|
||||
version=3,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
)
|
||||
46
tests/http3/test_session_ticket_store.py
Normal file
46
tests/http3/test_session_ticket_store.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
|
||||
from aioquic.tls import CipherSuite, SessionTicket
|
||||
|
||||
from sanic.http.http3 import SessionTicketStore
|
||||
|
||||
|
||||
def _generate_ticket(label):
|
||||
return SessionTicket(
|
||||
1,
|
||||
CipherSuite.AES_128_GCM_SHA256,
|
||||
datetime.now(),
|
||||
datetime.now(),
|
||||
label,
|
||||
label.decode(),
|
||||
label,
|
||||
None,
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
def test_session_ticket_store():
|
||||
store = SessionTicketStore()
|
||||
|
||||
assert len(store.tickets) == 0
|
||||
|
||||
ticket1 = _generate_ticket(b"foo")
|
||||
store.add(ticket1)
|
||||
|
||||
assert len(store.tickets) == 1
|
||||
|
||||
ticket2 = _generate_ticket(b"bar")
|
||||
store.add(ticket2)
|
||||
|
||||
assert len(store.tickets) == 2
|
||||
assert len(store.tickets) == 2
|
||||
|
||||
popped2 = store.pop(ticket2.ticket)
|
||||
|
||||
assert len(store.tickets) == 1
|
||||
assert popped2 is ticket2
|
||||
|
||||
popped1 = store.pop(ticket1.ticket)
|
||||
|
||||
assert len(store.tickets) == 0
|
||||
assert popped1 is ticket1
|
||||
@@ -417,7 +417,7 @@ async def test_request_class_custom():
|
||||
class MyCustomRequest(Request):
|
||||
pass
|
||||
|
||||
app = Sanic(name=__name__, request_class=MyCustomRequest)
|
||||
app = Sanic(name="Test", request_class=MyCustomRequest)
|
||||
|
||||
@app.get("/custom")
|
||||
def custom_request(request):
|
||||
|
||||
@@ -148,8 +148,7 @@ def test_tls_wrong_options(cmd: Tuple[str]):
|
||||
assert not out
|
||||
lines = err.decode().split("\n")
|
||||
|
||||
errmsg = lines[6]
|
||||
assert errmsg == "TLS certificates must be specified by either of:"
|
||||
assert "TLS certificates must be specified by either of:" in lines
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from contextlib import contextmanager
|
||||
from os import environ
|
||||
@@ -13,6 +14,7 @@ from pytest import MonkeyPatch
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.config import DEFAULT_CONFIG, Config
|
||||
from sanic.constants import LocalCertCreator
|
||||
from sanic.exceptions import PyFileError
|
||||
|
||||
|
||||
@@ -49,7 +51,7 @@ def test_load_from_object(app: Sanic):
|
||||
|
||||
|
||||
def test_load_from_object_string(app: Sanic):
|
||||
app.config.load("test_config.ConfigTest")
|
||||
app.config.load("tests.test_config.ConfigTest")
|
||||
assert "CONFIG_VALUE" in app.config
|
||||
assert app.config.CONFIG_VALUE == "should be used"
|
||||
assert "not_for_config" not in app.config
|
||||
@@ -71,14 +73,14 @@ def test_load_from_object_string_exception(app: Sanic):
|
||||
|
||||
def test_auto_env_prefix():
|
||||
environ["SANIC_TEST_ANSWER"] = "42"
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
assert app.config.TEST_ANSWER == 42
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_auto_bool_env_prefix():
|
||||
environ["SANIC_TEST_ANSWER"] = "True"
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
assert app.config.TEST_ANSWER is True
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
@@ -86,28 +88,28 @@ def test_auto_bool_env_prefix():
|
||||
@pytest.mark.parametrize("env_prefix", [None, ""])
|
||||
def test_empty_load_env_prefix(env_prefix):
|
||||
environ["SANIC_TEST_ANSWER"] = "42"
|
||||
app = Sanic(name=__name__, env_prefix=env_prefix)
|
||||
app = Sanic(name="Test", env_prefix=env_prefix)
|
||||
assert getattr(app.config, "TEST_ANSWER", None) is None
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_env_prefix():
|
||||
environ["MYAPP_TEST_ANSWER"] = "42"
|
||||
app = Sanic(name=__name__, env_prefix="MYAPP_")
|
||||
app = Sanic(name="Test", env_prefix="MYAPP_")
|
||||
assert app.config.TEST_ANSWER == 42
|
||||
del environ["MYAPP_TEST_ANSWER"]
|
||||
|
||||
|
||||
def test_env_prefix_float_values():
|
||||
environ["MYAPP_TEST_ROI"] = "2.3"
|
||||
app = Sanic(name=__name__, env_prefix="MYAPP_")
|
||||
app = Sanic(name="Test", env_prefix="MYAPP_")
|
||||
assert app.config.TEST_ROI == 2.3
|
||||
del environ["MYAPP_TEST_ROI"]
|
||||
|
||||
|
||||
def test_env_prefix_string_value():
|
||||
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
|
||||
app = Sanic(name=__name__, env_prefix="MYAPP_")
|
||||
app = Sanic(name="Test", env_prefix="MYAPP_")
|
||||
assert app.config.TEST_TOKEN == "somerandomtesttoken"
|
||||
del environ["MYAPP_TEST_TOKEN"]
|
||||
|
||||
@@ -116,7 +118,7 @@ def test_env_w_custom_converter():
|
||||
environ["SANIC_TEST_ANSWER"] = "42"
|
||||
|
||||
config = Config(converters=[UltimateAnswer])
|
||||
app = Sanic(name=__name__, config=config)
|
||||
app = Sanic(name="Test", config=config)
|
||||
assert isinstance(app.config.TEST_ANSWER, UltimateAnswer)
|
||||
assert app.config.TEST_ANSWER.answer == 42
|
||||
del environ["SANIC_TEST_ANSWER"]
|
||||
@@ -125,7 +127,7 @@ def test_env_w_custom_converter():
|
||||
def test_env_lowercase():
|
||||
with pytest.warns(None) as record:
|
||||
environ["SANIC_test_answer"] = "42"
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
assert app.config.test_answer == 42
|
||||
assert str(record[0].message) == (
|
||||
"[DEPRECATION v22.9] Lowercase environment variables will not be "
|
||||
@@ -435,3 +437,21 @@ def test_negative_proxy_count(app: Sanic):
|
||||
)
|
||||
with pytest.raises(ValueError, match=message):
|
||||
app.prepare()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"passed,expected",
|
||||
(
|
||||
("auto", LocalCertCreator.AUTO),
|
||||
("mkcert", LocalCertCreator.MKCERT),
|
||||
("trustme", LocalCertCreator.TRUSTME),
|
||||
("AUTO", LocalCertCreator.AUTO),
|
||||
("MKCERT", LocalCertCreator.MKCERT),
|
||||
("TRUSTME", LocalCertCreator.TRUSTME),
|
||||
),
|
||||
)
|
||||
def test_convert_local_cert_creator(passed, expected):
|
||||
os.environ["SANIC_LOCAL_CERT_CREATOR"] = passed
|
||||
app = Sanic("Test")
|
||||
assert app.config.LOCAL_CERT_CREATOR is expected
|
||||
del os.environ["SANIC_LOCAL_CERT_CREATOR"]
|
||||
|
||||
@@ -17,7 +17,7 @@ class CustomRequest(Request):
|
||||
|
||||
|
||||
def test_custom_request():
|
||||
app = Sanic(name=__name__, request_class=CustomRequest)
|
||||
app = Sanic(name="Test", request_class=CustomRequest)
|
||||
|
||||
@app.route("/post", methods=["POST"])
|
||||
async def post_handler(request):
|
||||
|
||||
@@ -259,7 +259,7 @@ def test_custom_exception_default_message(exception_app):
|
||||
|
||||
|
||||
def test_exception_in_ws_logged(caplog):
|
||||
app = Sanic(__name__)
|
||||
app = Sanic("Test")
|
||||
|
||||
@app.websocket("/feed")
|
||||
async def feed(request, ws):
|
||||
@@ -279,7 +279,7 @@ def test_exception_in_ws_logged(caplog):
|
||||
|
||||
@pytest.mark.parametrize("debug", (True, False))
|
||||
def test_contextual_exception_context(debug):
|
||||
app = Sanic(__name__)
|
||||
app = Sanic("Test")
|
||||
|
||||
class TeapotError(SanicException):
|
||||
status_code = 418
|
||||
@@ -314,7 +314,7 @@ def test_contextual_exception_context(debug):
|
||||
|
||||
@pytest.mark.parametrize("debug", (True, False))
|
||||
def test_contextual_exception_extra(debug):
|
||||
app = Sanic(__name__)
|
||||
app = Sanic("Test")
|
||||
|
||||
class TeapotError(SanicException):
|
||||
status_code = 418
|
||||
@@ -361,7 +361,7 @@ def test_contextual_exception_extra(debug):
|
||||
|
||||
@pytest.mark.parametrize("override", (True, False))
|
||||
def test_contextual_exception_functional_message(override):
|
||||
app = Sanic(__name__)
|
||||
app = Sanic("Test")
|
||||
|
||||
class TeapotError(SanicException):
|
||||
status_code = 418
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import json as stdjson
|
||||
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import AnyStr
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -11,54 +9,15 @@ from sanic_testing.reusable import ReusableClient
|
||||
|
||||
from sanic import json, text
|
||||
from sanic.app import Sanic
|
||||
from tests.client import RawClient
|
||||
|
||||
|
||||
parent_dir = Path(__file__).parent
|
||||
localhost_dir = parent_dir / "certs/localhost"
|
||||
|
||||
PORT = 1234
|
||||
|
||||
|
||||
class RawClient:
|
||||
CRLF = b"\r\n"
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
async def connect(self):
|
||||
self.reader, self.writer = await asyncio.open_connection(
|
||||
self.host, self.port
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
|
||||
async def send(self, message: AnyStr):
|
||||
if isinstance(message, str):
|
||||
msg = self._clean(message).encode("utf-8")
|
||||
else:
|
||||
msg = message
|
||||
await self._send(msg)
|
||||
|
||||
async def _send(self, message: bytes):
|
||||
if not self.writer:
|
||||
raise Exception("No open write stream")
|
||||
self.writer.write(message)
|
||||
|
||||
async def recv(self, nbytes: int = -1) -> bytes:
|
||||
if not self.reader:
|
||||
raise Exception("No open read stream")
|
||||
return await self.reader.read(nbytes)
|
||||
|
||||
def _clean(self, message: str) -> str:
|
||||
return (
|
||||
dedent(message)
|
||||
.lstrip("\n")
|
||||
.replace("\n", self.CRLF.decode("utf-8"))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(app: Sanic):
|
||||
app.config.KEEP_ALIVE_TIMEOUT = 1
|
||||
@@ -115,7 +74,7 @@ def test_full_message(client):
|
||||
"""
|
||||
)
|
||||
response = client.recv()
|
||||
assert len(response) == 140
|
||||
assert len(response) == 151
|
||||
assert b"200 OK" in response
|
||||
|
||||
|
||||
|
||||
66
tests/test_http_alt_svc.py
Normal file
66
tests/test_http_alt_svc.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic.app import Sanic
|
||||
from sanic.response import empty
|
||||
from tests.client import RawClient
|
||||
|
||||
|
||||
parent_dir = Path(__file__).parent
|
||||
localhost_dir = parent_dir / "certs/localhost"
|
||||
|
||||
PORT = 12344
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Not supported in 3.7")
|
||||
def test_http1_response_has_alt_svc():
|
||||
Sanic._app_registry.clear()
|
||||
app = Sanic("TestAltSvc")
|
||||
app.config.TOUCHUP = True
|
||||
response = b""
|
||||
|
||||
@app.get("/")
|
||||
async def handler(*_):
|
||||
return empty()
|
||||
|
||||
@app.after_server_start
|
||||
async def do_request(*_):
|
||||
nonlocal response
|
||||
|
||||
app.router.reset()
|
||||
app.router.finalize()
|
||||
|
||||
client = RawClient(app.state.host, app.state.port)
|
||||
await client.connect()
|
||||
await client.send(
|
||||
"""
|
||||
GET / HTTP/1.1
|
||||
host: localhost:7777
|
||||
|
||||
"""
|
||||
)
|
||||
response = await client.recv()
|
||||
await client.close()
|
||||
|
||||
@app.after_server_start
|
||||
def shutdown(*_):
|
||||
app.stop()
|
||||
|
||||
app.prepare(
|
||||
version=3,
|
||||
ssl={
|
||||
"cert": localhost_dir / "fullchain.pem",
|
||||
"key": localhost_dir / "privkey.pem",
|
||||
},
|
||||
port=PORT,
|
||||
)
|
||||
app.prepare(
|
||||
version=1,
|
||||
port=PORT,
|
||||
)
|
||||
Sanic.serve()
|
||||
|
||||
assert f'alt-svc: h3=":{PORT}"\r\n'.encode() in response
|
||||
@@ -136,7 +136,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
|
||||
async def test_logger(caplog):
|
||||
rand_string = str(uuid.uuid4())
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
|
||||
@app.get("/")
|
||||
def log_info(request):
|
||||
@@ -163,7 +163,7 @@ def test_logging_modified_root_logger_config():
|
||||
|
||||
def test_access_log_client_ip_remote_addr(monkeypatch):
|
||||
access = Mock()
|
||||
monkeypatch.setattr(sanic.http, "access_logger", access)
|
||||
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
|
||||
|
||||
app = Sanic("test_logging")
|
||||
app.config.PROXIES_COUNT = 2
|
||||
@@ -190,7 +190,7 @@ def test_access_log_client_ip_remote_addr(monkeypatch):
|
||||
|
||||
def test_access_log_client_ip_reqip(monkeypatch):
|
||||
access = Mock()
|
||||
monkeypatch.setattr(sanic.http, "access_logger", access)
|
||||
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
|
||||
|
||||
app = Sanic("test_logging")
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_motd_with_expected_info(app, run_startup):
|
||||
|
||||
assert logs[1][2] == f"Sanic v{__version__}"
|
||||
assert logs[3][2] == "mode: debug, single worker"
|
||||
assert logs[4][2] == "server: sanic"
|
||||
assert logs[4][2] == "server: sanic, HTTP/1.1"
|
||||
assert logs[5][2] == f"python: {platform.python_version()}"
|
||||
assert logs[6][2] == f"platform: {platform.platform()}"
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from sanic.touchup.schemes.ode import OptionalDispatchEvent
|
||||
try:
|
||||
from unittest.mock import AsyncMock
|
||||
except ImportError:
|
||||
from asyncmock import AsyncMock # type: ignore
|
||||
from tests.asyncmock import AsyncMock # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -231,3 +231,15 @@ def test_get_current_request(app):
|
||||
|
||||
_, resp = app.test_client.get("/")
|
||||
assert resp.json["same"]
|
||||
|
||||
|
||||
def test_request_stream_id(app):
|
||||
@app.get("/")
|
||||
async def get(request):
|
||||
try:
|
||||
request.stream_id
|
||||
except Exception as e:
|
||||
return response.text(str(e))
|
||||
|
||||
_, resp = app.test_client.get("/")
|
||||
assert resp.text == "Stream ID is only a property of a HTTP/3 request"
|
||||
|
||||
@@ -552,7 +552,7 @@ def test_streaming_new_api(app):
|
||||
|
||||
def test_streaming_echo():
|
||||
"""2-way streaming chat between server and client."""
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
|
||||
@app.post("/echo", stream=True)
|
||||
async def handler(request):
|
||||
|
||||
@@ -2050,7 +2050,7 @@ async def test_request_form_invalid_content_type_asgi(app):
|
||||
|
||||
|
||||
def test_endpoint_basic():
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
|
||||
@app.route("/")
|
||||
def my_unique_handler(request):
|
||||
@@ -2058,12 +2058,12 @@ def test_endpoint_basic():
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
||||
assert request.endpoint == "test_requests.my_unique_handler"
|
||||
assert request.endpoint == "Test.my_unique_handler"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_basic_asgi():
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="Test")
|
||||
|
||||
@app.route("/")
|
||||
def my_unique_handler(request):
|
||||
@@ -2071,7 +2071,7 @@ async def test_endpoint_basic_asgi():
|
||||
|
||||
request, response = await app.asgi_client.get("/")
|
||||
|
||||
assert request.endpoint == "test_requests.my_unique_handler"
|
||||
assert request.endpoint == "Test.my_unique_handler"
|
||||
|
||||
|
||||
def test_endpoint_named_app():
|
||||
|
||||
@@ -101,11 +101,12 @@ def test_response_header(app):
|
||||
return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"})
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
assert dict(response.headers) == {
|
||||
for key, value in {
|
||||
"connection": "keep-alive",
|
||||
"content-length": "11",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
}.items():
|
||||
assert response.headers[key] == value
|
||||
|
||||
|
||||
def test_response_content_length(app):
|
||||
|
||||
@@ -13,7 +13,7 @@ from sanic.response import empty
|
||||
try:
|
||||
from unittest.mock import AsyncMock
|
||||
except ImportError:
|
||||
from asyncmock import AsyncMock # type: ignore
|
||||
from tests.asyncmock import AsyncMock # type: ignore
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@@ -1,18 +1,30 @@
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import uuid
|
||||
import subprocess
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from sanic_testing.testing import HOST, PORT, SanicTestClient
|
||||
from sanic_testing.testing import HOST, PORT
|
||||
|
||||
import sanic.http.tls.creators
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.log import logger
|
||||
from sanic.application.constants import Mode
|
||||
from sanic.constants import LocalCertCreator
|
||||
from sanic.exceptions import SanicException
|
||||
from sanic.helpers import _default
|
||||
from sanic.http.tls.context import SanicSSLContext
|
||||
from sanic.http.tls.creators import (
|
||||
MkcertCreator,
|
||||
TrustmeCreator,
|
||||
get_ssl_context,
|
||||
)
|
||||
from sanic.response import text
|
||||
|
||||
|
||||
@@ -26,9 +38,63 @@ sanic_cert = os.path.join(sanic_dir, "fullchain.pem")
|
||||
sanic_key = os.path.join(sanic_dir, "privkey.pem")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_cert():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def issue_cert(server_cert):
|
||||
mock = Mock(return_value=server_cert)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ca(issue_cert):
|
||||
ca = Mock()
|
||||
ca.issue_cert = issue_cert
|
||||
return ca
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trustme(ca):
|
||||
module = Mock()
|
||||
module.CA = Mock(return_value=ca)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def MockMkcertCreator():
|
||||
class Creator(MkcertCreator):
|
||||
SUPPORTED = True
|
||||
|
||||
def check_supported(self):
|
||||
if not self.SUPPORTED:
|
||||
raise SanicException("Nope")
|
||||
|
||||
generate_cert = Mock()
|
||||
|
||||
return Creator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def MockTrustmeCreator():
|
||||
class Creator(TrustmeCreator):
|
||||
SUPPORTED = True
|
||||
|
||||
def check_supported(self):
|
||||
if not self.SUPPORTED:
|
||||
raise SanicException("Nope")
|
||||
|
||||
generate_cert = Mock()
|
||||
|
||||
return Creator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def replace_server_name(hostname):
|
||||
"""Temporarily replace the server name sent with all TLS requests with a fake hostname."""
|
||||
"""Temporarily replace the server name sent with all TLS requests with
|
||||
a fake hostname."""
|
||||
|
||||
def hack_wrap_bio(
|
||||
self,
|
||||
@@ -69,8 +135,7 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
|
||||
|
||||
app.add_route(handler, path)
|
||||
|
||||
port = app.test_client.port
|
||||
request, response = app.test_client.get(
|
||||
request, _ = app.test_client.get(
|
||||
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||||
server_kwargs={"ssl": context},
|
||||
)
|
||||
@@ -100,7 +165,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
|
||||
|
||||
app.add_route(handler, path)
|
||||
|
||||
request, response = app.test_client.get(
|
||||
request, _ = app.test_client.get(
|
||||
f"https://{HOST}:{PORT}" + path + f"?{query}",
|
||||
server_kwargs={"ssl": ssl_dict},
|
||||
)
|
||||
@@ -116,22 +181,22 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
|
||||
|
||||
def test_cert_sni_single(app):
|
||||
@app.get("/sni")
|
||||
async def handler(request):
|
||||
async def handler1(request):
|
||||
return text(request.conn_info.server_name)
|
||||
|
||||
@app.get("/commonname")
|
||||
async def handler(request):
|
||||
async def handler2(request):
|
||||
return text(request.conn_info.cert.get("commonName"))
|
||||
|
||||
port = app.test_client.port
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://localhost:{port}/sni",
|
||||
server_kwargs={"ssl": localhost_dir},
|
||||
)
|
||||
assert response.status == 200
|
||||
assert response.text == "localhost"
|
||||
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://localhost:{port}/commonname",
|
||||
server_kwargs={"ssl": localhost_dir},
|
||||
)
|
||||
@@ -143,16 +208,16 @@ def test_cert_sni_list(app):
|
||||
ssl_list = [sanic_dir, localhost_dir]
|
||||
|
||||
@app.get("/sni")
|
||||
async def handler(request):
|
||||
async def handler1(request):
|
||||
return text(request.conn_info.server_name)
|
||||
|
||||
@app.get("/commonname")
|
||||
async def handler(request):
|
||||
async def handler2(request):
|
||||
return text(request.conn_info.cert.get("commonName"))
|
||||
|
||||
# This test should match the localhost cert
|
||||
port = app.test_client.port
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://localhost:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -168,14 +233,14 @@ def test_cert_sni_list(app):
|
||||
|
||||
# This part should use the sanic.example cert because it matches
|
||||
with replace_server_name("www.sanic.example"):
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
assert response.status == 200
|
||||
assert response.text == "www.sanic.example"
|
||||
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/commonname",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -184,14 +249,14 @@ def test_cert_sni_list(app):
|
||||
|
||||
# This part should use the sanic.example cert, that being the first listed
|
||||
with replace_server_name("invalid.test"):
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
assert response.status == 200
|
||||
assert response.text == "invalid.test"
|
||||
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/commonname",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -200,7 +265,8 @@ def test_cert_sni_list(app):
|
||||
|
||||
|
||||
def test_missing_sni(app):
|
||||
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
|
||||
"""The sanic cert does not list 127.0.0.1 and httpx does not send
|
||||
IP as SNI anyway."""
|
||||
ssl_list = [None, sanic_dir]
|
||||
|
||||
@app.get("/sni")
|
||||
@@ -209,7 +275,7 @@ def test_missing_sni(app):
|
||||
|
||||
port = app.test_client.port
|
||||
with pytest.raises(Exception) as exc:
|
||||
request, response = app.test_client.get(
|
||||
app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -217,7 +283,8 @@ def test_missing_sni(app):
|
||||
|
||||
|
||||
def test_no_matching_cert(app):
|
||||
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
|
||||
"""The sanic cert does not list 127.0.0.1 and httpx does not send
|
||||
IP as SNI anyway."""
|
||||
ssl_list = [None, sanic_dir]
|
||||
|
||||
@app.get("/sni")
|
||||
@@ -227,7 +294,7 @@ def test_no_matching_cert(app):
|
||||
port = app.test_client.port
|
||||
with replace_server_name("invalid.test"):
|
||||
with pytest.raises(Exception) as exc:
|
||||
request, response = app.test_client.get(
|
||||
app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -244,7 +311,7 @@ def test_wildcards(app):
|
||||
port = app.test_client.port
|
||||
|
||||
with replace_server_name("foo.sanic.test"):
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -253,14 +320,14 @@ def test_wildcards(app):
|
||||
|
||||
with replace_server_name("sanic.test"):
|
||||
with pytest.raises(Exception) as exc:
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
assert "Request and response object expected" in str(exc.value)
|
||||
with replace_server_name("sub.foo.sanic.test"):
|
||||
with pytest.raises(Exception) as exc:
|
||||
request, response = app.test_client.get(
|
||||
_, response = app.test_client.get(
|
||||
f"https://127.0.0.1:{port}/sni",
|
||||
server_kwargs={"ssl": ssl_list},
|
||||
)
|
||||
@@ -275,9 +342,7 @@ def test_invalid_ssl_dict(app):
|
||||
ssl_dict = {"cert": None, "key": None}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": ssl_dict}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": ssl_dict})
|
||||
|
||||
assert str(excinfo.value) == "SSL dict needs filenames for cert and key."
|
||||
|
||||
@@ -288,9 +353,7 @@ def test_invalid_ssl_type(app):
|
||||
return text("ssl test")
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": False}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": False})
|
||||
|
||||
assert "Invalid ssl argument" in str(excinfo.value)
|
||||
|
||||
@@ -303,9 +366,7 @@ def test_cert_file_on_pathlist(app):
|
||||
ssl_list = [sanic_cert]
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": ssl_list}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
|
||||
|
||||
assert "folder expected" in str(excinfo.value)
|
||||
assert sanic_cert in str(excinfo.value)
|
||||
@@ -319,9 +380,7 @@ def test_missing_cert_path(app):
|
||||
ssl_list = [invalid_dir]
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": ssl_list}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
|
||||
|
||||
assert "not found" in str(excinfo.value)
|
||||
assert invalid_dir + "/privkey.pem" in str(excinfo.value)
|
||||
@@ -336,9 +395,7 @@ def test_missing_cert_file(app):
|
||||
ssl_list = [invalid2]
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": ssl_list}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
|
||||
|
||||
assert "not found" in str(excinfo.value)
|
||||
assert invalid2 + "/fullchain.pem" in str(excinfo.value)
|
||||
@@ -352,15 +409,13 @@ def test_no_certs_on_list(app):
|
||||
ssl_list = [None]
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
request, response = app.test_client.get(
|
||||
"/test", server_kwargs={"ssl": ssl_list}
|
||||
)
|
||||
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
|
||||
|
||||
assert "No certificates" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_logger_vhosts(caplog):
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test_logger_vhosts")
|
||||
|
||||
@app.after_server_start
|
||||
def stop(*args):
|
||||
@@ -374,5 +429,210 @@ def test_logger_vhosts(caplog):
|
||||
][0]
|
||||
|
||||
assert logmsg == (
|
||||
"Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, sanic.example, www.sanic.example, *.sanic.test, 2001:DB8:0:0:0:0:0:541C"
|
||||
"Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, "
|
||||
"sanic.example, www.sanic.example, *.sanic.test, "
|
||||
"2001:DB8:0:0:0:0:0:541C"
|
||||
)
|
||||
|
||||
|
||||
def test_mk_cert_creator_default(app: Sanic):
|
||||
cert_creator = MkcertCreator(app, _default, _default)
|
||||
assert isinstance(cert_creator.tmpdir, Path)
|
||||
assert cert_creator.tmpdir.exists()
|
||||
|
||||
|
||||
def test_mk_cert_creator_is_supported(app):
|
||||
cert_creator = MkcertCreator(app, _default, _default)
|
||||
with patch("subprocess.run") as run:
|
||||
cert_creator.check_supported()
|
||||
run.assert_called_once_with(
|
||||
["mkcert", "-help"],
|
||||
check=True,
|
||||
stderr=subprocess.DEVNULL,
|
||||
stdout=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
|
||||
def test_mk_cert_creator_is_not_supported(app):
|
||||
cert_creator = MkcertCreator(app, _default, _default)
|
||||
with patch("subprocess.run") as run:
|
||||
run.side_effect = Exception("")
|
||||
with pytest.raises(
|
||||
SanicException, match="Sanic is attempting to use mkcert"
|
||||
):
|
||||
cert_creator.check_supported()
|
||||
|
||||
|
||||
def test_mk_cert_creator_generate_cert_default(app):
|
||||
cert_creator = MkcertCreator(app, _default, _default)
|
||||
with patch("subprocess.run") as run:
|
||||
with patch("sanic.http.tls.creators.CertSimple"):
|
||||
retval = Mock()
|
||||
retval.stdout = "foo"
|
||||
run.return_value = retval
|
||||
cert_creator.generate_cert("localhost")
|
||||
run.assert_called_once()
|
||||
|
||||
|
||||
def test_mk_cert_creator_generate_cert_localhost(app):
|
||||
cert_creator = MkcertCreator(app, localhost_key, localhost_cert)
|
||||
with patch("subprocess.run") as run:
|
||||
with patch("sanic.http.tls.creators.CertSimple"):
|
||||
cert_creator.generate_cert("localhost")
|
||||
run.assert_not_called()
|
||||
|
||||
|
||||
def test_trustme_creator_default(app: Sanic):
|
||||
cert_creator = TrustmeCreator(app, _default, _default)
|
||||
assert isinstance(cert_creator.tmpdir, Path)
|
||||
assert cert_creator.tmpdir.exists()
|
||||
|
||||
|
||||
def test_trustme_creator_is_supported(app, monkeypatch):
|
||||
monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", True)
|
||||
cert_creator = TrustmeCreator(app, _default, _default)
|
||||
cert_creator.check_supported()
|
||||
|
||||
|
||||
def test_trustme_creator_is_not_supported(app, monkeypatch):
|
||||
monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", False)
|
||||
cert_creator = TrustmeCreator(app, _default, _default)
|
||||
with pytest.raises(
|
||||
SanicException, match="Sanic is attempting to use trustme"
|
||||
):
|
||||
cert_creator.check_supported()
|
||||
|
||||
|
||||
def test_trustme_creator_generate_cert_default(
|
||||
app, monkeypatch, trustme, issue_cert, server_cert, ca
|
||||
):
|
||||
monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
|
||||
cert_creator = TrustmeCreator(app, _default, _default)
|
||||
cert = cert_creator.generate_cert("localhost")
|
||||
|
||||
assert isinstance(cert, SanicSSLContext)
|
||||
trustme.CA.assert_called_once_with()
|
||||
issue_cert.assert_called_once_with("localhost")
|
||||
server_cert.configure_cert.assert_called_once()
|
||||
ca.configure_trust.assert_called_once()
|
||||
ca.cert_pem.write_to_path.assert_called_once_with(str(cert.sanic["cert"]))
|
||||
write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
|
||||
write_to_path.assert_called_once_with(str(cert.sanic["key"]))
|
||||
|
||||
|
||||
def test_trustme_creator_generate_cert_localhost(
|
||||
app, monkeypatch, trustme, server_cert, ca
|
||||
):
|
||||
monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
|
||||
cert_creator = TrustmeCreator(app, localhost_key, localhost_cert)
|
||||
cert_creator.generate_cert("localhost")
|
||||
|
||||
ca.cert_pem.write_to_path.assert_called_once_with(localhost_cert)
|
||||
write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
|
||||
write_to_path.assert_called_once_with(localhost_key)
|
||||
|
||||
|
||||
def test_get_ssl_context_with_ssl_context(app):
|
||||
mock_context = Mock()
|
||||
context = get_ssl_context(app, mock_context)
|
||||
assert context is mock_context
|
||||
|
||||
|
||||
def test_get_ssl_context_in_production(app):
|
||||
app.state.mode = Mode.PRODUCTION
|
||||
with pytest.raises(
|
||||
SanicException,
|
||||
match="Cannot run Sanic as an HTTPS server in PRODUCTION mode",
|
||||
):
|
||||
get_ssl_context(app, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"requirement,mk_supported,trustme_supported,mk_called,trustme_called,err",
|
||||
(
|
||||
(LocalCertCreator.AUTO, True, False, True, False, None),
|
||||
(LocalCertCreator.AUTO, True, True, True, False, None),
|
||||
(LocalCertCreator.AUTO, False, True, False, True, None),
|
||||
(
|
||||
LocalCertCreator.AUTO,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
"Sanic could not find package to create a TLS certificate",
|
||||
),
|
||||
(LocalCertCreator.MKCERT, True, False, True, False, None),
|
||||
(LocalCertCreator.MKCERT, True, True, True, False, None),
|
||||
(LocalCertCreator.MKCERT, False, True, False, False, "Nope"),
|
||||
(LocalCertCreator.MKCERT, False, False, False, False, "Nope"),
|
||||
(LocalCertCreator.TRUSTME, True, False, False, False, "Nope"),
|
||||
(LocalCertCreator.TRUSTME, True, True, False, True, None),
|
||||
(LocalCertCreator.TRUSTME, False, True, False, True, None),
|
||||
(LocalCertCreator.TRUSTME, False, False, False, False, "Nope"),
|
||||
),
|
||||
)
|
||||
def test_get_ssl_context_only_mkcert(
|
||||
app,
|
||||
monkeypatch,
|
||||
MockMkcertCreator,
|
||||
MockTrustmeCreator,
|
||||
requirement,
|
||||
mk_supported,
|
||||
trustme_supported,
|
||||
mk_called,
|
||||
trustme_called,
|
||||
err,
|
||||
):
|
||||
app.state.mode = Mode.DEBUG
|
||||
app.config.LOCAL_CERT_CREATOR = requirement
|
||||
monkeypatch.setattr(
|
||||
sanic.http.tls.creators, "MkcertCreator", MockMkcertCreator
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
|
||||
)
|
||||
MockMkcertCreator.SUPPORTED = mk_supported
|
||||
MockTrustmeCreator.SUPPORTED = trustme_supported
|
||||
|
||||
if err:
|
||||
with pytest.raises(SanicException, match=err):
|
||||
get_ssl_context(app, None)
|
||||
else:
|
||||
get_ssl_context(app, None)
|
||||
|
||||
if mk_called:
|
||||
MockMkcertCreator.generate_cert.assert_called_once_with("localhost")
|
||||
else:
|
||||
MockMkcertCreator.generate_cert.assert_not_called()
|
||||
if trustme_called:
|
||||
MockTrustmeCreator.generate_cert.assert_called_once_with("localhost")
|
||||
else:
|
||||
MockTrustmeCreator.generate_cert.assert_not_called()
|
||||
|
||||
|
||||
def test_no_http3_with_trustme(
|
||||
app,
|
||||
monkeypatch,
|
||||
MockTrustmeCreator,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
|
||||
)
|
||||
MockTrustmeCreator.SUPPORTED = True
|
||||
app.config.LOCAL_CERT_CREATOR = "TRUSTME"
|
||||
with pytest.raises(
|
||||
SanicException,
|
||||
match=(
|
||||
"Sorry, you cannot currently use trustme as a local certificate "
|
||||
"generator for an HTTP/3 server"
|
||||
),
|
||||
):
|
||||
app.run(version=3, debug=True)
|
||||
|
||||
|
||||
def test_sanic_ssl_context_create():
|
||||
context = ssl.SSLContext()
|
||||
sanic_context = SanicSSLContext.create_from_ssl_context(context)
|
||||
|
||||
assert sanic_context is context
|
||||
assert isinstance(sanic_context, SanicSSLContext)
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_unix_socket_creation(caplog):
|
||||
assert os.path.exists(SOCKPATH)
|
||||
ino = os.stat(SOCKPATH).st_ino
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.listener("after_server_start")
|
||||
def running(app, loop):
|
||||
@@ -74,7 +74,7 @@ def test_unix_socket_creation(caplog):
|
||||
|
||||
@pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock"))
|
||||
def test_invalid_paths(path):
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
with pytest.raises((FileExistsError, FileNotFoundError)):
|
||||
app.run(unix=path)
|
||||
@@ -84,7 +84,7 @@ def test_dont_replace_file():
|
||||
with open(SOCKPATH, "w") as f:
|
||||
f.write("File, not socket")
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.listener("after_server_start")
|
||||
def stop(app, loop):
|
||||
@@ -101,7 +101,7 @@ def test_dont_follow_symlink():
|
||||
sock.bind(SOCKPATH2)
|
||||
os.symlink(SOCKPATH2, SOCKPATH)
|
||||
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.listener("after_server_start")
|
||||
def stop(app, loop):
|
||||
@@ -112,7 +112,7 @@ def test_dont_follow_symlink():
|
||||
|
||||
|
||||
def test_socket_deleted_while_running():
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def hack(app, loop):
|
||||
@@ -123,7 +123,7 @@ def test_socket_deleted_while_running():
|
||||
|
||||
|
||||
def test_socket_replaced_with_file():
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.listener("after_server_start")
|
||||
async def hack(app, loop):
|
||||
@@ -136,7 +136,7 @@ def test_socket_replaced_with_file():
|
||||
|
||||
|
||||
def test_unix_connection():
|
||||
app = Sanic(name=__name__)
|
||||
app = Sanic(name="test")
|
||||
|
||||
@app.get("/")
|
||||
def handler(request):
|
||||
@@ -159,7 +159,7 @@ def test_unix_connection():
|
||||
app.run(host="myhost.invalid", unix=SOCKPATH)
|
||||
|
||||
|
||||
app_multi = Sanic(name=__name__)
|
||||
app_multi = Sanic(name="test")
|
||||
|
||||
|
||||
def handler(request):
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_route(app, handler):
|
||||
|
||||
|
||||
def test_bp(app, handler):
|
||||
bp = Blueprint(__name__, version=1)
|
||||
bp = Blueprint("Test", version=1)
|
||||
bp.route("/")(handler)
|
||||
app.blueprint(bp)
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_bp(app, handler):
|
||||
|
||||
|
||||
def test_bp_use_route(app, handler):
|
||||
bp = Blueprint(__name__, version=1)
|
||||
bp = Blueprint("Test", version=1)
|
||||
bp.route("/", version=1.1)(handler)
|
||||
app.blueprint(bp)
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_bp_use_route(app, handler):
|
||||
|
||||
|
||||
def test_bp_group(app, handler):
|
||||
bp = Blueprint(__name__)
|
||||
bp = Blueprint("Test")
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1)
|
||||
app.blueprint(group)
|
||||
@@ -47,7 +47,7 @@ def test_bp_group(app, handler):
|
||||
|
||||
|
||||
def test_bp_group_use_bp(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1)
|
||||
bp = Blueprint("Test", version=1.1)
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1)
|
||||
app.blueprint(group)
|
||||
@@ -57,7 +57,7 @@ def test_bp_group_use_bp(app, handler):
|
||||
|
||||
|
||||
def test_bp_group_use_registration(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1)
|
||||
bp = Blueprint("Test", version=1.1)
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1)
|
||||
app.blueprint(group, version=1.2)
|
||||
@@ -67,7 +67,7 @@ def test_bp_group_use_registration(app, handler):
|
||||
|
||||
|
||||
def test_bp_group_use_route(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1)
|
||||
bp = Blueprint("Test", version=1.1)
|
||||
bp.route("/", version=1.3)(handler)
|
||||
group = Blueprint.group(bp, version=1)
|
||||
app.blueprint(group, version=1.2)
|
||||
@@ -84,7 +84,7 @@ def test_version_prefix_route(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp(app, handler):
|
||||
bp = Blueprint(__name__, version=1, version_prefix="/api/v")
|
||||
bp = Blueprint("Test", version=1, version_prefix="/api/v")
|
||||
bp.route("/")(handler)
|
||||
app.blueprint(bp)
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_version_prefix_bp(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp_use_route(app, handler):
|
||||
bp = Blueprint(__name__, version=1, version_prefix="/ignore/v")
|
||||
bp = Blueprint("Test", version=1, version_prefix="/ignore/v")
|
||||
bp.route("/", version=1.1, version_prefix="/api/v")(handler)
|
||||
app.blueprint(bp)
|
||||
|
||||
@@ -102,7 +102,7 @@ def test_version_prefix_bp_use_route(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp_group(app, handler):
|
||||
bp = Blueprint(__name__)
|
||||
bp = Blueprint("Test")
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1, version_prefix="/api/v")
|
||||
app.blueprint(group)
|
||||
@@ -112,7 +112,7 @@ def test_version_prefix_bp_group(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp_group_use_bp(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1, version_prefix="/api/v")
|
||||
bp = Blueprint("Test", version=1.1, version_prefix="/api/v")
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
|
||||
app.blueprint(group)
|
||||
@@ -122,7 +122,7 @@ def test_version_prefix_bp_group_use_bp(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp_group_use_registration(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v")
|
||||
bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v")
|
||||
bp.route("/")(handler)
|
||||
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
|
||||
app.blueprint(group, version=1.2, version_prefix="/api/v")
|
||||
@@ -132,7 +132,7 @@ def test_version_prefix_bp_group_use_registration(app, handler):
|
||||
|
||||
|
||||
def test_version_prefix_bp_group_use_route(app, handler):
|
||||
bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v")
|
||||
bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v")
|
||||
bp.route("/", version=1.3, version_prefix="/api/v")(handler)
|
||||
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
|
||||
app.blueprint(group, version=1.2, version_prefix="/stillignoring/v")
|
||||
|
||||
@@ -14,7 +14,7 @@ from sanic.server.websockets.frame import WebsocketFrameAssembler
|
||||
try:
|
||||
from unittest.mock import AsyncMock
|
||||
except ImportError:
|
||||
from asyncmock import AsyncMock # type: ignore
|
||||
from tests.asyncmock import AsyncMock # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user