HTTP/3 Support (#2378)

This commit is contained in:
Adam Hopkins 2022-06-27 11:19:26 +03:00 committed by GitHub
parent 70382f21ba
commit b59da498cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
72 changed files with 2567 additions and 437 deletions

View File

@ -20,6 +20,7 @@ exclude_lines =
noqa
NOQA
pragma: no cover
TYPE_CHECKING
omit =
site-packages
sanic/__main__.py

View File

@ -16,3 +16,10 @@ lines_after_imports = 2
lines_between_types = 1
multi_line_output = 3
profile = "black"
[[tool.mypy.overrides]]
module = [
"trustme.*",
"sanic_routing.*",
]
ignore_missing_imports = true

View File

@ -43,11 +43,8 @@ from typing import (
from urllib.parse import urlencode, urlunparse
from warnings import filterwarnings
from sanic_routing.exceptions import ( # type: ignore
FinalizationError,
NotFound,
)
from sanic_routing.route import Route # type: ignore
from sanic_routing.exceptions import FinalizationError, NotFound
from sanic_routing.route import Route
from sanic.application.ext import setup_ext
from sanic.application.state import ApplicationState, Mode, ServerStage
@ -64,6 +61,7 @@ from sanic.exceptions import (
URLBuildError,
)
from sanic.handlers import ErrorHandler
from sanic.helpers import _default
from sanic.http import Stage
from sanic.log import (
LOGGING_CONFIG_DEFAULTS,
@ -92,7 +90,7 @@ from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
try:
from sanic_ext import Extend # type: ignore
from sanic_ext.extensions.base import Extension # type: ignore
@ -949,6 +947,7 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
"response": response,
},
)
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
@ -1532,8 +1531,10 @@ class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta):
if hasattr(self, "_ext"):
self.ext._display()
if self.state.is_debug:
if self.state.is_debug and self.config.TOUCHUP is not True:
self.config.TOUCHUP = False
elif self.config.TOUCHUP is _default:
self.config.TOUCHUP = True
# Setup routers
self.signalize(self.config.TOUCHUP)

View File

@ -0,0 +1,23 @@
from enum import Enum, IntEnum, auto
class StrEnum(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
class Server(StrEnum):
SANIC = auto()
ASGI = auto()
GUNICORN = auto()
class Mode(StrEnum):
PRODUCTION = auto()
DEBUG = auto()
class ServerStage(IntEnum):
STOPPED = auto()
PARTIAL = auto()
SERVING = auto()

View File

@ -5,7 +5,7 @@ from importlib import import_module
from typing import TYPE_CHECKING
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic
try:

View File

@ -0,0 +1,86 @@
import os
import sys
import time
from contextlib import contextmanager
from queue import Queue
from threading import Thread
if os.name == "nt": # noqa
import ctypes # noqa
class _CursorInfo(ctypes.Structure):
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
class Spinner: # noqa
def __init__(self, message: str) -> None:
self.message = message
self.queue: Queue[int] = Queue()
self.spinner = self.cursor()
self.thread = Thread(target=self.run)
def start(self):
self.queue.put(1)
self.thread.start()
self.hide()
def run(self):
while self.queue.get():
output = f"\r{self.message} [{next(self.spinner)}]"
sys.stdout.write(output)
sys.stdout.flush()
time.sleep(0.1)
self.queue.put(1)
def stop(self):
self.queue.put(0)
self.thread.join()
self.show()
@staticmethod
def cursor():
while True:
for cursor in "|/-\\":
yield cursor
@staticmethod
def hide():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = False
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25l")
sys.stdout.flush()
@staticmethod
def show():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = True
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25h")
sys.stdout.flush()
@contextmanager
def loading(message: str = "Loading"): # noqa
spinner = Spinner(message)
spinner.start()
yield
spinner.stop()

View File

@ -3,42 +3,20 @@ from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from pathlib import Path
from socket import socket
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sanic.application.constants import Mode, Server, ServerStage
from sanic.log import VerbosityFilter, logger
from sanic.server.async_server import AsyncioServer
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic
class StrEnum(str, Enum):
def _generate_next_value_(name: str, *args) -> str: # type: ignore
return name.lower()
class Server(StrEnum):
SANIC = auto()
ASGI = auto()
GUNICORN = auto()
class Mode(StrEnum):
PRODUCTION = auto()
DEBUG = auto()
class ServerStage(IntEnum):
STOPPED = auto()
PARTIAL = auto()
SERVING = auto()
@dataclass
class ApplicationServerInfo:
settings: Dict[str, Any]

View File

@ -17,7 +17,7 @@ from sanic.server import ConnInfo
from sanic.server.websockets.connection import WebSocketConnection
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic

View File

@ -5,7 +5,7 @@ from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic.blueprints import Blueprint

View File

@ -21,8 +21,8 @@ from typing import (
Union,
)
from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.route import Route # type: ignore
from sanic_routing.exceptions import NotFound
from sanic_routing.route import Route
from sanic.base.root import BaseSanic
from sanic.blueprint_group import BlueprintGroup
@ -36,7 +36,7 @@ from sanic.models.handler_types import (
)
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic

View File

@ -58,10 +58,13 @@ Or, a path to a directory to run as a simple HTTP server:
os.environ.get("SANIC_RELOADER_PROCESS", "") != "true"
)
self.args: List[Any] = []
self.groups: List[Group] = []
def attach(self):
for group in Group._registry:
group.create(self.parser).attach()
instance = group.create(self.parser)
instance.attach()
self.groups.append(instance)
def run(self):
# This is to provide backwards compat -v to display version
@ -81,9 +84,13 @@ Or, a path to a directory to run as a simple HTTP server:
try:
app = self._get_app()
kwargs = self._build_run_kwargs()
app.run(**kwargs)
except ValueError:
error_logger.exception("Failed to run app")
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)
Sanic.serve()
def _precheck(self):
# # Custom TLS mismatch handling for better diagnostics
@ -163,11 +170,14 @@ Or, a path to a directory to run as a simple HTTP server:
" Example File: project/sanic_server.py -> app\n"
" Example Module: project.sanic_server.app"
)
sys.exit(1)
else:
raise e
return app
def _build_run_kwargs(self):
for group in self.groups:
group.prepare(self.args)
ssl: Union[None, dict, str, list] = []
if self.args.tlshost:
ssl.append(None)
@ -192,6 +202,7 @@ Or, a path to a directory to run as a simple HTTP server:
"unix": self.args.unix,
"verbosity": self.args.verbosity or 0,
"workers": self.args.workers,
"auto_tls": self.args.auto_tls,
}
for maybe_arg in ("auto_reload", "dev"):
@ -201,4 +212,5 @@ Or, a path to a directory to run as a simple HTTP server:
if self.args.path:
kwargs["auto_reload"] = True
kwargs["reload_dir"] = self.args.path
return kwargs

View File

@ -3,9 +3,10 @@ from __future__ import annotations
from argparse import ArgumentParser, _ArgumentGroup
from typing import List, Optional, Type, Union
from sanic_routing import __version__ as __routing_version__ # type: ignore
from sanic_routing import __version__ as __routing_version__
from sanic import __version__
from sanic.http.constants import HTTP
class Group:
@ -38,6 +39,9 @@ class Group:
"--no-" + args[0][2:], *args[1:], action="store_false", **kwargs
)
def prepare(self, args) -> None:
...
class GeneralGroup(Group):
name = None
@ -83,6 +87,44 @@ class ApplicationGroup(Group):
)
class HTTPVersionGroup(Group):
name = "HTTP version"
def attach(self):
http_values = [http.value for http in HTTP.__members__.values()]
self.container.add_argument(
"--http",
dest="http",
action="append",
choices=http_values,
type=int,
help=(
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n"
"be either 1, or 3. [default 1]"
),
)
self.container.add_argument(
"-1",
dest="http",
action="append_const",
const=1,
help=("Run Sanic server using HTTP/1.1"),
)
self.container.add_argument(
"-3",
dest="http",
action="append_const",
const=3,
help=("Run Sanic server using HTTP/3"),
)
def prepare(self, args):
if not args.http:
args.http = [1]
args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True))
class SocketGroup(Group):
name = "Socket binding"
@ -92,7 +134,6 @@ class SocketGroup(Group):
"--host",
dest="host",
type=str,
default="127.0.0.1",
help="Host address [default 127.0.0.1]",
)
self.container.add_argument(
@ -100,7 +141,6 @@ class SocketGroup(Group):
"--port",
dest="port",
type=int,
default=8000,
help="Port to serve on [default 8000]",
)
self.container.add_argument(
@ -180,11 +220,7 @@ class DevelopmentGroup(Group):
"--debug",
dest="debug",
action="store_true",
help=(
"Run the server in DEBUG mode. It includes DEBUG logging,\n"
"additional context on exceptions, and other settings\n"
"not-safe for PRODUCTION, but helpful for debugging problems."
),
help="Run the server in debug mode",
)
self.container.add_argument(
"-r",
@ -209,7 +245,16 @@ class DevelopmentGroup(Group):
"--dev",
dest="dev",
action="store_true",
help=("debug + auto reload."),
help=("debug + auto reload"),
)
self.container.add_argument(
"--auto-tls",
dest="auto_tls",
action="store_true",
help=(
"Create a temporary TLS certificate for local development "
"(requires mkcert or trustme)"
),
)

View File

@ -5,6 +5,7 @@ from os import environ
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Sequence, Union
from sanic.constants import LocalCertCreator
from sanic.errorpages import DEFAULT_FORMAT, check_error_format
from sanic.helpers import Default, _default
from sanic.http import Http
@ -26,6 +27,10 @@ DEFAULT_CONFIG = {
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
"KEEP_ALIVE": True,
"LOCAL_CERT_CREATOR": LocalCertCreator.AUTO,
"LOCAL_TLS_KEY": _default,
"LOCAL_TLS_CERT": _default,
"LOCALHOST": "localhost",
"MOTD": True,
"MOTD_DISPLAY": {},
"NOISY_EXCEPTIONS": False,
@ -38,7 +43,8 @@ DEFAULT_CONFIG = {
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
"REQUEST_TIMEOUT": 60, # 60 seconds
"RESPONSE_TIMEOUT": 60, # 60 seconds
"TOUCHUP": True,
"TLS_CERT_PASSWORD": "",
"TOUCHUP": _default,
"USE_UVLOOP": _default,
"WEBSOCKET_MAX_SIZE": 2**20, # 1 megabyte
"WEBSOCKET_PING_INTERVAL": 20,
@ -69,9 +75,13 @@ class Config(dict, metaclass=DescriptorMeta):
GRACEFUL_SHUTDOWN_TIMEOUT: float
KEEP_ALIVE_TIMEOUT: int
KEEP_ALIVE: bool
NOISY_EXCEPTIONS: bool
LOCAL_CERT_CREATOR: Union[str, LocalCertCreator]
LOCAL_TLS_KEY: Union[Path, str, Default]
LOCAL_TLS_CERT: Union[Path, str, Default]
LOCALHOST: str
MOTD: bool
MOTD_DISPLAY: Dict[str, str]
NOISY_EXCEPTIONS: bool
PROXIES_COUNT: Optional[int]
REAL_IP_HEADER: Optional[str]
REGISTER: bool
@ -82,7 +92,8 @@ class Config(dict, metaclass=DescriptorMeta):
REQUEST_TIMEOUT: int
RESPONSE_TIMEOUT: int
SERVER_NAME: str
TOUCHUP: bool
TLS_CERT_PASSWORD: str
TOUCHUP: Union[Default, bool]
USE_UVLOOP: Union[Default, bool]
WEBSOCKET_MAX_SIZE: int
WEBSOCKET_PING_INTERVAL: int
@ -157,13 +168,19 @@ class Config(dict, metaclass=DescriptorMeta):
"REQUEST_MAX_SIZE",
):
self._configure_header_size()
elif attr == "LOGO":
if attr == "LOGO":
self._LOGO = value
deprecation(
"Setting the config.LOGO is deprecated and will no longer "
"be supported starting in v22.6.",
22.6,
)
elif attr == "LOCAL_CERT_CREATOR" and not isinstance(
self.LOCAL_CERT_CREATOR, LocalCertCreator
):
self.LOCAL_CERT_CREATOR = LocalCertCreator[
self.LOCAL_CERT_CREATOR.upper()
]
@property
def LOGO(self):

View File

@ -24,5 +24,16 @@ class HTTPMethod(str, Enum):
DELETE = auto()
class LocalCertCreator(str, Enum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
AUTO = auto()
TRUSTME = auto()
MKCERT = auto()
HTTP_METHODS = tuple(HTTPMethod.__members__.values())
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
DEFAULT_LOCAL_TLS_KEY = "key.pem"
DEFAULT_LOCAL_TLS_CERT = "cert.pem"

5
sanic/http/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .constants import Stage
from .http1 import Http
__all__ = ("Http", "Stage")

29
sanic/http/constants.py Normal file
View File

@ -0,0 +1,29 @@
from enum import Enum, IntEnum
class Stage(Enum):
"""
Enum for representing the stage of the request/response cycle
| ``IDLE`` Waiting for request
| ``REQUEST`` Request headers being received
| ``HANDLER`` Headers done, handler running
| ``RESPONSE`` Response headers sent, body in progress
| ``FAILED`` Unrecoverable state (error while sending response)
|
"""
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
class HTTP(IntEnum):
VERSION_1 = 1
VERSION_3 = 3
def display(self) -> str:
value = 1.1 if self.value == 1 else self.value
return f"HTTP/{value}"

View File

@ -3,12 +3,11 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic.request import Request
from sanic.response import BaseHTTPResponse
from asyncio import CancelledError, sleep
from enum import Enum
from sanic.compat import Header
from sanic.exceptions import (
@ -20,33 +19,16 @@ from sanic.exceptions import (
)
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.log import access_logger, error_logger, logger
from sanic.touchup import TouchUpMeta
class Stage(Enum):
"""
Enum for representing the stage of the request/response cycle
| ``IDLE`` Waiting for request
| ``REQUEST`` Request headers being received
| ``HANDLER`` Headers done, handler running
| ``RESPONSE`` Response headers sent, body in progress
| ``FAILED`` Unrecoverable state (error while sending response)
|
"""
IDLE = 0 # Waiting for request
REQUEST = 1 # Request headers being received
HANDLER = 3 # Headers done, handler running
RESPONSE = 4 # Response headers sent, body in progress
FAILED = 100 # Unrecoverable state (error while sending response)
HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n"
class Http(metaclass=TouchUpMeta):
class Http(Stream, metaclass=TouchUpMeta):
"""
Internal helper for managing the HTTP request/response cycle
@ -67,7 +49,6 @@ class Http(metaclass=TouchUpMeta):
HEADER_CEILING = 16_384
HEADER_MAX_SIZE = 0
__touchup__ = (
"http1_request_header",
"http1_response_header",
@ -353,6 +334,12 @@ class Http(metaclass=TouchUpMeta):
self.response_func = self.head_response_ignored
headers["connection"] = "keep-alive" if self.keep_alive else "close"
# This header may be removed or modified by the AltSvcCheck Touchup
# service. At server start, we either remove this header from ever
# being assigned, or we change the value as required.
headers["alt-svc"] = ""
ret = format_http1_response(status, res.processed_headers)
if data:
ret += data

397
sanic/http/http3.py Normal file
View File

@ -0,0 +1,397 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from ssl import SSLContext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
from aioquic.h0.connection import H0_ALPN, H0Connection
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import (
DatagramReceived,
DataReceived,
H3Event,
HeadersReceived,
WebTransportStreamDataReceived,
)
from aioquic.quic.configuration import QuicConfiguration
from aioquic.tls import SessionTicket
from sanic.compat import Header
from sanic.constants import LocalCertCreator
from sanic.exceptions import PayloadTooLarge, SanicException, ServerError
from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.http.tls.context import CertSelector, CertSimple, SanicSSLContext
from sanic.log import Colors, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.server_types import ConnInfo
if TYPE_CHECKING:
from sanic import Sanic
from sanic.request import Request
from sanic.response import BaseHTTPResponse
from sanic.server.protocols.http_protocol import Http3Protocol
HttpConnection = Union[H0Connection, H3Connection]
class HTTP3Transport(TransportProtocol):
__slots__ = ("_protocol",)
def __init__(self, protocol: Http3Protocol):
self._protocol = protocol
def get_protocol(self) -> Http3Protocol:
return self._protocol
def get_extra_info(self, info: str, default: Any = None) -> Any:
if (
info in ("socket", "sockname", "peername")
and self._protocol._transport
):
return self._protocol._transport.get_extra_info(info, default)
elif info == "network_paths":
return self._protocol._quic._network_paths
elif info == "ssl_context":
return self._protocol.app.state.ssl
return default
class Receiver(ABC):
future: asyncio.Future
def __init__(self, transmit, protocol, request: Request) -> None:
self.transmit = transmit
self.protocol = protocol
self.request = request
@abstractmethod
async def run(self): # no cov
...
class HTTPReceiver(Receiver, Stream):
stage: Stage
request: Request
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.request_body = None
self.stage = Stage.IDLE
self.headers_sent = False
self.response: Optional[BaseHTTPResponse] = None
self.request_max_size = self.protocol.request_max_size
self.request_bytes = 0
async def run(self, exception: Optional[Exception] = None):
self.stage = Stage.HANDLER
self.head_only = self.request.method.upper() == "HEAD"
if exception:
logger.info( # no cov
f"{Colors.BLUE}[exception]: "
f"{Colors.RED}{exception}{Colors.END}",
exc_info=True,
extra={"verbosity": 1},
)
await self.error_response(exception)
else:
try:
logger.info( # no cov
f"{Colors.BLUE}[request]:{Colors.END} {self.request}",
extra={"verbosity": 1},
)
await self.protocol.request_handler(self.request)
except Exception as e: # no cov
# This should largely be handled within the request handler.
# But, just in case...
await self.run(e)
self.stage = Stage.IDLE
async def error_response(self, exception: Exception) -> None:
"""
Handle response when exception encountered
"""
# From request and handler states we can respond, otherwise be silent
app = self.protocol.app
await app.handle_exception(self.request, exception)
def _prepare_headers(
self, response: BaseHTTPResponse
) -> List[Tuple[bytes, bytes]]:
size = len(response.body) if response.body else 0
headers = response.headers
status = response.status
if not has_message_body(status) and (
size
or "content-length" in headers
or "transfer-encoding" in headers
):
headers.pop("content-length", None)
headers.pop("transfer-encoding", None)
logger.warning( # no cov
f"Message body set in response on {self.request.path}. "
f"A {status} response may only have headers, no body."
)
elif "content-length" not in headers:
if size:
headers["content-length"] = size
else:
headers["transfer-encoding"] = "chunked"
headers = [
(b":status", str(response.status).encode()),
*response.processed_headers,
]
return headers
def send_headers(self) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[send]: {Colors.GREEN}HEADERS{Colors.END}",
extra={"verbosity": 2},
)
if not self.response:
raise RuntimeError("no response")
response = self.response
headers = self._prepare_headers(response)
self.protocol.connection.send_headers(
stream_id=self.request.stream_id,
headers=headers,
)
self.headers_sent = True
self.stage = Stage.RESPONSE
if self.response.body and not self.head_only:
self._send(self.response.body, False)
elif self.head_only:
self.future.cancel()
def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse:
logger.debug( # no cov
f"{Colors.BLUE}[respond]:{Colors.END} {response}",
extra={"verbosity": 2},
)
if self.stage is not Stage.HANDLER:
self.stage = Stage.FAILED
raise RuntimeError("Response already started")
# Disconnect any earlier but unused response object
if self.response is not None:
self.response.stream = None
self.response, response.stream = response, self
return response
def receive_body(self, data: bytes) -> None:
self.request_bytes += len(data)
if self.request_bytes > self.request_max_size:
raise PayloadTooLarge("Request body exceeds the size limit")
self.request.body += data
async def send(self, data: bytes, end_stream: bool) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[send]: {Colors.GREEN}data={data.decode()} "
f"end_stream={end_stream}{Colors.END}",
extra={"verbosity": 2},
)
self._send(data, end_stream)
def _send(self, data: bytes, end_stream: bool) -> None:
if not self.headers_sent:
self.send_headers()
if self.stage is not Stage.RESPONSE:
raise ServerError(f"not ready to send: {self.stage}")
# Chunked
if (
self.response
and self.response.headers.get("transfer-encoding") == "chunked"
):
size = len(data)
if end_stream:
data = (
b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
if size
else b"0\r\n\r\n"
)
elif size:
data = b"%x\r\n%b\r\n" % (size, data)
logger.debug( # no cov
f"{Colors.BLUE}[transmitting]{Colors.END}",
extra={"verbosity": 2},
)
self.protocol.connection.send_data(
stream_id=self.request.stream_id,
data=data,
end_stream=end_stream,
)
self.transmit()
if end_stream:
self.stage = Stage.IDLE
class WebsocketReceiver(Receiver): # noqa
async def run(self):
...
class WebTransportReceiver(Receiver): # noqa
async def run(self):
...
class Http3:
HANDLER_PROPERTY_MAPPING = {
DataReceived: "stream_id",
HeadersReceived: "stream_id",
DatagramReceived: "flow_id",
WebTransportStreamDataReceived: "session_id",
}
def __init__(
self,
protocol: Http3Protocol,
transmit: Callable[[], None],
) -> None:
self.protocol = protocol
self.transmit = transmit
self.receivers: Dict[int, Receiver] = {}
def http_event_received(self, event: H3Event) -> None:
logger.debug( # no cov
f"{Colors.BLUE}[http_event_received]: "
f"{Colors.YELLOW}{event}{Colors.END}",
extra={"verbosity": 2},
)
receiver, created_new = self.get_or_make_receiver(event)
receiver = cast(HTTPReceiver, receiver)
if isinstance(event, HeadersReceived) and created_new:
receiver.future = asyncio.ensure_future(receiver.run())
elif isinstance(event, DataReceived):
try:
receiver.receive_body(event.data)
except Exception as e:
receiver.future.cancel()
receiver.future = asyncio.ensure_future(receiver.run(e))
else:
... # Intentionally here to help out Touchup
logger.debug( # no cov
f"{Colors.RED}DOING NOTHING{Colors.END}",
extra={"verbosity": 2},
)
def get_or_make_receiver(self, event: H3Event) -> Tuple[Receiver, bool]:
if (
isinstance(event, HeadersReceived)
and event.stream_id not in self.receivers
):
request = self._make_request(event)
receiver = HTTPReceiver(self.transmit, self.protocol, request)
request.stream = receiver
self.receivers[event.stream_id] = receiver
return receiver, True
else:
ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)])
return self.receivers[ident], False
def get_receiver_by_stream_id(self, stream_id: int) -> Receiver:
return self.receivers[stream_id]
def _make_request(self, event: HeadersReceived) -> Request:
headers = Header(((k.decode(), v.decode()) for k, v in event.headers))
method = headers[":method"]
path = headers[":path"]
scheme = headers.pop(":scheme", "")
authority = headers.pop(":authority", "")
if authority:
headers["host"] = authority
transport = HTTP3Transport(self.protocol)
request = self.protocol.request_class(
path.encode(),
headers,
"3",
method,
transport,
self.protocol.app,
b"",
)
request.conn_info = ConnInfo(transport)
request._stream_id = event.stream_id
request._scheme = scheme
return request
class SessionTicketStore:
"""
Simple in-memory store for session tickets.
"""
def __init__(self) -> None:
self.tickets: Dict[bytes, SessionTicket] = {}
def add(self, ticket: SessionTicket) -> None:
self.tickets[ticket.ticket] = ticket
def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)
def get_config(
app: Sanic, ssl: Union[SanicSSLContext, CertSelector, SSLContext]
):
# TODO:
# - proper selection needed if servince with multiple certs insted of
# just taking the first
if isinstance(ssl, CertSelector):
ssl = cast(SanicSSLContext, ssl.sanic_select[0])
if app.config.LOCAL_CERT_CREATOR is LocalCertCreator.TRUSTME:
raise SanicException(
"Sorry, you cannot currently use trustme as a local certificate "
"generator for an HTTP/3 server. This is not yet supported. You "
"should be able to use mkcert instead. For more information, see: "
"https://github.com/aiortc/aioquic/issues/295."
)
if not isinstance(ssl, CertSimple):
raise SanicException("SSLContext is not CertSimple")
config = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
is_client=False,
max_datagram_frame_size=65536,
)
password = app.config.TLS_CERT_PASSWORD or None
config.load_cert_chain(
ssl.sanic["cert"], ssl.sanic["key"], password=password
)
return config

27
sanic/http/stream.py Normal file
View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple, Union
from sanic.http.constants import Stage
if TYPE_CHECKING:
from sanic.response import BaseHTTPResponse
from sanic.server.protocols.http_protocol import HttpProtocol
class Stream:
stage: Stage
response: Optional[BaseHTTPResponse]
protocol: HttpProtocol
url: Optional[str]
request_body: Optional[bytes]
request_max_size: Union[int, float]
__touchup__: Tuple[str, ...] = tuple()
__slots__ = ()
def respond(
self, response: BaseHTTPResponse
) -> BaseHTTPResponse: # no cov
raise NotImplementedError("Not implemented")

View File

@ -0,0 +1,5 @@
from .context import process_to_context
from .creators import get_ssl_context
__all__ = ("get_ssl_context", "process_to_context")

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import os
import ssl
from typing import Iterable, Optional, Union
from typing import Any, Dict, Iterable, Optional, Union
from sanic.log import logger
@ -77,65 +79,6 @@ def load_cert_dir(p: str) -> ssl.SSLContext:
return CertSimple(certfile, keyfile)
class CertSimple(ssl.SSLContext):
"""A wrapper for creating SSLContext with a sanic attribute."""
def __new__(cls, cert, key, **kw):
# try common aliases, rename to cert/key
certfile = kw["cert"] = kw.pop("certificate", None) or cert
keyfile = kw["key"] = kw.pop("keyfile", None) or key
password = kw.pop("password", None)
if not certfile or not keyfile:
raise ValueError("SSL dict needs filenames for cert and key.")
subject = {}
if "names" not in kw:
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
kw["names"] = [
name
for t, name in cert["subjectAltName"]
if t in ["DNS", "IP Address"]
]
subject = {k: v for item in cert["subject"] for k, v in item}
self = create_context(certfile, keyfile, password)
self.__class__ = cls
self.sanic = {**subject, **kw}
return self
def __init__(self, cert, key, **kw):
pass # Do not call super().__init__ because it is already initialized
class CertSelector(ssl.SSLContext):
"""Automatically select SSL certificate based on the hostname that the
client is trying to access, via SSL SNI. Paths to certificate folders
with privkey.pem and fullchain.pem in them should be provided, and
will be matched in the order given whenever there is a new connection.
"""
def __new__(cls, ctxs):
return super().__new__(cls)
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
super().__init__()
self.sni_callback = selector_sni_callback # type: ignore
self.sanic_select = []
self.sanic_fallback = None
all_names = []
for i, ctx in enumerate(ctxs):
if not ctx:
continue
names = dict(getattr(ctx, "sanic", {})).get("names", [])
all_names += names
self.sanic_select.append(ctx)
if i == 0:
self.sanic_fallback = ctx
if not all_names:
raise ValueError(
"No certificates with SubjectAlternativeNames found."
)
logger.info(f"Certificate vhosts: {', '.join(all_names)}")
def find_cert(self: CertSelector, server_name: str):
"""Find the first certificate that matches the given SNI.
@ -194,3 +137,73 @@ def server_name_callback(
) -> None:
"""Store the received SNI as sslobj.sanic_server_name."""
sslobj.sanic_server_name = server_name # type: ignore
class SanicSSLContext(ssl.SSLContext):
sanic: Dict[str, os.PathLike]
@classmethod
def create_from_ssl_context(cls, context: ssl.SSLContext):
context.__class__ = cls
return context
class CertSimple(SanicSSLContext):
"""A wrapper for creating SSLContext with a sanic attribute."""
sanic: Dict[str, Any]
def __new__(cls, cert, key, **kw):
# try common aliases, rename to cert/key
certfile = kw["cert"] = kw.pop("certificate", None) or cert
keyfile = kw["key"] = kw.pop("keyfile", None) or key
password = kw.pop("password", None)
if not certfile or not keyfile:
raise ValueError("SSL dict needs filenames for cert and key.")
subject = {}
if "names" not in kw:
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
kw["names"] = [
name
for t, name in cert["subjectAltName"]
if t in ["DNS", "IP Address"]
]
subject = {k: v for item in cert["subject"] for k, v in item}
self = create_context(certfile, keyfile, password)
self.__class__ = cls
self.sanic = {**subject, **kw}
return self
def __init__(self, cert, key, **kw):
pass # Do not call super().__init__ because it is already initialized
class CertSelector(ssl.SSLContext):
"""Automatically select SSL certificate based on the hostname that the
client is trying to access, via SSL SNI. Paths to certificate folders
with privkey.pem and fullchain.pem in them should be provided, and
will be matched in the order given whenever there is a new connection.
"""
def __new__(cls, ctxs):
return super().__new__(cls)
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
super().__init__()
self.sni_callback = selector_sni_callback # type: ignore
self.sanic_select = []
self.sanic_fallback = None
all_names = []
for i, ctx in enumerate(ctxs):
if not ctx:
continue
names = dict(getattr(ctx, "sanic", {})).get("names", [])
all_names += names
self.sanic_select.append(ctx)
if i == 0:
self.sanic_fallback = ctx
if not all_names:
raise ValueError(
"No certificates with SubjectAlternativeNames found."
)
logger.info(f"Certificate vhosts: {', '.join(all_names)}")

278
sanic/http/tls/creators.py Normal file
View File

@ -0,0 +1,278 @@
from __future__ import annotations
import ssl
import subprocess
import sys
from abc import ABC, abstractmethod
from contextlib import suppress
from pathlib import Path
from tempfile import mkdtemp
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union, cast
from sanic.application.constants import Mode
from sanic.application.spinner import loading
from sanic.constants import (
DEFAULT_LOCAL_TLS_CERT,
DEFAULT_LOCAL_TLS_KEY,
LocalCertCreator,
)
from sanic.exceptions import SanicException
from sanic.helpers import Default
from sanic.http.tls.context import CertSimple, SanicSSLContext
try:
import trustme
TRUSTME_INSTALLED = True
except (ImportError, ModuleNotFoundError):
trustme = ModuleType("trustme")
TRUSTME_INSTALLED = False
if TYPE_CHECKING:
from sanic import Sanic
# Only allow secure ciphers, notably leaving out AES-CBC mode
# OpenSSL chooses ECDSA or RSA depending on the cert in use
CIPHERS_TLS12 = [
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
]
def _make_path(maybe_path: Union[Path, str], tmpdir: Optional[Path]) -> Path:
if isinstance(maybe_path, Path):
return maybe_path
else:
path = Path(maybe_path)
if not path.exists():
if not tmpdir:
raise RuntimeError("Reached an unknown state. No tmpdir.")
return tmpdir / maybe_path
return path
def get_ssl_context(
app: Sanic, ssl: Optional[ssl.SSLContext]
) -> ssl.SSLContext:
if ssl:
return ssl
if app.state.mode is Mode.PRODUCTION:
raise SanicException(
"Cannot run Sanic as an HTTPS server in PRODUCTION mode "
"without passing a TLS certificate. If you are developing "
"locally, please enable DEVELOPMENT mode and Sanic will "
"generate a localhost TLS certificate. For more information "
"please see: ___."
)
creator = CertCreator.select(
app,
cast(LocalCertCreator, app.config.LOCAL_CERT_CREATOR),
app.config.LOCAL_TLS_KEY,
app.config.LOCAL_TLS_CERT,
)
context = creator.generate_cert(app.config.LOCALHOST)
return context
class CertCreator(ABC):
def __init__(self, app, key, cert) -> None:
self.app = app
self.key = key
self.cert = cert
self.tmpdir = None
if isinstance(self.key, Default) or isinstance(self.cert, Default):
self.tmpdir = Path(mkdtemp())
key = (
DEFAULT_LOCAL_TLS_KEY
if isinstance(self.key, Default)
else self.key
)
cert = (
DEFAULT_LOCAL_TLS_CERT
if isinstance(self.cert, Default)
else self.cert
)
self.key_path = _make_path(key, self.tmpdir)
self.cert_path = _make_path(cert, self.tmpdir)
@abstractmethod
def check_supported(self) -> None: # no cov
...
@abstractmethod
def generate_cert(self, localhost: str) -> ssl.SSLContext: # no cov
...
@classmethod
def select(
cls,
app: Sanic,
cert_creator: LocalCertCreator,
local_tls_key,
local_tls_cert,
) -> CertCreator:
creator: Optional[CertCreator] = None
cert_creator_options: Tuple[
Tuple[Type[CertCreator], LocalCertCreator], ...
] = (
(MkcertCreator, LocalCertCreator.MKCERT),
(TrustmeCreator, LocalCertCreator.TRUSTME),
)
for creator_class, local_creator in cert_creator_options:
creator = cls._try_select(
app,
creator,
creator_class,
local_creator,
cert_creator,
local_tls_key,
local_tls_cert,
)
if creator:
break
if not creator:
raise SanicException(
"Sanic could not find package to create a TLS certificate. "
"You must have either mkcert or trustme installed. See "
"_____ for more details."
)
return creator
@staticmethod
def _try_select(
app: Sanic,
creator: Optional[CertCreator],
creator_class: Type[CertCreator],
creator_requirement: LocalCertCreator,
creator_requested: LocalCertCreator,
local_tls_key,
local_tls_cert,
):
if creator or (
creator_requested is not LocalCertCreator.AUTO
and creator_requested is not creator_requirement
):
return creator
instance = creator_class(app, local_tls_key, local_tls_cert)
try:
instance.check_supported()
except SanicException:
if creator_requested is creator_requirement:
raise
else:
return None
return instance
class MkcertCreator(CertCreator):
def check_supported(self) -> None:
try:
subprocess.run( # nosec B603 B607
["mkcert", "-help"],
check=True,
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
except Exception as e:
raise SanicException(
"Sanic is attempting to use mkcert to generate local TLS "
"certificates since you did not supply a certificate, but "
"one is required. Sanic cannot proceed since mkcert does not "
"appear to be installed. Alternatively, you can use trustme. "
"Please install mkcert, trustme, or supply TLS certificates "
"to proceed. Installation instructions can be found here: "
"https://github.com/FiloSottile/mkcert.\n"
"Find out more information about your options here: "
"_____"
) from e
def generate_cert(self, localhost: str) -> ssl.SSLContext:
try:
if not self.cert_path.exists():
message = "Generating TLS certificate"
# TODO: Validate input for security
with loading(message):
cmd = [
"mkcert",
"-key-file",
str(self.key_path),
"-cert-file",
str(self.cert_path),
localhost,
]
resp = subprocess.run( # nosec B603
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
sys.stdout.write("\r" + " " * (len(message) + 4))
sys.stdout.flush()
sys.stdout.write(resp.stdout)
finally:
@self.app.main_process_stop
async def cleanup(*_): # no cov
if self.tmpdir:
with suppress(FileNotFoundError):
self.key_path.unlink()
self.cert_path.unlink()
self.tmpdir.rmdir()
return CertSimple(self.cert_path, self.key_path)
class TrustmeCreator(CertCreator):
def check_supported(self) -> None:
if not TRUSTME_INSTALLED:
raise SanicException(
"Sanic is attempting to use trustme to generate local TLS "
"certificates since you did not supply a certificate, but "
"one is required. Sanic cannot proceed since trustme does not "
"appear to be installed. Alternatively, you can use mkcert. "
"Please install mkcert, trustme, or supply TLS certificates "
"to proceed. Installation instructions can be found here: "
"https://github.com/python-trio/trustme.\n"
"Find out more information about your options here: "
"_____"
)
def generate_cert(self, localhost: str) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sanic_context = SanicSSLContext.create_from_ssl_context(context)
sanic_context.sanic = {
"cert": self.cert_path.absolute(),
"key": self.key_path.absolute(),
}
ca = trustme.CA()
server_cert = ca.issue_cert(localhost)
server_cert.configure_cert(sanic_context)
ca.configure_trust(context)
ca.cert_pem.write_to_path(str(self.cert_path.absolute()))
server_cert.private_key_and_cert_chain_pem.write_to_path(
str(self.key_path.absolute())
)
return context

View File

@ -21,7 +21,7 @@ from typing import (
)
from urllib.parse import unquote
from sanic_routing.route import Route # type: ignore
from sanic_routing.route import Route
from sanic.base.meta import SanicMeta
from sanic.compat import stat_async

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import os
import platform
import sys
from asyncio import (
AbstractEventLoop,
@ -18,7 +19,18 @@ from importlib import import_module
from pathlib import Path
from socket import socket
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
from sanic import reloader_helpers
from sanic.application.logo import get_logo
@ -27,7 +39,9 @@ from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS, is_atty
from sanic.helpers import _default
from sanic.log import Colors, error_logger, logger
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context
from sanic.log import Colors, deprecation, error_logger, logger
from sanic.models.handler_types import ListenerType
from sanic.server import Signal as ServerSignal
from sanic.server import try_use_uvloop
@ -36,16 +50,22 @@ from sanic.server.events import trigger_events
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
from sanic.server.runners import serve, serve_multiple, serve_single
from sanic.tls import process_to_context
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic
from sanic.application.state import ApplicationState
from sanic.config import Config
SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext")
if sys.version_info < (3, 8):
HTTPVersion = Union[HTTP, int]
else:
from typing import Literal
HTTPVersion = Union[HTTP, Literal[1], Literal[3]]
class RunnerMixin(metaclass=SanicMeta):
_app_registry: Dict[str, Sanic]
@ -66,6 +86,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev: bool = False,
debug: bool = False,
auto_reload: Optional[bool] = None,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None,
workers: int = 1,
@ -81,6 +102,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False,
verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None,
auto_tls: bool = False,
) -> None:
"""
Run the HTTP Server and listen until keyboard interrupt or term
@ -124,6 +146,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev=dev,
debug=debug,
auto_reload=auto_reload,
version=version,
ssl=ssl,
sock=sock,
workers=workers,
@ -139,6 +162,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast=fast,
verbosity=verbosity,
motd_display=motd_display,
auto_tls=auto_tls,
)
self.__class__.serve(primary=self) # type: ignore
@ -151,6 +175,7 @@ class RunnerMixin(metaclass=SanicMeta):
dev: bool = False,
debug: bool = False,
auto_reload: Optional[bool] = None,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None,
workers: int = 1,
@ -166,7 +191,15 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False,
verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None,
auto_tls: bool = False,
) -> None:
if version == 3 and self.state.server_info:
raise RuntimeError(
"Serving HTTP/3 instances as a secondary server is "
"not supported. There can only be a single HTTP/3 worker "
"and it must be the first instance prepared."
)
if dev:
debug = True
auto_reload = True
@ -208,7 +241,7 @@ class RunnerMixin(metaclass=SanicMeta):
return
if sock is None:
host, port = host or "127.0.0.1", port or 8000
host, port = self.get_address(host, port, version, auto_tls)
if protocol is None:
protocol = (
@ -236,6 +269,7 @@ class RunnerMixin(metaclass=SanicMeta):
host=host,
port=port,
debug=debug,
version=version,
ssl=ssl,
sock=sock,
unix=unix,
@ -243,6 +277,7 @@ class RunnerMixin(metaclass=SanicMeta):
protocol=protocol,
backlog=backlog,
register_sys_signals=register_sys_signals,
auto_tls=auto_tls,
)
self.state.server_info.append(
ApplicationServerInfo(settings=server_settings)
@ -312,7 +347,7 @@ class RunnerMixin(metaclass=SanicMeta):
"""
if sock is None:
host, port = host or "127.0.0.1", port or 8000
host, port = host, port = self.get_address(host, port)
if protocol is None:
protocol = (
@ -377,6 +412,7 @@ class RunnerMixin(metaclass=SanicMeta):
host: Optional[str] = None,
port: Optional[int] = None,
debug: bool = False,
version: HTTPVersion = HTTP.VERSION_1,
ssl: Union[None, SSLContext, dict, str, list, tuple] = None,
sock: Optional[socket] = None,
unix: Optional[str] = None,
@ -386,6 +422,7 @@ class RunnerMixin(metaclass=SanicMeta):
backlog: int = 100,
register_sys_signals: bool = True,
run_async: bool = False,
auto_tls: bool = False,
) -> Dict[str, Any]:
"""Helper function used by `run` and `create_server`."""
if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0:
@ -395,11 +432,18 @@ class RunnerMixin(metaclass=SanicMeta):
"#proxy-configuration"
)
ssl = process_to_context(ssl)
if not self.state.is_debug:
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION
if isinstance(version, int):
version = HTTP(version)
ssl = process_to_context(ssl)
if version is HTTP.VERSION_3 or auto_tls:
if TYPE_CHECKING:
self = cast(Sanic, self)
ssl = get_ssl_context(self, ssl)
self.state.host = host or ""
self.state.port = port or 0
self.state.workers = workers
@ -411,6 +455,7 @@ class RunnerMixin(metaclass=SanicMeta):
"protocol": protocol,
"host": host,
"port": port,
"version": version,
"sock": sock,
"unix": unix,
"ssl": ssl,
@ -421,7 +466,7 @@ class RunnerMixin(metaclass=SanicMeta):
"backlog": backlog,
}
self.motd(self.serve_location)
self.motd(server_settings=server_settings)
if is_atty() and not self.state.is_debug:
error_logger.warning(
@ -447,7 +492,19 @@ class RunnerMixin(metaclass=SanicMeta):
return server_settings
def motd(self, serve_location):
def motd(
self,
serve_location: str = "",
server_settings: Optional[Dict[str, Any]] = None,
):
if serve_location:
deprecation(
"Specifying a serve_location in the MOTD is deprecated and "
"will be removed.",
22.9,
)
else:
serve_location = self.get_server_location(server_settings)
if self.config.MOTD:
mode = [f"{self.state.mode},"]
if self.state.fast:
@ -460,9 +517,19 @@ class RunnerMixin(metaclass=SanicMeta):
else:
mode.append(f"w/ {self.state.workers} workers")
if server_settings:
server = ", ".join(
(
self.state.server,
server_settings["version"].display(), # type: ignore
)
)
else:
server = ""
display = {
"mode": " ".join(mode),
"server": self.state.server,
"server": server,
"python": platform.python_version(),
"platform": platform.platform(),
}
@ -486,7 +553,9 @@ class RunnerMixin(metaclass=SanicMeta):
module_name = package_name.replace("-", "_")
try:
module = import_module(module_name)
packages.append(f"{package_name}=={module.__version__}")
packages.append(
f"{package_name}=={module.__version__}" # type: ignore
)
except ImportError:
...
@ -506,25 +575,50 @@ class RunnerMixin(metaclass=SanicMeta):
@property
def serve_location(self) -> str:
server_settings = self.state.server_info[0].settings
return self.get_server_location(server_settings)
@staticmethod
def get_server_location(
server_settings: Optional[Dict[str, Any]] = None
) -> str:
serve_location = ""
proto = "http"
if self.state.ssl is not None:
if not server_settings:
return serve_location
if server_settings["ssl"] is not None:
proto = "https"
if self.state.unix:
serve_location = f"{self.state.unix} {proto}://..."
elif self.state.sock:
serve_location = f"{self.state.sock.getsockname()} {proto}://..."
elif self.state.host and self.state.port:
if server_settings["unix"]:
serve_location = f'{server_settings["unix"]} {proto}://...'
elif server_settings["sock"]:
serve_location = (
f'{server_settings["sock"].getsockname()} {proto}://...'
)
elif server_settings["host"] and server_settings["port"]:
# colon(:) is legal for a host only in an ipv6 address
display_host = (
f"[{self.state.host}]"
if ":" in self.state.host
else self.state.host
f'[{server_settings["host"]}]'
if ":" in server_settings["host"]
else server_settings["host"]
)
serve_location = (
f'{proto}://{display_host}:{server_settings["port"]}'
)
serve_location = f"{proto}://{display_host}:{self.state.port}"
return serve_location
@staticmethod
def get_address(
host: Optional[str],
port: Optional[int],
version: HTTPVersion = HTTP.VERSION_1,
auto_tls: bool = False,
) -> Tuple[str, int]:
host = host or "127.0.0.1"
port = port or (8443 if (version == 3 or auto_tls) else 8000)
return host, port
@classmethod
def should_auto_reload(cls) -> bool:
return any(app.state.auto_reload for app in cls._app_registry.values())

View File

@ -4,6 +4,7 @@ import sys
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from sanic.exceptions import BadRequest
from sanic.models.protocol_types import TransportProtocol
from sanic.server.websockets.connection import WebSocketConnection
@ -56,7 +57,7 @@ class MockProtocol: # no cov
await self._not_paused.wait()
class MockTransport: # no cov
class MockTransport(TransportProtocol): # no cov
_protocol: Optional[MockProtocol]
def __init__(
@ -68,17 +69,19 @@ class MockTransport: # no cov
self._protocol = None
self.loop = None
def get_protocol(self) -> MockProtocol:
def get_protocol(self) -> MockProtocol: # type: ignore
if not self._protocol:
self._protocol = MockProtocol(self, self.loop)
return self._protocol
def get_extra_info(self, info: str) -> Union[str, bool, None]:
def get_extra_info(
self, info: str, default=None
) -> Optional[Union[str, bool]]:
if info == "peername":
return self.scope.get("client")
elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"]
return None
return default
def get_websocket_connection(self) -> WebSocketConnection:
try:

View File

@ -1,32 +1,22 @@
from __future__ import annotations
import sys
from typing import Any, AnyStr, TypeVar, Union
from asyncio import BaseTransport
from typing import TYPE_CHECKING, Any, AnyStr
from sanic.models.asgi import ASGIScope
if TYPE_CHECKING:
from sanic.models.asgi import ASGIScope
if sys.version_info < (3, 8):
from asyncio import BaseTransport
# from sanic.models.asgi import MockTransport
MockTransport = TypeVar("MockTransport")
TransportProtocol = Union[MockTransport, BaseTransport]
Range = Any
HTMLProtocol = Any
else:
# Protocol is a 3.8+ feature
from typing import Protocol
class TransportProtocol(Protocol):
scope: ASGIScope
def get_protocol(self):
...
def get_extra_info(self, info: str) -> Union[str, bool, None]:
...
class HTMLProtocol(Protocol):
def __html__(self) -> AnyStr:
...
@ -46,3 +36,8 @@ else:
def total(self) -> int:
...
class TransportProtocol(BaseTransport):
scope: ASGIScope
__slots__ = ()

View File

@ -1,8 +1,8 @@
from __future__ import annotations
from ssl import SSLObject
from ssl import SSLContext, SSLObject
from types import SimpleNamespace
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from sanic.models.protocol_types import TransportProtocol
@ -28,6 +28,7 @@ class ConnInfo:
"sockname",
"ssl",
"cert",
"network_paths",
)
def __init__(self, transport: TransportProtocol, unix=None):
@ -40,17 +41,22 @@ class ConnInfo:
self.ssl = False
self.server_name = ""
self.cert: Dict[str, Any] = {}
self.network_paths: List[Any] = []
sslobj: Optional[SSLObject] = transport.get_extra_info(
"ssl_object"
) # type: ignore
sslctx: Optional[SSLContext] = transport.get_extra_info(
"ssl_context"
) # type: ignore
if sslobj:
self.ssl = True
self.server_name = getattr(sslobj, "sanic_server_name", None) or ""
self.cert = dict(getattr(sslobj.context, "sanic", {}))
if sslctx and not self.cert:
self.cert = dict(getattr(sslctx, "sanic", {}))
if isinstance(addr, str): # UNIX socket
self.server = unix or addr
return
# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
@ -59,6 +65,9 @@ class ConnInfo:
if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername")
self.network_paths = transport.get_extra_info( # type: ignore
"network_paths"
)
if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from contextvars import ContextVar
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
@ -13,13 +14,15 @@ from typing import (
Union,
)
from sanic_routing.route import Route # type: ignore
from sanic_routing.route import Route
from sanic.http.constants import HTTP # type: ignore
from sanic.http.stream import Stream
from sanic.models.asgi import ASGIScope
from sanic.models.http_types import Credentials
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic.server import ConnInfo
from sanic.app import Sanic
@ -47,7 +50,7 @@ from sanic.headers import (
parse_host,
parse_xforwarded,
)
from sanic.http import Http, Stage
from sanic.http import Stage
from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse
@ -94,7 +97,9 @@ class Request:
"_port",
"_protocol",
"_remote_addr",
"_scheme",
"_socket",
"_stream_id",
"_match_info",
"_name",
"app",
@ -131,6 +136,7 @@ class Request:
transport: TransportProtocol,
app: Sanic,
head: bytes = b"",
stream_id: int = 0,
):
self.raw_url = url_bytes
@ -140,6 +146,7 @@ class Request:
raise BadURL(f"Bad URL: {url_bytes.decode()}")
self._id: Optional[Union[uuid.UUID, str, int]] = None
self._name: Optional[str] = None
self._stream_id = stream_id
self.app = app
self.headers = Header(headers)
@ -166,12 +173,12 @@ class Request:
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self.request_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
self._cookies: Optional[Dict[str, str]] = None
self._match_info: Dict[str, Any] = {}
self.stream: Optional[Http] = None
self.route: Optional[Route] = None
self._protocol = None
self.responded: bool = False
def __repr__(self):
class_name = self.__class__.__name__
@ -188,6 +195,14 @@ class Request:
def generate_id(*_):
return uuid.uuid4()
@property
def stream_id(self):
if self.protocol.version is not HTTP.VERSION_3:
raise ServerError(
"Stream ID is only a property of a HTTP/3 request"
)
return self._stream_id
def reset_response(self):
try:
if (
@ -274,6 +289,9 @@ class Request:
# Connect the response
if isinstance(response, BaseHTTPResponse) and self.stream:
response = self.stream.respond(response)
if isawaitable(response):
response = await response # type: ignore
# Run response middleware
try:
response = await self.app._run_response_middleware(
@ -668,6 +686,10 @@ class Request:
"""
return self._parsed_url.path.decode("utf-8")
@property
def network_paths(self):
return self.conn_info.network_paths
# Proxy properties (using SERVER_NAME/forwarded/request/transport info)
@property
@ -721,6 +743,7 @@ class Request:
:return: http|https|ws|wss or arbitrary value given by the headers.
:rtype: str
"""
if not hasattr(self, "_scheme"):
if "//" in self.app.config.get("SERVER_NAME", ""):
return self.app.config.SERVER_NAME.split("//")[0]
if "proto" in self.forwarded:
@ -736,8 +759,9 @@ class Request:
if self.transport.get_extra_info("sslcontext"):
scheme += "s"
self._scheme = scheme
return scheme
return self._scheme
@property
def host(self) -> str:

View File

@ -38,6 +38,7 @@ from sanic.models.protocol_types import HTMLProtocol, Range
if TYPE_CHECKING:
from sanic.asgi import ASGIApp
from sanic.http.http3 import HTTPReceiver
from sanic.request import Request
else:
Request = TypeVar("Request")
@ -74,11 +75,15 @@ class BaseHTTPResponse:
self.asgi: bool = False
self.body: Optional[bytes] = None
self.content_type: Optional[str] = None
self.stream: Optional[Union[Http, ASGIApp]] = None
self.stream: Optional[Union[Http, ASGIApp, HTTPReceiver]] = None
self.status: int = None
self.headers = Header({})
self._cookies: Optional[CookieJar] = None
def __repr__(self):
class_name = self.__class__.__name__
return f"<{class_name}: {self.status} {self.content_type}>"
def _encode_body(self, data: Optional[AnyStr]):
if data is None:
return b""
@ -157,7 +162,10 @@ class BaseHTTPResponse:
if hasattr(data, "encode")
else data or b""
)
await self.stream.send(data, end_stream=end_stream)
await self.stream.send(
data, # type: ignore
end_stream=end_stream or False,
)
class HTTPResponse(BaseHTTPResponse):

View File

@ -5,12 +5,10 @@ from inspect import signature
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID
from sanic_routing import BaseRouter # type: ignore
from sanic_routing.exceptions import NoMethod # type: ignore
from sanic_routing.exceptions import (
NotFound as RoutingNotFound, # type: ignore
)
from sanic_routing.route import Route # type: ignore
from sanic_routing import BaseRouter
from sanic_routing.exceptions import NoMethod
from sanic_routing.exceptions import NotFound as RoutingNotFound
from sanic_routing.route import Route
from sanic.constants import HTTP_METHODS
from sanic.errorpages import check_error_format

View File

@ -4,7 +4,7 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic.app import Sanic
import asyncio

View File

@ -2,10 +2,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from aioquic.h3.connection import H3_ALPN, H3Connection
from sanic.http.constants import HTTP
from sanic.http.http3 import Http3
from sanic.touchup.meta import TouchUpMeta
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic.app import Sanic
import sys
@ -13,24 +17,68 @@ import sys
from asyncio import CancelledError
from time import monotonic as current_time
from aioquic.asyncio import QuicConnectionProtocol
from aioquic.quic.events import (
DatagramFrameReceived,
ProtocolNegotiated,
QuicEvent,
)
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import error_logger, logger
from sanic.log import Colors, error_logger, logger
from sanic.models.server_types import ConnInfo
from sanic.request import Request
from sanic.server.protocols.base_protocol import SanicProtocol
class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
class HttpProtocolMixin:
__slots__ = ()
__version__: HTTP
def _setup_connection(self, *args, **kwargs):
self._http = self.HTTP_CLASS(self, *args, **kwargs)
self._time = current_time()
try:
self.check_timeouts()
except AttributeError:
...
def _setup(self):
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
@property
def http(self):
if not hasattr(self, "_http"):
return None
return self._http
@property
def version(self):
return self.__class__.__version__
class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta):
"""
This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport
"""
HTTP_CLASS = Http
__touchup__ = (
"send",
"connection_task",
)
__version__ = HTTP.VERSION_1
__slots__ = (
# request params
"request",
@ -72,25 +120,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
unix=unix,
)
self.url = None
self.request: Optional[Request] = None
self.access_log = self.app.config.ACCESS_LOG
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
self.response_timeout = self.app.config.RESPONSE_TIMEOUT
self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT
self.request_max_size = self.app.config.REQUEST_MAX_SIZE
self.request_class = self.app.request_class or Request
self.state = state if state else {}
self._setup()
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._exception = None
def _setup_connection(self):
self._http = Http(self)
self._time = current_time()
self.check_timeouts()
async def connection_task(self): # no cov
"""
Run a HTTP connection.
@ -241,3 +276,39 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
self._data_received.set()
except Exception:
error_logger.exception("protocol.data_received")
class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol):
HTTP_CLASS = Http3
__version__ = HTTP.VERSION_3
def __init__(self, *args, app: Sanic, **kwargs) -> None:
self.app = app
super().__init__(*args, **kwargs)
self._setup()
self._connection: Optional[H3Connection] = None
def quic_event_received(self, event: QuicEvent) -> None:
logger.debug(
f"{Colors.BLUE}[quic_event_received]: "
f"{Colors.PURPLE}{event}{Colors.END}",
extra={"verbosity": 2},
)
if isinstance(event, ProtocolNegotiated):
self._setup_connection(transmit=self.transmit)
if event.alpn_protocol in H3_ALPN:
self._connection = H3Connection(
self._quic, enable_webtransport=True
)
elif isinstance(event, DatagramFrameReceived):
if event.data == b"quack":
self._quic.send_datagram_frame(b"quack-ack")
# pass event to the HTTP layer
if self._connection is not None:
for http_event in self._connection.handle_event(event):
self._http.http_event_received(http_event)
@property
def connection(self) -> Optional[H3Connection]:
return self._connection

View File

@ -11,7 +11,7 @@ from sanic.server import HttpProtocol
from ..websockets.impl import WebsocketImplProtocol
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from websockets import http11

View File

@ -6,6 +6,8 @@ from ssl import SSLContext
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from sanic.config import Config
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context
from sanic.server.events import trigger_events
@ -21,12 +23,15 @@ from functools import partial
from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from aioquic.asyncio import serve as quic_serve
from sanic.application.ext import setup_ext
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.http.http3 import SessionTicketStore, get_config
from sanic.log import error_logger, logger
from sanic.models.server_types import Signal
from sanic.server.async_server import AsyncioServer
from sanic.server.protocols.http_protocol import HttpProtocol
from sanic.server.protocols.http_protocol import Http3Protocol, HttpProtocol
from sanic.server.socket import (
bind_socket,
bind_unix_socket,
@ -52,6 +57,7 @@ def serve(
signal=Signal(),
state=None,
asyncio_server_kwargs=None,
version=HTTP.VERSION_1,
):
"""Start asynchronous HTTP Server on an individual process.
@ -88,6 +94,87 @@ def serve(
app.asgi = False
if version is HTTP.VERSION_3:
return _serve_http_3(host, port, app, loop, ssl)
return _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
)
def _setup_system_signals(
app: Sanic,
run_multiple: bool,
register_sys_signals: bool,
loop: asyncio.AbstractEventLoop,
) -> None:
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
os.environ["SANIC_WORKER_PROCESS"] = "true"
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
def _run_server_forever(loop, before_stop, after_stop, cleanup, unix):
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
except KeyboardInterrupt:
pass
finally:
logger.info("Stopping worker [%s]", pid)
loop.run_until_complete(before_stop())
if cleanup:
cleanup()
loop.run_until_complete(after_stop())
remove_unix_socket(unix)
def _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
):
connections = connections if connections is not None else set()
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
server = partial(
@ -135,30 +222,7 @@ def serve(
error_logger.exception("Unable to start server", exc_info=True)
return
# Ignore SIGINT when run_multiple
if run_multiple:
signal_func(SIGINT, SIG_IGN)
os.environ["SANIC_WORKER_PROCESS"] = "true"
# Register signals for graceful termination
if register_sys_signals:
if OS_IS_WINDOWS:
ctrlc_workaround_for_windows(app)
else:
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
loop.add_signal_handler(_signal, app.stop)
loop.run_until_complete(app._server_event("init", "after"))
pid = os.getpid()
try:
logger.info("Starting worker [%s]", pid)
loop.run_forever()
finally:
logger.info("Stopping worker [%s]", pid)
# Run the on_stop function if provided
loop.run_until_complete(app._server_event("shutdown", "before"))
def _cleanup():
# Wait for event loop to finish and all connections to drain
http_server.close()
loop.run_until_complete(http_server.wait_closed())
@ -188,8 +252,51 @@ def serve(
conn.websocket.fail_connection(code=1001)
else:
conn.abort()
loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix)
_setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(app._server_event("init", "after"))
_run_server_forever(
loop,
partial(app._server_event, "shutdown", "before"),
partial(app._server_event, "shutdown", "after"),
_cleanup,
unix,
)
def _serve_http_3(
host,
port,
app,
loop,
ssl,
register_sys_signals: bool = True,
run_multiple: bool = False,
):
protocol = partial(Http3Protocol, app=app)
ticket_store = SessionTicketStore()
ssl_context = get_ssl_context(app, ssl)
config = get_config(app, ssl_context)
coro = quic_serve(
host,
port,
configuration=config,
create_protocol=protocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
)
server = AsyncioServer(app, loop, coro, [])
loop.run_until_complete(server.startup())
loop.run_until_complete(server.before_start())
loop.run_until_complete(server)
_setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(server.after_start())
# TODO: Create connection cleanup and graceful shutdown
cleanup = None
_run_server_forever(
loop, server.before_stop, server.after_stop, cleanup, None
)
def serve_single(server_settings):

View File

@ -9,7 +9,7 @@ from websockets.typing import Data
from sanic.exceptions import ServerError
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from .impl import WebsocketImplProtocol
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
@ -37,7 +37,7 @@ class WebsocketFrameAssembler:
"get_id",
"put_id",
)
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
protocol: "WebsocketImplProtocol"
read_mutex: asyncio.Lock
write_mutex: asyncio.Lock

View File

@ -6,9 +6,9 @@ from enum import Enum
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore
from sanic_routing.exceptions import NotFound # type: ignore
from sanic_routing.utils import path_to_parts # type: ignore
from sanic_routing import BaseRouter, Route, RouteGroup
from sanic_routing.exceptions import NotFound
from sanic_routing.utils import path_to_parts
from sanic.exceptions import InvalidSignal
from sanic.log import error_logger, logger

View File

@ -1,3 +1,4 @@
from .altsvc import AltSvcCheck # noqa
from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa

View File

@ -0,0 +1,56 @@
from __future__ import annotations
from ast import Assign, Constant, NodeTransformer, Subscript
from typing import TYPE_CHECKING, Any, List
from sanic.http.constants import HTTP
from .base import BaseScheme
if TYPE_CHECKING:
from sanic import Sanic
class AltSvcCheck(BaseScheme):
ident = "ALTSVC"
def visitors(self) -> List[NodeTransformer]:
return [RemoveAltSvc(self.app, self.app.state.verbosity)]
class RemoveAltSvc(NodeTransformer):
def __init__(self, app: Sanic, verbosity: int = 0) -> None:
self._app = app
self._verbosity = verbosity
self._versions = {
info.settings["version"] for info in app.state.server_info
}
def visit_Assign(self, node: Assign) -> Any:
if any(self._matches(target) for target in node.targets):
if self._should_remove():
return None
assert isinstance(node.value, Constant)
node.value.value = self.value()
return node
def _should_remove(self) -> bool:
return len(self._versions) == 1
@staticmethod
def _matches(node) -> bool:
return (
isinstance(node, Subscript)
and isinstance(node.slice, Constant)
and node.slice.value == "alt-svc"
)
def value(self):
values = []
for info in self._app.state.server_info:
port = info.settings["port"]
version = info.settings["version"]
if version is HTTP.VERSION_3:
values.append(f'h3=":{port}"')
return ", ".join(values)

View File

@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import Set, Type
from ast import NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any, Dict, List, Set, Type
class BaseScheme(ABC):
@ -10,11 +13,26 @@ class BaseScheme(ABC):
self.app = app
@abstractmethod
def run(self, method, module_globals) -> None:
def visitors(self) -> List[NodeTransformer]:
...
def __init_subclass__(cls):
BaseScheme._registry.add(cls)
def __call__(self, method, module_globals):
return self.run(method, module_globals)
def __call__(self):
return self.visitors()
@classmethod
def build(cls, method, module_globals, app):
raw_source = getsource(method)
src = dedent(raw_source)
node = parse(src)
for scheme in cls._registry:
for visitor in scheme(app)():
node = visitor.visit(node)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]

View File

@ -1,7 +1,5 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any
from ast import Attribute, Await, Expr, NodeTransformer
from typing import Any, List
from sanic.log import logger
@ -20,16 +18,8 @@ class OptionalDispatchEvent(BaseScheme):
signal.name for signal in app.signal_router.routes
]
def run(self, method, module_globals):
raw_source = getsource(method)
src = dedent(raw_source)
tree = parse(src)
node = RemoveDispatch(self._registered_events).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
def visitors(self) -> List[NodeTransformer]:
return [RemoveDispatch(self._registered_events)]
def _sync_events(self):
all_events = set()

View File

@ -21,9 +21,7 @@ class TouchUp:
module = getmodule(target)
module_globals = dict(getmembers(module))
for scheme in BaseScheme._registry:
modified = scheme(app)(method, module_globals)
modified = BaseScheme.build(method, module_globals, app)
setattr(target, method_name, modified)
target.__touched__ = True

View File

@ -13,7 +13,7 @@ from typing import (
from sanic.models.handler_types import RouteHandler
if TYPE_CHECKING: # no cov
if TYPE_CHECKING:
from sanic import Sanic
from sanic.blueprints import Blueprint

View File

@ -149,6 +149,7 @@ extras_require = {
"docs": docs_require,
"all": all_require,
"ext": ["sanic-ext"],
"http3": ["aioquic"],
}
setup_kwargs["install_requires"] = requirements

0
tests/__init__.py Normal file
View File

View File

@ -25,6 +25,10 @@ class AsyncMock(Mock):
def __await__(self):
return self().__await__()
def reset_mock(self, *args, **kwargs):
super().reset_mock(*args, **kwargs)
self.await_count = 0
def assert_awaited_once(self):
if not self.await_count == 1:
msg = (
@ -32,3 +36,13 @@ class AsyncMock(Mock):
f" Awaited {self.await_count} times."
)
raise AssertionError(msg)
def assert_awaited_once_with(self, *args, **kwargs):
if not self.await_count == 1:
msg = (
f"Expected to have been awaited once."
f" Awaited {self.await_count} times."
)
raise AssertionError(msg)
self.assert_awaited_once()
return self.assert_called_with(*args, **kwargs)

47
tests/client.py Normal file
View File

@ -0,0 +1,47 @@
import asyncio
from textwrap import dedent
from typing import AnyStr
class RawClient:
CRLF = b"\r\n"
def __init__(self, host: str, port: int):
self.reader = None
self.writer = None
self.host = host
self.port = port
async def connect(self):
self.reader, self.writer = await asyncio.open_connection(
self.host, self.port
)
async def close(self):
self.writer.close()
await self.writer.wait_closed()
async def send(self, message: AnyStr):
if isinstance(message, str):
msg = self._clean(message).encode("utf-8")
else:
msg = message
await self._send(msg)
async def _send(self, message: bytes):
if not self.writer:
raise Exception("No open write stream")
self.writer.write(message)
async def recv(self, nbytes: int = -1) -> bytes:
if not self.reader:
raise Exception("No open read stream")
return await self.reader.read(nbytes)
def _clean(self, message: str) -> str:
return (
dedent(message)
.lstrip("\n")
.replace("\n", self.CRLF.decode("utf-8"))
)

View File

@ -150,6 +150,7 @@ def app(request):
yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])
Sanic._app_registry.clear()
@pytest.fixture(scope="function")

0
tests/http3/__init__.py Normal file
View File

View File

@ -0,0 +1,294 @@
from unittest.mock import Mock
import pytest
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import ProtocolNegotiated
from sanic import Request, Sanic
from sanic.compat import Header
from sanic.config import DEFAULT_CONFIG
from sanic.exceptions import PayloadTooLarge
from sanic.http.constants import Stage
from sanic.http.http3 import Http3, HTTPReceiver
from sanic.models.server_types import ConnInfo
from sanic.response import empty, json
from sanic.server.protocols.http_protocol import Http3Protocol
try:
from unittest.mock import AsyncMock
except ImportError:
from tests.asyncmock import AsyncMock # type: ignore
pytestmark = pytest.mark.asyncio
@pytest.fixture(autouse=True)
async def setup(app: Sanic):
@app.get("/")
async def handler(*_):
return empty()
app.router.finalize()
app.signal_router.finalize()
app.signal_router.allow_fail_builtin = False
@pytest.fixture
def http_request(app):
return Request(b"/", Header({}), "3", "GET", Mock(), app)
def generate_protocol(app):
connection = QuicConnection(configuration=QuicConfiguration())
connection._ack_delay = 0
connection._loss = Mock()
connection._loss.spaces = []
connection._loss.get_loss_detection_time = lambda: None
connection.datagrams_to_send = Mock(return_value=[]) # type: ignore
return Http3Protocol(
connection,
app=app,
stream_handler=None,
)
def generate_http_receiver(app, http_request) -> HTTPReceiver:
protocol = generate_protocol(app)
receiver = HTTPReceiver(
protocol.transmit,
protocol,
http_request,
)
http_request.stream = receiver
return receiver
def test_http_receiver_init(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
assert receiver.request_body is None
assert receiver.stage is Stage.IDLE
assert receiver.headers_sent is False
assert receiver.response is None
assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"]
assert receiver.request_bytes == 0
async def test_http_receiver_run_request(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_request = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
await receiver.run()
handler.assert_awaited_once_with(receiver.request)
async def test_http_receiver_run_exception(app: Sanic, http_request: Request):
handler = AsyncMock()
class mock_handle(Sanic):
handle_exception = handler
app.__class__ = mock_handle
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
exception = Exception("Oof")
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
handler.reset_mock()
receiver.stage = Stage.REQUEST
await receiver.run(exception)
handler.assert_awaited_once_with(receiver.request, exception)
def test_http_receiver_respond(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
response = empty()
receiver.stage = Stage.RESPONSE
with pytest.raises(RuntimeError, match="Response already started"):
receiver.respond(response)
receiver.stage = Stage.HANDLER
receiver.response = Mock()
resp = receiver.respond(response)
assert receiver.response is resp
assert resp is response
assert response.stream is receiver
def test_http_receiver_receive_body(app: Sanic, http_request: Request):
receiver = generate_http_receiver(app, http_request)
receiver.request_max_size = 4
receiver.receive_body(b"..")
assert receiver.request.body == b".."
receiver.receive_body(b"..")
assert receiver.request.body == b"...."
with pytest.raises(
PayloadTooLarge, match="Request body exceeds the size limit"
):
receiver.receive_body(b"..")
def test_http3_events(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(DataReceived(b"foobar", 1, False))
receiver = http3.receivers[1]
assert len(http3.receivers) == 1
assert receiver.request.stream_id == 1
assert receiver.request.path == "/location"
assert receiver.request.method == "GET"
assert receiver.request.headers["foo"] == "bar"
assert receiver.request.body == b"foobar"
async def test_send_headers(app: Sanic, http_request: Request):
send_headers_mock = Mock()
existing_send_headers = H3Connection.send_headers
receiver = generate_http_receiver(app, http_request)
receiver.protocol.quic_event_received(
ProtocolNegotiated(alpn_protocol="h3")
)
http_request._protocol = receiver.protocol
def send_headers(*args, **kwargs):
send_headers_mock(*args, **kwargs)
return existing_send_headers(
receiver.protocol.connection, *args, **kwargs
)
receiver.protocol.connection.send_headers = send_headers
receiver.head_only = False
response = json({}, status=201, headers={"foo": "bar"})
with pytest.raises(RuntimeError, match="no response"):
receiver.send_headers()
receiver.response = response
receiver.send_headers()
assert receiver.headers_sent
assert receiver.stage is Stage.RESPONSE
send_headers_mock.assert_called_once_with(
stream_id=0,
headers=[
(b":status", b"201"),
(b"foo", b"bar"),
(b"content-length", b"2"),
(b"content-type", b"application/json"),
],
)
def test_multiple_streams(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
2,
False,
)
)
receiver1 = http3.get_receiver_by_stream_id(1)
receiver2 = http3.get_receiver_by_stream_id(2)
assert len(http3.receivers) == 2
assert isinstance(receiver1, HTTPReceiver)
assert isinstance(receiver2, HTTPReceiver)
assert receiver1 is not receiver2
def test_request_stream_id(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request, Request)
assert receiver.request.stream_id == 1
def test_request_conn_info(app):
protocol = generate_protocol(app)
http3 = Http3(protocol, protocol.transmit)
http3.http_event_received(
HeadersReceived(
[
(b":method", b"GET"),
(b":path", b"/location"),
(b":scheme", b"https"),
(b":authority", b"localhost:8443"),
(b"foo", b"bar"),
],
1,
False,
)
)
receiver = http3.get_receiver_by_stream_id(1)
assert isinstance(receiver.request.conn_info, ConnInfo)

114
tests/http3/test_server.py Normal file
View File

@ -0,0 +1,114 @@
import logging
import sys
from asyncio import Event
from pathlib import Path
import pytest
from sanic import Sanic
from sanic.compat import UVLOOP_INSTALLED
from sanic.http.constants import HTTP
parent_dir = Path(__file__).parent.parent
localhost_dir = parent_dir / "certs/localhost"
@pytest.mark.parametrize("version", (3, HTTP.VERSION_3))
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http3(app: Sanic, version, caplog):
ev = Event()
@app.after_server_start
def shutdown(*_):
ev.set()
app.stop()
with caplog.at_level(logging.INFO):
app.run(
version=version,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
assert ev.is_set()
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/3",
) in caplog.record_tuples
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http1_and_http3(app: Sanic, caplog):
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
app.prepare(
version=1,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
with caplog.at_level(logging.INFO):
Sanic.serve()
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/1.1",
) in caplog.record_tuples
assert (
"sanic.root",
logging.INFO,
"server: sanic, HTTP/3",
) in caplog.record_tuples
@pytest.mark.skipif(
sys.version_info < (3, 8) and not UVLOOP_INSTALLED,
reason="In 3.7 w/o uvloop the port is not always released",
)
def test_server_starts_http1_and_http3_bad_order(app: Sanic, caplog):
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=1,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)
message = (
"Serving HTTP/3 instances as a secondary server is not supported. "
"There can only be a single HTTP/3 worker and it must be the first "
"instance prepared."
)
with pytest.raises(RuntimeError, match=message):
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
)

View File

@ -0,0 +1,46 @@
from datetime import datetime
from aioquic.tls import CipherSuite, SessionTicket
from sanic.http.http3 import SessionTicketStore
def _generate_ticket(label):
return SessionTicket(
1,
CipherSuite.AES_128_GCM_SHA256,
datetime.now(),
datetime.now(),
label,
label.decode(),
label,
None,
[],
)
def test_session_ticket_store():
store = SessionTicketStore()
assert len(store.tickets) == 0
ticket1 = _generate_ticket(b"foo")
store.add(ticket1)
assert len(store.tickets) == 1
ticket2 = _generate_ticket(b"bar")
store.add(ticket2)
assert len(store.tickets) == 2
assert len(store.tickets) == 2
popped2 = store.pop(ticket2.ticket)
assert len(store.tickets) == 1
assert popped2 is ticket2
popped1 = store.pop(ticket1.ticket)
assert len(store.tickets) == 0
assert popped1 is ticket1

View File

@ -417,7 +417,7 @@ async def test_request_class_custom():
class MyCustomRequest(Request):
pass
app = Sanic(name=__name__, request_class=MyCustomRequest)
app = Sanic(name="Test", request_class=MyCustomRequest)
@app.get("/custom")
def custom_request(request):

View File

@ -148,8 +148,7 @@ def test_tls_wrong_options(cmd: Tuple[str]):
assert not out
lines = err.decode().split("\n")
errmsg = lines[6]
assert errmsg == "TLS certificates must be specified by either of:"
assert "TLS certificates must be specified by either of:" in lines
@pytest.mark.parametrize(

View File

@ -1,4 +1,5 @@
import logging
import os
from contextlib import contextmanager
from os import environ
@ -13,6 +14,7 @@ from pytest import MonkeyPatch
from sanic import Sanic
from sanic.config import DEFAULT_CONFIG, Config
from sanic.constants import LocalCertCreator
from sanic.exceptions import PyFileError
@ -49,7 +51,7 @@ def test_load_from_object(app: Sanic):
def test_load_from_object_string(app: Sanic):
app.config.load("test_config.ConfigTest")
app.config.load("tests.test_config.ConfigTest")
assert "CONFIG_VALUE" in app.config
assert app.config.CONFIG_VALUE == "should be used"
assert "not_for_config" not in app.config
@ -71,14 +73,14 @@ def test_load_from_object_string_exception(app: Sanic):
def test_auto_env_prefix():
environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(name=__name__)
app = Sanic(name="Test")
assert app.config.TEST_ANSWER == 42
del environ["SANIC_TEST_ANSWER"]
def test_auto_bool_env_prefix():
environ["SANIC_TEST_ANSWER"] = "True"
app = Sanic(name=__name__)
app = Sanic(name="Test")
assert app.config.TEST_ANSWER is True
del environ["SANIC_TEST_ANSWER"]
@ -86,28 +88,28 @@ def test_auto_bool_env_prefix():
@pytest.mark.parametrize("env_prefix", [None, ""])
def test_empty_load_env_prefix(env_prefix):
environ["SANIC_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, env_prefix=env_prefix)
app = Sanic(name="Test", env_prefix=env_prefix)
assert getattr(app.config, "TEST_ANSWER", None) is None
del environ["SANIC_TEST_ANSWER"]
def test_env_prefix():
environ["MYAPP_TEST_ANSWER"] = "42"
app = Sanic(name=__name__, env_prefix="MYAPP_")
app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_ANSWER == 42
del environ["MYAPP_TEST_ANSWER"]
def test_env_prefix_float_values():
environ["MYAPP_TEST_ROI"] = "2.3"
app = Sanic(name=__name__, env_prefix="MYAPP_")
app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_ROI == 2.3
del environ["MYAPP_TEST_ROI"]
def test_env_prefix_string_value():
environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken"
app = Sanic(name=__name__, env_prefix="MYAPP_")
app = Sanic(name="Test", env_prefix="MYAPP_")
assert app.config.TEST_TOKEN == "somerandomtesttoken"
del environ["MYAPP_TEST_TOKEN"]
@ -116,7 +118,7 @@ def test_env_w_custom_converter():
environ["SANIC_TEST_ANSWER"] = "42"
config = Config(converters=[UltimateAnswer])
app = Sanic(name=__name__, config=config)
app = Sanic(name="Test", config=config)
assert isinstance(app.config.TEST_ANSWER, UltimateAnswer)
assert app.config.TEST_ANSWER.answer == 42
del environ["SANIC_TEST_ANSWER"]
@ -125,7 +127,7 @@ def test_env_w_custom_converter():
def test_env_lowercase():
with pytest.warns(None) as record:
environ["SANIC_test_answer"] = "42"
app = Sanic(name=__name__)
app = Sanic(name="Test")
assert app.config.test_answer == 42
assert str(record[0].message) == (
"[DEPRECATION v22.9] Lowercase environment variables will not be "
@ -435,3 +437,21 @@ def test_negative_proxy_count(app: Sanic):
)
with pytest.raises(ValueError, match=message):
app.prepare()
@pytest.mark.parametrize(
"passed,expected",
(
("auto", LocalCertCreator.AUTO),
("mkcert", LocalCertCreator.MKCERT),
("trustme", LocalCertCreator.TRUSTME),
("AUTO", LocalCertCreator.AUTO),
("MKCERT", LocalCertCreator.MKCERT),
("TRUSTME", LocalCertCreator.TRUSTME),
),
)
def test_convert_local_cert_creator(passed, expected):
os.environ["SANIC_LOCAL_CERT_CREATOR"] = passed
app = Sanic("Test")
assert app.config.LOCAL_CERT_CREATOR is expected
del os.environ["SANIC_LOCAL_CERT_CREATOR"]

View File

@ -17,7 +17,7 @@ class CustomRequest(Request):
def test_custom_request():
app = Sanic(name=__name__, request_class=CustomRequest)
app = Sanic(name="Test", request_class=CustomRequest)
@app.route("/post", methods=["POST"])
async def post_handler(request):

View File

@ -259,7 +259,7 @@ def test_custom_exception_default_message(exception_app):
def test_exception_in_ws_logged(caplog):
app = Sanic(__name__)
app = Sanic("Test")
@app.websocket("/feed")
async def feed(request, ws):
@ -279,7 +279,7 @@ def test_exception_in_ws_logged(caplog):
@pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_context(debug):
app = Sanic(__name__)
app = Sanic("Test")
class TeapotError(SanicException):
status_code = 418
@ -314,7 +314,7 @@ def test_contextual_exception_context(debug):
@pytest.mark.parametrize("debug", (True, False))
def test_contextual_exception_extra(debug):
app = Sanic(__name__)
app = Sanic("Test")
class TeapotError(SanicException):
status_code = 418
@ -361,7 +361,7 @@ def test_contextual_exception_extra(debug):
@pytest.mark.parametrize("override", (True, False))
def test_contextual_exception_functional_message(override):
app = Sanic(__name__)
app = Sanic("Test")
class TeapotError(SanicException):
status_code = 418

View File

@ -1,9 +1,7 @@
import asyncio
import json as stdjson
from collections import namedtuple
from textwrap import dedent
from typing import AnyStr
from pathlib import Path
import pytest
@ -11,54 +9,15 @@ from sanic_testing.reusable import ReusableClient
from sanic import json, text
from sanic.app import Sanic
from tests.client import RawClient
parent_dir = Path(__file__).parent
localhost_dir = parent_dir / "certs/localhost"
PORT = 1234
class RawClient:
CRLF = b"\r\n"
def __init__(self, host: str, port: int):
self.reader = None
self.writer = None
self.host = host
self.port = port
async def connect(self):
self.reader, self.writer = await asyncio.open_connection(
self.host, self.port
)
async def close(self):
self.writer.close()
await self.writer.wait_closed()
async def send(self, message: AnyStr):
if isinstance(message, str):
msg = self._clean(message).encode("utf-8")
else:
msg = message
await self._send(msg)
async def _send(self, message: bytes):
if not self.writer:
raise Exception("No open write stream")
self.writer.write(message)
async def recv(self, nbytes: int = -1) -> bytes:
if not self.reader:
raise Exception("No open read stream")
return await self.reader.read(nbytes)
def _clean(self, message: str) -> str:
return (
dedent(message)
.lstrip("\n")
.replace("\n", self.CRLF.decode("utf-8"))
)
@pytest.fixture
def test_app(app: Sanic):
app.config.KEEP_ALIVE_TIMEOUT = 1
@ -115,7 +74,7 @@ def test_full_message(client):
"""
)
response = client.recv()
assert len(response) == 140
assert len(response) == 151
assert b"200 OK" in response

View File

@ -0,0 +1,66 @@
import sys
from pathlib import Path
import pytest
from sanic.app import Sanic
from sanic.response import empty
from tests.client import RawClient
parent_dir = Path(__file__).parent
localhost_dir = parent_dir / "certs/localhost"
PORT = 12344
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Not supported in 3.7")
def test_http1_response_has_alt_svc():
Sanic._app_registry.clear()
app = Sanic("TestAltSvc")
app.config.TOUCHUP = True
response = b""
@app.get("/")
async def handler(*_):
return empty()
@app.after_server_start
async def do_request(*_):
nonlocal response
app.router.reset()
app.router.finalize()
client = RawClient(app.state.host, app.state.port)
await client.connect()
await client.send(
"""
GET / HTTP/1.1
host: localhost:7777
"""
)
response = await client.recv()
await client.close()
@app.after_server_start
def shutdown(*_):
app.stop()
app.prepare(
version=3,
ssl={
"cert": localhost_dir / "fullchain.pem",
"key": localhost_dir / "privkey.pem",
},
port=PORT,
)
app.prepare(
version=1,
port=PORT,
)
Sanic.serve()
assert f'alt-svc: h3=":{PORT}"\r\n'.encode() in response

View File

@ -136,7 +136,7 @@ def test_log_connection_lost(app, debug, monkeypatch):
async def test_logger(caplog):
rand_string = str(uuid.uuid4())
app = Sanic(name=__name__)
app = Sanic(name="Test")
@app.get("/")
def log_info(request):
@ -163,7 +163,7 @@ def test_logging_modified_root_logger_config():
def test_access_log_client_ip_remote_addr(monkeypatch):
access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access)
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging")
app.config.PROXIES_COUNT = 2
@ -190,7 +190,7 @@ def test_access_log_client_ip_remote_addr(monkeypatch):
def test_access_log_client_ip_reqip(monkeypatch):
access = Mock()
monkeypatch.setattr(sanic.http, "access_logger", access)
monkeypatch.setattr(sanic.http.http1, "access_logger", access)
app = Sanic("test_logging")

View File

@ -53,7 +53,7 @@ def test_motd_with_expected_info(app, run_startup):
assert logs[1][2] == f"Sanic v{__version__}"
assert logs[3][2] == "mode: debug, single worker"
assert logs[4][2] == "server: sanic"
assert logs[4][2] == "server: sanic, HTTP/1.1"
assert logs[5][2] == f"python: {platform.python_version()}"
assert logs[6][2] == f"platform: {platform.platform()}"

View File

@ -14,7 +14,7 @@ from sanic.touchup.schemes.ode import OptionalDispatchEvent
try:
from unittest.mock import AsyncMock
except ImportError:
from asyncmock import AsyncMock # type: ignore
from tests.asyncmock import AsyncMock # type: ignore
@pytest.fixture

View File

@ -231,3 +231,15 @@ def test_get_current_request(app):
_, resp = app.test_client.get("/")
assert resp.json["same"]
def test_request_stream_id(app):
@app.get("/")
async def get(request):
try:
request.stream_id
except Exception as e:
return response.text(str(e))
_, resp = app.test_client.get("/")
assert resp.text == "Stream ID is only a property of a HTTP/3 request"

View File

@ -552,7 +552,7 @@ def test_streaming_new_api(app):
def test_streaming_echo():
"""2-way streaming chat between server and client."""
app = Sanic(name=__name__)
app = Sanic(name="Test")
@app.post("/echo", stream=True)
async def handler(request):

View File

@ -2050,7 +2050,7 @@ async def test_request_form_invalid_content_type_asgi(app):
def test_endpoint_basic():
app = Sanic(name=__name__)
app = Sanic(name="Test")
@app.route("/")
def my_unique_handler(request):
@ -2058,12 +2058,12 @@ def test_endpoint_basic():
request, response = app.test_client.get("/")
assert request.endpoint == "test_requests.my_unique_handler"
assert request.endpoint == "Test.my_unique_handler"
@pytest.mark.asyncio
async def test_endpoint_basic_asgi():
app = Sanic(name=__name__)
app = Sanic(name="Test")
@app.route("/")
def my_unique_handler(request):
@ -2071,7 +2071,7 @@ async def test_endpoint_basic_asgi():
request, response = await app.asgi_client.get("/")
assert request.endpoint == "test_requests.my_unique_handler"
assert request.endpoint == "Test.my_unique_handler"
def test_endpoint_named_app():

View File

@ -101,11 +101,12 @@ def test_response_header(app):
return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"})
request, response = app.test_client.get("/")
assert dict(response.headers) == {
for key, value in {
"connection": "keep-alive",
"content-length": "11",
"content-type": "application/json",
}
}.items():
assert response.headers[key] == value
def test_response_content_length(app):

View File

@ -13,7 +13,7 @@ from sanic.response import empty
try:
from unittest.mock import AsyncMock
except ImportError:
from asyncmock import AsyncMock # type: ignore
from tests.asyncmock import AsyncMock # type: ignore
pytestmark = pytest.mark.asyncio

View File

@ -1,18 +1,30 @@
import logging
import os
import ssl
import uuid
import subprocess
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import Mock, patch
from urllib.parse import urlparse
import pytest
from sanic_testing.testing import HOST, PORT, SanicTestClient
from sanic_testing.testing import HOST, PORT
import sanic.http.tls.creators
from sanic import Sanic
from sanic.compat import OS_IS_WINDOWS
from sanic.log import logger
from sanic.application.constants import Mode
from sanic.constants import LocalCertCreator
from sanic.exceptions import SanicException
from sanic.helpers import _default
from sanic.http.tls.context import SanicSSLContext
from sanic.http.tls.creators import (
MkcertCreator,
TrustmeCreator,
get_ssl_context,
)
from sanic.response import text
@ -26,9 +38,63 @@ sanic_cert = os.path.join(sanic_dir, "fullchain.pem")
sanic_key = os.path.join(sanic_dir, "privkey.pem")
@pytest.fixture
def server_cert():
return Mock()
@pytest.fixture
def issue_cert(server_cert):
mock = Mock(return_value=server_cert)
return mock
@pytest.fixture
def ca(issue_cert):
ca = Mock()
ca.issue_cert = issue_cert
return ca
@pytest.fixture
def trustme(ca):
module = Mock()
module.CA = Mock(return_value=ca)
return module
@pytest.fixture
def MockMkcertCreator():
class Creator(MkcertCreator):
SUPPORTED = True
def check_supported(self):
if not self.SUPPORTED:
raise SanicException("Nope")
generate_cert = Mock()
return Creator
@pytest.fixture
def MockTrustmeCreator():
class Creator(TrustmeCreator):
SUPPORTED = True
def check_supported(self):
if not self.SUPPORTED:
raise SanicException("Nope")
generate_cert = Mock()
return Creator
@contextmanager
def replace_server_name(hostname):
"""Temporarily replace the server name sent with all TLS requests with a fake hostname."""
"""Temporarily replace the server name sent with all TLS requests with
a fake hostname."""
def hack_wrap_bio(
self,
@ -69,8 +135,7 @@ def test_url_attributes_with_ssl_context(app, path, query, expected_url):
app.add_route(handler, path)
port = app.test_client.port
request, response = app.test_client.get(
request, _ = app.test_client.get(
f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": context},
)
@ -100,7 +165,7 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
app.add_route(handler, path)
request, response = app.test_client.get(
request, _ = app.test_client.get(
f"https://{HOST}:{PORT}" + path + f"?{query}",
server_kwargs={"ssl": ssl_dict},
)
@ -116,22 +181,22 @@ def test_url_attributes_with_ssl_dict(app, path, query, expected_url):
def test_cert_sni_single(app):
@app.get("/sni")
async def handler(request):
async def handler1(request):
return text(request.conn_info.server_name)
@app.get("/commonname")
async def handler(request):
async def handler2(request):
return text(request.conn_info.cert.get("commonName"))
port = app.test_client.port
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://localhost:{port}/sni",
server_kwargs={"ssl": localhost_dir},
)
assert response.status == 200
assert response.text == "localhost"
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://localhost:{port}/commonname",
server_kwargs={"ssl": localhost_dir},
)
@ -143,16 +208,16 @@ def test_cert_sni_list(app):
ssl_list = [sanic_dir, localhost_dir]
@app.get("/sni")
async def handler(request):
async def handler1(request):
return text(request.conn_info.server_name)
@app.get("/commonname")
async def handler(request):
async def handler2(request):
return text(request.conn_info.cert.get("commonName"))
# This test should match the localhost cert
port = app.test_client.port
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://localhost:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
@ -168,14 +233,14 @@ def test_cert_sni_list(app):
# This part should use the sanic.example cert because it matches
with replace_server_name("www.sanic.example"):
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
assert response.status == 200
assert response.text == "www.sanic.example"
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/commonname",
server_kwargs={"ssl": ssl_list},
)
@ -184,14 +249,14 @@ def test_cert_sni_list(app):
# This part should use the sanic.example cert, that being the first listed
with replace_server_name("invalid.test"):
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
assert response.status == 200
assert response.text == "invalid.test"
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/commonname",
server_kwargs={"ssl": ssl_list},
)
@ -200,7 +265,8 @@ def test_cert_sni_list(app):
def test_missing_sni(app):
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
"""The sanic cert does not list 127.0.0.1 and httpx does not send
IP as SNI anyway."""
ssl_list = [None, sanic_dir]
@app.get("/sni")
@ -209,7 +275,7 @@ def test_missing_sni(app):
port = app.test_client.port
with pytest.raises(Exception) as exc:
request, response = app.test_client.get(
app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
@ -217,7 +283,8 @@ def test_missing_sni(app):
def test_no_matching_cert(app):
"""The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway."""
"""The sanic cert does not list 127.0.0.1 and httpx does not send
IP as SNI anyway."""
ssl_list = [None, sanic_dir]
@app.get("/sni")
@ -227,7 +294,7 @@ def test_no_matching_cert(app):
port = app.test_client.port
with replace_server_name("invalid.test"):
with pytest.raises(Exception) as exc:
request, response = app.test_client.get(
app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
@ -244,7 +311,7 @@ def test_wildcards(app):
port = app.test_client.port
with replace_server_name("foo.sanic.test"):
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
@ -253,14 +320,14 @@ def test_wildcards(app):
with replace_server_name("sanic.test"):
with pytest.raises(Exception) as exc:
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
assert "Request and response object expected" in str(exc.value)
with replace_server_name("sub.foo.sanic.test"):
with pytest.raises(Exception) as exc:
request, response = app.test_client.get(
_, response = app.test_client.get(
f"https://127.0.0.1:{port}/sni",
server_kwargs={"ssl": ssl_list},
)
@ -275,9 +342,7 @@ def test_invalid_ssl_dict(app):
ssl_dict = {"cert": None, "key": None}
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": ssl_dict}
)
app.test_client.get("/test", server_kwargs={"ssl": ssl_dict})
assert str(excinfo.value) == "SSL dict needs filenames for cert and key."
@ -288,9 +353,7 @@ def test_invalid_ssl_type(app):
return text("ssl test")
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": False}
)
app.test_client.get("/test", server_kwargs={"ssl": False})
assert "Invalid ssl argument" in str(excinfo.value)
@ -303,9 +366,7 @@ def test_cert_file_on_pathlist(app):
ssl_list = [sanic_cert]
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": ssl_list}
)
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
assert "folder expected" in str(excinfo.value)
assert sanic_cert in str(excinfo.value)
@ -319,9 +380,7 @@ def test_missing_cert_path(app):
ssl_list = [invalid_dir]
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": ssl_list}
)
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
assert "not found" in str(excinfo.value)
assert invalid_dir + "/privkey.pem" in str(excinfo.value)
@ -336,9 +395,7 @@ def test_missing_cert_file(app):
ssl_list = [invalid2]
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": ssl_list}
)
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
assert "not found" in str(excinfo.value)
assert invalid2 + "/fullchain.pem" in str(excinfo.value)
@ -352,15 +409,13 @@ def test_no_certs_on_list(app):
ssl_list = [None]
with pytest.raises(ValueError) as excinfo:
request, response = app.test_client.get(
"/test", server_kwargs={"ssl": ssl_list}
)
app.test_client.get("/test", server_kwargs={"ssl": ssl_list})
assert "No certificates" in str(excinfo.value)
def test_logger_vhosts(caplog):
app = Sanic(name=__name__)
app = Sanic(name="test_logger_vhosts")
@app.after_server_start
def stop(*args):
@ -374,5 +429,210 @@ def test_logger_vhosts(caplog):
][0]
assert logmsg == (
"Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, sanic.example, www.sanic.example, *.sanic.test, 2001:DB8:0:0:0:0:0:541C"
"Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, "
"sanic.example, www.sanic.example, *.sanic.test, "
"2001:DB8:0:0:0:0:0:541C"
)
def test_mk_cert_creator_default(app: Sanic):
cert_creator = MkcertCreator(app, _default, _default)
assert isinstance(cert_creator.tmpdir, Path)
assert cert_creator.tmpdir.exists()
def test_mk_cert_creator_is_supported(app):
cert_creator = MkcertCreator(app, _default, _default)
with patch("subprocess.run") as run:
cert_creator.check_supported()
run.assert_called_once_with(
["mkcert", "-help"],
check=True,
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
def test_mk_cert_creator_is_not_supported(app):
cert_creator = MkcertCreator(app, _default, _default)
with patch("subprocess.run") as run:
run.side_effect = Exception("")
with pytest.raises(
SanicException, match="Sanic is attempting to use mkcert"
):
cert_creator.check_supported()
def test_mk_cert_creator_generate_cert_default(app):
cert_creator = MkcertCreator(app, _default, _default)
with patch("subprocess.run") as run:
with patch("sanic.http.tls.creators.CertSimple"):
retval = Mock()
retval.stdout = "foo"
run.return_value = retval
cert_creator.generate_cert("localhost")
run.assert_called_once()
def test_mk_cert_creator_generate_cert_localhost(app):
cert_creator = MkcertCreator(app, localhost_key, localhost_cert)
with patch("subprocess.run") as run:
with patch("sanic.http.tls.creators.CertSimple"):
cert_creator.generate_cert("localhost")
run.assert_not_called()
def test_trustme_creator_default(app: Sanic):
cert_creator = TrustmeCreator(app, _default, _default)
assert isinstance(cert_creator.tmpdir, Path)
assert cert_creator.tmpdir.exists()
def test_trustme_creator_is_supported(app, monkeypatch):
monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", True)
cert_creator = TrustmeCreator(app, _default, _default)
cert_creator.check_supported()
def test_trustme_creator_is_not_supported(app, monkeypatch):
monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", False)
cert_creator = TrustmeCreator(app, _default, _default)
with pytest.raises(
SanicException, match="Sanic is attempting to use trustme"
):
cert_creator.check_supported()
def test_trustme_creator_generate_cert_default(
app, monkeypatch, trustme, issue_cert, server_cert, ca
):
monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
cert_creator = TrustmeCreator(app, _default, _default)
cert = cert_creator.generate_cert("localhost")
assert isinstance(cert, SanicSSLContext)
trustme.CA.assert_called_once_with()
issue_cert.assert_called_once_with("localhost")
server_cert.configure_cert.assert_called_once()
ca.configure_trust.assert_called_once()
ca.cert_pem.write_to_path.assert_called_once_with(str(cert.sanic["cert"]))
write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
write_to_path.assert_called_once_with(str(cert.sanic["key"]))
def test_trustme_creator_generate_cert_localhost(
app, monkeypatch, trustme, server_cert, ca
):
monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme)
cert_creator = TrustmeCreator(app, localhost_key, localhost_cert)
cert_creator.generate_cert("localhost")
ca.cert_pem.write_to_path.assert_called_once_with(localhost_cert)
write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path
write_to_path.assert_called_once_with(localhost_key)
def test_get_ssl_context_with_ssl_context(app):
mock_context = Mock()
context = get_ssl_context(app, mock_context)
assert context is mock_context
def test_get_ssl_context_in_production(app):
app.state.mode = Mode.PRODUCTION
with pytest.raises(
SanicException,
match="Cannot run Sanic as an HTTPS server in PRODUCTION mode",
):
get_ssl_context(app, None)
@pytest.mark.parametrize(
"requirement,mk_supported,trustme_supported,mk_called,trustme_called,err",
(
(LocalCertCreator.AUTO, True, False, True, False, None),
(LocalCertCreator.AUTO, True, True, True, False, None),
(LocalCertCreator.AUTO, False, True, False, True, None),
(
LocalCertCreator.AUTO,
False,
False,
False,
False,
"Sanic could not find package to create a TLS certificate",
),
(LocalCertCreator.MKCERT, True, False, True, False, None),
(LocalCertCreator.MKCERT, True, True, True, False, None),
(LocalCertCreator.MKCERT, False, True, False, False, "Nope"),
(LocalCertCreator.MKCERT, False, False, False, False, "Nope"),
(LocalCertCreator.TRUSTME, True, False, False, False, "Nope"),
(LocalCertCreator.TRUSTME, True, True, False, True, None),
(LocalCertCreator.TRUSTME, False, True, False, True, None),
(LocalCertCreator.TRUSTME, False, False, False, False, "Nope"),
),
)
def test_get_ssl_context_only_mkcert(
app,
monkeypatch,
MockMkcertCreator,
MockTrustmeCreator,
requirement,
mk_supported,
trustme_supported,
mk_called,
trustme_called,
err,
):
app.state.mode = Mode.DEBUG
app.config.LOCAL_CERT_CREATOR = requirement
monkeypatch.setattr(
sanic.http.tls.creators, "MkcertCreator", MockMkcertCreator
)
monkeypatch.setattr(
sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
)
MockMkcertCreator.SUPPORTED = mk_supported
MockTrustmeCreator.SUPPORTED = trustme_supported
if err:
with pytest.raises(SanicException, match=err):
get_ssl_context(app, None)
else:
get_ssl_context(app, None)
if mk_called:
MockMkcertCreator.generate_cert.assert_called_once_with("localhost")
else:
MockMkcertCreator.generate_cert.assert_not_called()
if trustme_called:
MockTrustmeCreator.generate_cert.assert_called_once_with("localhost")
else:
MockTrustmeCreator.generate_cert.assert_not_called()
def test_no_http3_with_trustme(
app,
monkeypatch,
MockTrustmeCreator,
):
monkeypatch.setattr(
sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator
)
MockTrustmeCreator.SUPPORTED = True
app.config.LOCAL_CERT_CREATOR = "TRUSTME"
with pytest.raises(
SanicException,
match=(
"Sorry, you cannot currently use trustme as a local certificate "
"generator for an HTTP/3 server"
),
):
app.run(version=3, debug=True)
def test_sanic_ssl_context_create():
context = ssl.SSLContext()
sanic_context = SanicSSLContext.create_from_ssl_context(context)
assert sanic_context is context
assert isinstance(sanic_context, SanicSSLContext)

View File

@ -53,7 +53,7 @@ def test_unix_socket_creation(caplog):
assert os.path.exists(SOCKPATH)
ino = os.stat(SOCKPATH).st_ino
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.listener("after_server_start")
def running(app, loop):
@ -74,7 +74,7 @@ def test_unix_socket_creation(caplog):
@pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock"))
def test_invalid_paths(path):
app = Sanic(name=__name__)
app = Sanic(name="test")
with pytest.raises((FileExistsError, FileNotFoundError)):
app.run(unix=path)
@ -84,7 +84,7 @@ def test_dont_replace_file():
with open(SOCKPATH, "w") as f:
f.write("File, not socket")
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.listener("after_server_start")
def stop(app, loop):
@ -101,7 +101,7 @@ def test_dont_follow_symlink():
sock.bind(SOCKPATH2)
os.symlink(SOCKPATH2, SOCKPATH)
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.listener("after_server_start")
def stop(app, loop):
@ -112,7 +112,7 @@ def test_dont_follow_symlink():
def test_socket_deleted_while_running():
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.listener("after_server_start")
async def hack(app, loop):
@ -123,7 +123,7 @@ def test_socket_deleted_while_running():
def test_socket_replaced_with_file():
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.listener("after_server_start")
async def hack(app, loop):
@ -136,7 +136,7 @@ def test_socket_replaced_with_file():
def test_unix_connection():
app = Sanic(name=__name__)
app = Sanic(name="test")
@app.get("/")
def handler(request):
@ -159,7 +159,7 @@ def test_unix_connection():
app.run(host="myhost.invalid", unix=SOCKPATH)
app_multi = Sanic(name=__name__)
app_multi = Sanic(name="test")
def handler(request):

View File

@ -19,7 +19,7 @@ def test_route(app, handler):
def test_bp(app, handler):
bp = Blueprint(__name__, version=1)
bp = Blueprint("Test", version=1)
bp.route("/")(handler)
app.blueprint(bp)
@ -28,7 +28,7 @@ def test_bp(app, handler):
def test_bp_use_route(app, handler):
bp = Blueprint(__name__, version=1)
bp = Blueprint("Test", version=1)
bp.route("/", version=1.1)(handler)
app.blueprint(bp)
@ -37,7 +37,7 @@ def test_bp_use_route(app, handler):
def test_bp_group(app, handler):
bp = Blueprint(__name__)
bp = Blueprint("Test")
bp.route("/")(handler)
group = Blueprint.group(bp, version=1)
app.blueprint(group)
@ -47,7 +47,7 @@ def test_bp_group(app, handler):
def test_bp_group_use_bp(app, handler):
bp = Blueprint(__name__, version=1.1)
bp = Blueprint("Test", version=1.1)
bp.route("/")(handler)
group = Blueprint.group(bp, version=1)
app.blueprint(group)
@ -57,7 +57,7 @@ def test_bp_group_use_bp(app, handler):
def test_bp_group_use_registration(app, handler):
bp = Blueprint(__name__, version=1.1)
bp = Blueprint("Test", version=1.1)
bp.route("/")(handler)
group = Blueprint.group(bp, version=1)
app.blueprint(group, version=1.2)
@ -67,7 +67,7 @@ def test_bp_group_use_registration(app, handler):
def test_bp_group_use_route(app, handler):
bp = Blueprint(__name__, version=1.1)
bp = Blueprint("Test", version=1.1)
bp.route("/", version=1.3)(handler)
group = Blueprint.group(bp, version=1)
app.blueprint(group, version=1.2)
@ -84,7 +84,7 @@ def test_version_prefix_route(app, handler):
def test_version_prefix_bp(app, handler):
bp = Blueprint(__name__, version=1, version_prefix="/api/v")
bp = Blueprint("Test", version=1, version_prefix="/api/v")
bp.route("/")(handler)
app.blueprint(bp)
@ -93,7 +93,7 @@ def test_version_prefix_bp(app, handler):
def test_version_prefix_bp_use_route(app, handler):
bp = Blueprint(__name__, version=1, version_prefix="/ignore/v")
bp = Blueprint("Test", version=1, version_prefix="/ignore/v")
bp.route("/", version=1.1, version_prefix="/api/v")(handler)
app.blueprint(bp)
@ -102,7 +102,7 @@ def test_version_prefix_bp_use_route(app, handler):
def test_version_prefix_bp_group(app, handler):
bp = Blueprint(__name__)
bp = Blueprint("Test")
bp.route("/")(handler)
group = Blueprint.group(bp, version=1, version_prefix="/api/v")
app.blueprint(group)
@ -112,7 +112,7 @@ def test_version_prefix_bp_group(app, handler):
def test_version_prefix_bp_group_use_bp(app, handler):
bp = Blueprint(__name__, version=1.1, version_prefix="/api/v")
bp = Blueprint("Test", version=1.1, version_prefix="/api/v")
bp.route("/")(handler)
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
app.blueprint(group)
@ -122,7 +122,7 @@ def test_version_prefix_bp_group_use_bp(app, handler):
def test_version_prefix_bp_group_use_registration(app, handler):
bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v")
bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v")
bp.route("/")(handler)
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
app.blueprint(group, version=1.2, version_prefix="/api/v")
@ -132,7 +132,7 @@ def test_version_prefix_bp_group_use_registration(app, handler):
def test_version_prefix_bp_group_use_route(app, handler):
bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v")
bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v")
bp.route("/", version=1.3, version_prefix="/api/v")(handler)
group = Blueprint.group(bp, version=1, version_prefix="/ignore/v")
app.blueprint(group, version=1.2, version_prefix="/stillignoring/v")

View File

@ -14,7 +14,7 @@ from sanic.server.websockets.frame import WebsocketFrameAssembler
try:
from unittest.mock import AsyncMock
except ImportError:
from asyncmock import AsyncMock # type: ignore
from tests.asyncmock import AsyncMock # type: ignore
@pytest.mark.asyncio

View File

@ -6,7 +6,9 @@ usedevelop = true
setenv =
{py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UJSON=1
{py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UVLOOP=1
extras = test
extras = test, http3
deps =
httpx==0.23
allowlist_externals =
pytest
coverage
@ -46,7 +48,7 @@ commands =
[testenv:docs]
platform = linux|linux2|darwin
allowlist_externals = make
extras = docs
extras = docs, http3
commands =
make docs-test