WIP
This commit is contained in:
parent
34fe43772e
commit
a937977ca2
@ -76,6 +76,7 @@ from sanic.handlers import ErrorHandler
|
|||||||
from sanic.helpers import _default
|
from sanic.helpers import _default
|
||||||
from sanic.http import Stage
|
from sanic.http import Stage
|
||||||
from sanic.http.constants import HTTP
|
from sanic.http.constants import HTTP
|
||||||
|
from sanic.http.tls import process_to_context
|
||||||
from sanic.log import (
|
from sanic.log import (
|
||||||
LOGGING_CONFIG_DEFAULTS,
|
LOGGING_CONFIG_DEFAULTS,
|
||||||
Colors,
|
Colors,
|
||||||
@ -104,7 +105,6 @@ from sanic.server import serve, serve_multiple, serve_single, try_use_uvloop
|
|||||||
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||||
from sanic.server.websockets.impl import ConnectionClosed
|
from sanic.server.websockets.impl import ConnectionClosed
|
||||||
from sanic.signals import Signal, SignalRouter
|
from sanic.signals import Signal, SignalRouter
|
||||||
from sanic.tls import process_to_context
|
|
||||||
from sanic.touchup import TouchUp, TouchUpMeta
|
from sanic.touchup import TouchUp, TouchUpMeta
|
||||||
|
|
||||||
|
|
||||||
@ -1216,6 +1216,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
|||||||
finally:
|
finally:
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
logger.info("Server Stopped")
|
logger.info("Server Stopped")
|
||||||
|
print("END OF RUN")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""
|
"""
|
||||||
|
@ -11,6 +11,7 @@ from typing import Any, List, Union
|
|||||||
from sanic.app import Sanic
|
from sanic.app import Sanic
|
||||||
from sanic.application.logo import get_logo
|
from sanic.application.logo import get_logo
|
||||||
from sanic.cli.arguments import Group
|
from sanic.cli.arguments import Group
|
||||||
|
from sanic.http.constants import HTTP
|
||||||
from sanic.log import error_logger
|
from sanic.log import error_logger
|
||||||
from sanic.simple import create_simple_server
|
from sanic.simple import create_simple_server
|
||||||
|
|
||||||
@ -160,6 +161,7 @@ Or, a path to a directory to run as a simple HTTP server:
|
|||||||
elif len(ssl) == 1 and ssl[0] is not None:
|
elif len(ssl) == 1 and ssl[0] is not None:
|
||||||
# Use only one cert, no TLSSelector.
|
# Use only one cert, no TLSSelector.
|
||||||
ssl = ssl[0]
|
ssl = ssl[0]
|
||||||
|
version = HTTP(self.args.http)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"access_log": self.args.access_log,
|
"access_log": self.args.access_log,
|
||||||
"debug": self.args.debug,
|
"debug": self.args.debug,
|
||||||
@ -172,6 +174,7 @@ Or, a path to a directory to run as a simple HTTP server:
|
|||||||
"unix": self.args.unix,
|
"unix": self.args.unix,
|
||||||
"verbosity": self.args.verbosity or 0,
|
"verbosity": self.args.verbosity or 0,
|
||||||
"workers": self.args.workers,
|
"workers": self.args.workers,
|
||||||
|
"version": version,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.args.auto_reload:
|
if self.args.auto_reload:
|
||||||
|
@ -83,6 +83,30 @@ class ApplicationGroup(Group):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPVersionGroup(Group):
|
||||||
|
name = "HTTP version"
|
||||||
|
|
||||||
|
def attach(self):
|
||||||
|
group = self.container.add_mutually_exclusive_group()
|
||||||
|
group.add_argument(
|
||||||
|
"--http",
|
||||||
|
dest="http",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=(
|
||||||
|
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should "
|
||||||
|
"be either 1 or 3 [default 1]"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"-3",
|
||||||
|
dest="http",
|
||||||
|
action="store_const",
|
||||||
|
const=3,
|
||||||
|
help=("Run Sanic server using HTTP/3"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SocketGroup(Group):
|
class SocketGroup(Group):
|
||||||
name = "Socket binding"
|
name = "Socket binding"
|
||||||
|
|
||||||
|
@ -26,6 +26,9 @@ DEFAULT_CONFIG = {
|
|||||||
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
|
"GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec
|
||||||
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
|
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
|
||||||
"KEEP_ALIVE": True,
|
"KEEP_ALIVE": True,
|
||||||
|
"LOCAL_TLS_KEY": _default,
|
||||||
|
"LOCAL_TLS_CERT": _default,
|
||||||
|
"LOCALHOST": "localhost",
|
||||||
"MOTD": True,
|
"MOTD": True,
|
||||||
"MOTD_DISPLAY": {},
|
"MOTD_DISPLAY": {},
|
||||||
"NOISY_EXCEPTIONS": False,
|
"NOISY_EXCEPTIONS": False,
|
||||||
@ -68,9 +71,12 @@ class Config(dict, metaclass=DescriptorMeta):
|
|||||||
GRACEFUL_SHUTDOWN_TIMEOUT: float
|
GRACEFUL_SHUTDOWN_TIMEOUT: float
|
||||||
KEEP_ALIVE_TIMEOUT: int
|
KEEP_ALIVE_TIMEOUT: int
|
||||||
KEEP_ALIVE: bool
|
KEEP_ALIVE: bool
|
||||||
NOISY_EXCEPTIONS: bool
|
LOCAL_TLS_KEY: Union[Path, str, Default]
|
||||||
|
LOCAL_TLS_CERT: Union[Path, str, Default]
|
||||||
|
LOCALHOST: str
|
||||||
MOTD: bool
|
MOTD: bool
|
||||||
MOTD_DISPLAY: Dict[str, str]
|
MOTD_DISPLAY: Dict[str, str]
|
||||||
|
NOISY_EXCEPTIONS: bool
|
||||||
PROXIES_COUNT: Optional[int]
|
PROXIES_COUNT: Optional[int]
|
||||||
REAL_IP_HEADER: Optional[str]
|
REAL_IP_HEADER: Optional[str]
|
||||||
REGISTER: bool
|
REGISTER: bool
|
||||||
|
@ -26,3 +26,5 @@ class HTTPMethod(str, Enum):
|
|||||||
|
|
||||||
HTTP_METHODS = tuple(HTTPMethod.__members__.values())
|
HTTP_METHODS = tuple(HTTPMethod.__members__.values())
|
||||||
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
|
DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream"
|
||||||
|
DEFAULT_LOCAL_TLS_KEY = "key.pem"
|
||||||
|
DEFAULT_LOCAL_TLS_CERT = "cert.pem"
|
||||||
|
@ -1,18 +1,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
import asyncio
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from ast import Mod
|
||||||
|
from ssl import SSLContext
|
||||||
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from aioquic.h0.connection import H0_ALPN, H0Connection
|
from aioquic.h0.connection import H0_ALPN, H0Connection
|
||||||
from aioquic.h3.connection import H3_ALPN, H3Connection
|
from aioquic.h3.connection import H3_ALPN, H3Connection
|
||||||
from aioquic.h3.events import H3Event
|
from aioquic.h3.events import (
|
||||||
|
DatagramReceived,
|
||||||
# from aioquic.h3.events import (
|
DataReceived,
|
||||||
# DatagramReceived,
|
H3Event,
|
||||||
# DataReceived,
|
HeadersReceived,
|
||||||
# H3Event,
|
WebTransportStreamDataReceived,
|
||||||
# HeadersReceived,
|
)
|
||||||
# WebTransportStreamDataReceived,
|
|
||||||
# )
|
|
||||||
from aioquic.quic.configuration import QuicConfiguration
|
from aioquic.quic.configuration import QuicConfiguration
|
||||||
|
|
||||||
# from aioquic.quic.events import (
|
# from aioquic.quic.events import (
|
||||||
@ -22,89 +25,144 @@ from aioquic.quic.configuration import QuicConfiguration
|
|||||||
# )
|
# )
|
||||||
from aioquic.tls import SessionTicket
|
from aioquic.tls import SessionTicket
|
||||||
|
|
||||||
|
from sanic.compat import Header
|
||||||
|
from sanic.exceptions import SanicException
|
||||||
|
from sanic.http.tls import CertSimple
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sanic import Sanic
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
|
from sanic.response import BaseHTTPResponse
|
||||||
|
from sanic.server.protocols.http_protocol import Http3Protocol
|
||||||
|
|
||||||
# from sanic.compat import Header
|
# from sanic.compat import Header
|
||||||
|
# from sanic.application.state import Mode
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.response import BaseHTTPResponse
|
|
||||||
|
|
||||||
|
|
||||||
HttpConnection = Union[H0Connection, H3Connection]
|
HttpConnection = Union[H0Connection, H3Connection]
|
||||||
|
|
||||||
|
|
||||||
async def handler(request: Request):
|
|
||||||
logger.info(f"Request received: {request}")
|
|
||||||
response = await request.app.handle_request(request)
|
|
||||||
logger.info(f"Build response: {response=}")
|
|
||||||
|
|
||||||
|
|
||||||
class Transport:
|
class Transport:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Receiver(ABC):
|
||||||
|
def __init__(self, transmit, protocol, request) -> None:
|
||||||
|
self.transmit = transmit
|
||||||
|
self.protocol = protocol
|
||||||
|
self.request = request
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPReceiver(Receiver):
|
||||||
|
async def respond(self):
|
||||||
|
logger.info(f"Request received: {self.request}")
|
||||||
|
await self.protocol.app.handle_request(self.request)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketReceiver(Receiver):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class WebTransportReceiver(Receiver):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Http3:
|
class Http3:
|
||||||
|
HANDLER_PROPERTY_MAPPING = {
|
||||||
|
DataReceived: "stream_id",
|
||||||
|
HeadersReceived: "stream_id",
|
||||||
|
DatagramReceived: "flow_id",
|
||||||
|
WebTransportStreamDataReceived: "session_id",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connection: HttpConnection,
|
protocol: Http3Protocol,
|
||||||
transmit: Callable[[], None],
|
transmit: Callable[[], None],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_body = None
|
self.request_body = None
|
||||||
self.connection = connection
|
self.request: Optional[Request] = None
|
||||||
|
self.protocol = protocol
|
||||||
self.transmit = transmit
|
self.transmit = transmit
|
||||||
|
self.receivers: Dict[int, Receiver] = {}
|
||||||
|
|
||||||
def http_event_received(self, event: H3Event) -> None:
|
def http_event_received(self, event: H3Event) -> None:
|
||||||
print("[http_event_received]:", event)
|
print("[http_event_received]:", event)
|
||||||
# if isinstance(event, HeadersReceived):
|
receiver = self.get_or_make_receiver(event)
|
||||||
# method, path, *rem = event.headers
|
print(f"{receiver=}")
|
||||||
# headers = Header(((k.decode(), v.decode()) for k, v in rem))
|
|
||||||
# method = method[1].decode()
|
|
||||||
# path = path[1]
|
|
||||||
# scheme = headers.pop(":scheme")
|
|
||||||
# authority = headers.pop(":authority")
|
|
||||||
# print(f"{headers=}")
|
|
||||||
# print(f"{method=}")
|
|
||||||
# print(f"{path=}")
|
|
||||||
# print(f"{scheme=}")
|
|
||||||
# print(f"{authority=}")
|
|
||||||
# if authority:
|
|
||||||
# headers["host"] = authority
|
|
||||||
|
|
||||||
# request = Request(
|
|
||||||
# path, headers, "3", method, Transport(), app, b""
|
|
||||||
# )
|
|
||||||
# request.stream = Stream(
|
|
||||||
# connection=self._http, transmit=self.transmit
|
|
||||||
# )
|
|
||||||
# print(f"{request=}")
|
|
||||||
|
|
||||||
# asyncio.ensure_future(handler(request))
|
# asyncio.ensure_future(handler(request))
|
||||||
|
|
||||||
|
def get_or_make_receiver(self, event: H3Event) -> Receiver:
|
||||||
|
if (
|
||||||
|
isinstance(event, HeadersReceived)
|
||||||
|
and event.stream_id not in self.receivers
|
||||||
|
):
|
||||||
|
self.request = self._make_request(event)
|
||||||
|
receiver = HTTPReceiver(self.transmit, self.protocol, self.request)
|
||||||
|
self.receivers[event.stream_id] = receiver
|
||||||
|
asyncio.ensure_future(receiver.respond())
|
||||||
|
else:
|
||||||
|
ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)])
|
||||||
|
return self.receivers[ident]
|
||||||
|
|
||||||
|
def _make_request(self, event: HeadersReceived) -> Request:
|
||||||
|
method, path, *rem = event.headers
|
||||||
|
headers = Header(((k.decode(), v.decode()) for k, v in rem))
|
||||||
|
method = method[1].decode()
|
||||||
|
path = path[1]
|
||||||
|
scheme = headers.pop(":scheme")
|
||||||
|
authority = headers.pop(":authority")
|
||||||
|
print(f"{headers=}")
|
||||||
|
print(f"{method=}")
|
||||||
|
print(f"{path=}")
|
||||||
|
print(f"{scheme=}")
|
||||||
|
print(f"{authority=}")
|
||||||
|
if authority:
|
||||||
|
headers["host"] = authority
|
||||||
|
|
||||||
|
request = self.protocol.request_class(
|
||||||
|
path, headers, "3", method, Transport(), self.protocol.app, b""
|
||||||
|
)
|
||||||
|
request.stream = self
|
||||||
|
print(f"{request=}")
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse:
|
||||||
|
print(f"[respond]: {response=}")
|
||||||
|
response.headers.update({"foo": "bar"})
|
||||||
|
self.response, response.stream = response, self
|
||||||
|
|
||||||
|
# Need more appropriate place to send these
|
||||||
|
self.protocol.connection.send_headers(
|
||||||
|
stream_id=0,
|
||||||
|
headers=[
|
||||||
|
(b":status", str(self.response.status).encode()),
|
||||||
|
*(
|
||||||
|
(k.encode(), v.encode())
|
||||||
|
for k, v in self.response.headers.items()
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# TEMP
|
||||||
|
await self.drain(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def drain(self, response: BaseHTTPResponse) -> None:
|
||||||
|
await self.send(response.body, False)
|
||||||
|
|
||||||
async def send(self, data: bytes, end_stream: bool) -> None:
|
async def send(self, data: bytes, end_stream: bool) -> None:
|
||||||
print(f"[send]: {data=} {end_stream=}")
|
print(f"[send]: {data=} {end_stream=}")
|
||||||
print(self.response.headers)
|
print(self.response.headers)
|
||||||
# self.connection.send_headers(
|
self.protocol.connection.send_data(
|
||||||
# stream_id=0,
|
stream_id=0,
|
||||||
# headers=[
|
data=data,
|
||||||
# (b":status", str(self.response.status).encode()),
|
end_stream=end_stream,
|
||||||
# *(
|
)
|
||||||
# (k.encode(), v.encode())
|
self.transmit()
|
||||||
# for k, v in self.response.headers.items()
|
|
||||||
# ),
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# self.connection.send_data(
|
|
||||||
# stream_id=0,
|
|
||||||
# data=data,
|
|
||||||
# end_stream=end_stream,
|
|
||||||
# )
|
|
||||||
# self.transmit()
|
|
||||||
|
|
||||||
def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse:
|
|
||||||
print(f"[respond]: {response=}")
|
|
||||||
self.response, response.stream = response, self
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class SessionTicketStore:
|
class SessionTicketStore:
|
||||||
@ -122,13 +180,19 @@ class SessionTicketStore:
|
|||||||
return self.tickets.pop(label, None)
|
return self.tickets.pop(label, None)
|
||||||
|
|
||||||
|
|
||||||
def get_config():
|
def get_config(app: Sanic, ssl: SSLContext):
|
||||||
|
if not isinstance(ssl, CertSimple):
|
||||||
|
raise SanicException("SSLContext is not CertSimple")
|
||||||
|
|
||||||
config = QuicConfiguration(
|
config = QuicConfiguration(
|
||||||
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
|
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
|
||||||
is_client=False,
|
is_client=False,
|
||||||
max_datagram_frame_size=65536,
|
max_datagram_frame_size=65536,
|
||||||
)
|
)
|
||||||
config.load_cert_chain("./cert.pem", "./key.pem", password="qqqqqqqq")
|
# TODO:
|
||||||
|
# - add password kwarg, read from config.TLS_CERT_PASSWORD
|
||||||
|
config.load_cert_chain(ssl.sanic["cert"], ssl.sanic["key"])
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from inspect import isawaitable
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -205,6 +206,9 @@ class Request:
|
|||||||
# Connect the response
|
# Connect the response
|
||||||
if isinstance(response, BaseHTTPResponse) and self.stream:
|
if isinstance(response, BaseHTTPResponse) and self.stream:
|
||||||
response = self.stream.respond(response)
|
response = self.stream.respond(response)
|
||||||
|
|
||||||
|
if isawaitable(response):
|
||||||
|
response = await response
|
||||||
# Run response middleware
|
# Run response middleware
|
||||||
try:
|
try:
|
||||||
response = await self.app._run_response_middleware(
|
response = await self.app._run_response_middleware(
|
||||||
|
@ -25,6 +25,7 @@ from sanic.cookies import CookieJar
|
|||||||
from sanic.exceptions import SanicException, ServerError
|
from sanic.exceptions import SanicException, ServerError
|
||||||
from sanic.helpers import has_message_body, remove_entity_headers
|
from sanic.helpers import has_message_body, remove_entity_headers
|
||||||
from sanic.http import Http
|
from sanic.http import Http
|
||||||
|
from sanic.http.http3 import Http3
|
||||||
from sanic.models.protocol_types import HTMLProtocol, Range
|
from sanic.models.protocol_types import HTMLProtocol, Range
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class BaseHTTPResponse:
|
|||||||
self.asgi: bool = False
|
self.asgi: bool = False
|
||||||
self.body: Optional[bytes] = None
|
self.body: Optional[bytes] = None
|
||||||
self.content_type: Optional[str] = None
|
self.content_type: Optional[str] = None
|
||||||
self.stream: Optional[Union[Http, ASGIApp]] = None
|
self.stream: Optional[Union[Http, ASGIApp, Http3]] = None
|
||||||
self.status: int = None
|
self.status: int = None
|
||||||
self.headers = Header({})
|
self.headers = Header({})
|
||||||
self._cookies: Optional[CookieJar] = None
|
self._cookies: Optional[CookieJar] = None
|
||||||
@ -121,6 +122,7 @@ class BaseHTTPResponse:
|
|||||||
:param data: str or bytes to be written
|
:param data: str or bytes to be written
|
||||||
:param end_stream: whether to close the stream after this block
|
:param end_stream: whether to close the stream after this block
|
||||||
"""
|
"""
|
||||||
|
print(f">>> BaseHTTPResponse: {data=} {end_stream=} {self.body=}")
|
||||||
if data is None and end_stream is None:
|
if data is None and end_stream is None:
|
||||||
end_stream = True
|
end_stream = True
|
||||||
if self.stream is None:
|
if self.stream is None:
|
||||||
|
@ -262,7 +262,7 @@ class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol):
|
|||||||
self.app = app
|
self.app = app
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._setup()
|
self._setup()
|
||||||
self._connection = None
|
self._connection: Optional[H3Connection] = None
|
||||||
|
|
||||||
def quic_event_received(self, event: QuicEvent) -> None:
|
def quic_event_received(self, event: QuicEvent) -> None:
|
||||||
print("[quic_event_received]:", event)
|
print("[quic_event_received]:", event)
|
||||||
@ -282,3 +282,7 @@ class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol):
|
|||||||
if self._connection is not None:
|
if self._connection is not None:
|
||||||
for http_event in self._connection.handle_event(event):
|
for http_event in self._connection.handle_event(event):
|
||||||
self._http.http_event_received(http_event)
|
self._http.http_event_received(http_event)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection(self) -> Optional[H3Connection]:
|
||||||
|
return self._connection
|
||||||
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Type, Union
|
|||||||
|
|
||||||
from sanic.config import Config
|
from sanic.config import Config
|
||||||
from sanic.http.constants import HTTP
|
from sanic.http.constants import HTTP
|
||||||
|
from sanic.http.tls import get_ssl_context
|
||||||
from sanic.server.events import trigger_events
|
from sanic.server.events import trigger_events
|
||||||
|
|
||||||
|
|
||||||
@ -94,7 +95,7 @@ def serve(
|
|||||||
app.asgi = False
|
app.asgi = False
|
||||||
|
|
||||||
if version is HTTP.VERSION_3:
|
if version is HTTP.VERSION_3:
|
||||||
return serve_http_3(host, port, app, loop)
|
return serve_http_3(host, port, app, loop, ssl)
|
||||||
|
|
||||||
connections = connections if connections is not None else set()
|
connections = connections if connections is not None else set()
|
||||||
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
|
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
|
||||||
@ -200,10 +201,19 @@ def serve(
|
|||||||
remove_unix_socket(unix)
|
remove_unix_socket(unix)
|
||||||
|
|
||||||
|
|
||||||
def serve_http_3(host, port, app, loop):
|
def serve_http_3(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
app,
|
||||||
|
loop,
|
||||||
|
ssl,
|
||||||
|
register_sys_signals: bool = True,
|
||||||
|
run_multiple: bool = False,
|
||||||
|
):
|
||||||
protocol = partial(Http3Protocol, app=app)
|
protocol = partial(Http3Protocol, app=app)
|
||||||
ticket_store = get_ticket_store()
|
ticket_store = get_ticket_store()
|
||||||
config = get_config()
|
ssl_context = get_ssl_context(app, ssl)
|
||||||
|
config = get_config(app, ssl_context)
|
||||||
coro = quic_serve(
|
coro = quic_serve(
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
@ -214,6 +224,21 @@ def serve_http_3(host, port, app, loop):
|
|||||||
)
|
)
|
||||||
server = AsyncioServer(app, loop, coro, [])
|
server = AsyncioServer(app, loop, coro, [])
|
||||||
loop.run_until_complete(server.startup())
|
loop.run_until_complete(server.startup())
|
||||||
|
|
||||||
|
# TODO: Cleanup the non-DRY code block
|
||||||
|
# Ignore SIGINT when run_multiple
|
||||||
|
if run_multiple:
|
||||||
|
signal_func(SIGINT, SIG_IGN)
|
||||||
|
os.environ["SANIC_WORKER_PROCESS"] = "true"
|
||||||
|
|
||||||
|
# Register signals for graceful termination
|
||||||
|
if register_sys_signals:
|
||||||
|
if OS_IS_WINDOWS:
|
||||||
|
ctrlc_workaround_for_windows(app)
|
||||||
|
else:
|
||||||
|
for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]:
|
||||||
|
loop.add_signal_handler(_signal, app.stop)
|
||||||
|
|
||||||
loop.run_until_complete(server.before_start())
|
loop.run_until_complete(server.before_start())
|
||||||
loop.run_until_complete(server)
|
loop.run_until_complete(server)
|
||||||
loop.run_until_complete(server.after_start())
|
loop.run_until_complete(server.after_start())
|
||||||
|
196
sanic/tls.py
196
sanic/tls.py
@ -1,196 +0,0 @@
|
|||||||
import os
|
|
||||||
import ssl
|
|
||||||
|
|
||||||
from typing import Iterable, Optional, Union
|
|
||||||
|
|
||||||
from sanic.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
# Only allow secure ciphers, notably leaving out AES-CBC mode
|
|
||||||
# OpenSSL chooses ECDSA or RSA depending on the cert in use
|
|
||||||
CIPHERS_TLS12 = [
|
|
||||||
"ECDHE-ECDSA-CHACHA20-POLY1305",
|
|
||||||
"ECDHE-ECDSA-AES256-GCM-SHA384",
|
|
||||||
"ECDHE-ECDSA-AES128-GCM-SHA256",
|
|
||||||
"ECDHE-RSA-CHACHA20-POLY1305",
|
|
||||||
"ECDHE-RSA-AES256-GCM-SHA384",
|
|
||||||
"ECDHE-RSA-AES128-GCM-SHA256",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_context(
|
|
||||||
certfile: Optional[str] = None,
|
|
||||||
keyfile: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
) -> ssl.SSLContext:
|
|
||||||
"""Create a context with secure crypto and HTTP/1.1 in protocols."""
|
|
||||||
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
|
||||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
||||||
context.set_ciphers(":".join(CIPHERS_TLS12))
|
|
||||||
context.set_alpn_protocols(["http/1.1"])
|
|
||||||
context.sni_callback = server_name_callback
|
|
||||||
if certfile and keyfile:
|
|
||||||
context.load_cert_chain(certfile, keyfile, password)
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
def shorthand_to_ctx(
|
|
||||||
ctxdef: Union[None, ssl.SSLContext, dict, str]
|
|
||||||
) -> Optional[ssl.SSLContext]:
|
|
||||||
"""Convert an ssl argument shorthand to an SSLContext object."""
|
|
||||||
if ctxdef is None or isinstance(ctxdef, ssl.SSLContext):
|
|
||||||
return ctxdef
|
|
||||||
if isinstance(ctxdef, str):
|
|
||||||
return load_cert_dir(ctxdef)
|
|
||||||
if isinstance(ctxdef, dict):
|
|
||||||
return CertSimple(**ctxdef)
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid ssl argument {type(ctxdef)}."
|
|
||||||
" Expecting a list of certdirs, a dict or an SSLContext."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_to_context(
|
|
||||||
ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple]
|
|
||||||
) -> Optional[ssl.SSLContext]:
|
|
||||||
"""Process app.run ssl argument from easy formats to full SSLContext."""
|
|
||||||
return (
|
|
||||||
CertSelector(map(shorthand_to_ctx, ssldef))
|
|
||||||
if isinstance(ssldef, (list, tuple))
|
|
||||||
else shorthand_to_ctx(ssldef)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_cert_dir(p: str) -> ssl.SSLContext:
|
|
||||||
if os.path.isfile(p):
|
|
||||||
raise ValueError(f"Certificate folder expected but {p} is a file.")
|
|
||||||
keyfile = os.path.join(p, "privkey.pem")
|
|
||||||
certfile = os.path.join(p, "fullchain.pem")
|
|
||||||
if not os.access(keyfile, os.R_OK):
|
|
||||||
raise ValueError(
|
|
||||||
f"Certificate not found or permission denied {keyfile}"
|
|
||||||
)
|
|
||||||
if not os.access(certfile, os.R_OK):
|
|
||||||
raise ValueError(
|
|
||||||
f"Certificate not found or permission denied {certfile}"
|
|
||||||
)
|
|
||||||
return CertSimple(certfile, keyfile)
|
|
||||||
|
|
||||||
|
|
||||||
class CertSimple(ssl.SSLContext):
|
|
||||||
"""A wrapper for creating SSLContext with a sanic attribute."""
|
|
||||||
|
|
||||||
def __new__(cls, cert, key, **kw):
|
|
||||||
# try common aliases, rename to cert/key
|
|
||||||
certfile = kw["cert"] = kw.pop("certificate", None) or cert
|
|
||||||
keyfile = kw["key"] = kw.pop("keyfile", None) or key
|
|
||||||
password = kw.pop("password", None)
|
|
||||||
if not certfile or not keyfile:
|
|
||||||
raise ValueError("SSL dict needs filenames for cert and key.")
|
|
||||||
subject = {}
|
|
||||||
if "names" not in kw:
|
|
||||||
cert = ssl._ssl._test_decode_cert(certfile) # type: ignore
|
|
||||||
kw["names"] = [
|
|
||||||
name
|
|
||||||
for t, name in cert["subjectAltName"]
|
|
||||||
if t in ["DNS", "IP Address"]
|
|
||||||
]
|
|
||||||
subject = {k: v for item in cert["subject"] for k, v in item}
|
|
||||||
self = create_context(certfile, keyfile, password)
|
|
||||||
self.__class__ = cls
|
|
||||||
self.sanic = {**subject, **kw}
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __init__(self, cert, key, **kw):
|
|
||||||
pass # Do not call super().__init__ because it is already initialized
|
|
||||||
|
|
||||||
|
|
||||||
class CertSelector(ssl.SSLContext):
|
|
||||||
"""Automatically select SSL certificate based on the hostname that the
|
|
||||||
client is trying to access, via SSL SNI. Paths to certificate folders
|
|
||||||
with privkey.pem and fullchain.pem in them should be provided, and
|
|
||||||
will be matched in the order given whenever there is a new connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __new__(cls, ctxs):
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]):
|
|
||||||
super().__init__()
|
|
||||||
self.sni_callback = selector_sni_callback # type: ignore
|
|
||||||
self.sanic_select = []
|
|
||||||
self.sanic_fallback = None
|
|
||||||
all_names = []
|
|
||||||
for i, ctx in enumerate(ctxs):
|
|
||||||
if not ctx:
|
|
||||||
continue
|
|
||||||
names = dict(getattr(ctx, "sanic", {})).get("names", [])
|
|
||||||
all_names += names
|
|
||||||
self.sanic_select.append(ctx)
|
|
||||||
if i == 0:
|
|
||||||
self.sanic_fallback = ctx
|
|
||||||
if not all_names:
|
|
||||||
raise ValueError(
|
|
||||||
"No certificates with SubjectAlternativeNames found."
|
|
||||||
)
|
|
||||||
logger.info(f"Certificate vhosts: {', '.join(all_names)}")
|
|
||||||
|
|
||||||
|
|
||||||
def find_cert(self: CertSelector, server_name: str):
|
|
||||||
"""Find the first certificate that matches the given SNI.
|
|
||||||
|
|
||||||
:raises ssl.CertificateError: No matching certificate found.
|
|
||||||
:return: A matching ssl.SSLContext object if found."""
|
|
||||||
if not server_name:
|
|
||||||
if self.sanic_fallback:
|
|
||||||
return self.sanic_fallback
|
|
||||||
raise ValueError(
|
|
||||||
"The client provided no SNI to match for certificate."
|
|
||||||
)
|
|
||||||
for ctx in self.sanic_select:
|
|
||||||
if match_hostname(ctx, server_name):
|
|
||||||
return ctx
|
|
||||||
if self.sanic_fallback:
|
|
||||||
return self.sanic_fallback
|
|
||||||
raise ValueError(f"No certificate found matching hostname {server_name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
def match_hostname(
|
|
||||||
ctx: Union[ssl.SSLContext, CertSelector], hostname: str
|
|
||||||
) -> bool:
|
|
||||||
"""Match names from CertSelector against a received hostname."""
|
|
||||||
# Local certs are considered trusted, so this can be less pedantic
|
|
||||||
# and thus faster than the deprecated ssl.match_hostname function is.
|
|
||||||
names = dict(getattr(ctx, "sanic", {})).get("names", [])
|
|
||||||
hostname = hostname.lower()
|
|
||||||
for name in names:
|
|
||||||
if name.startswith("*."):
|
|
||||||
if hostname.split(".", 1)[-1] == name[2:]:
|
|
||||||
return True
|
|
||||||
elif name == hostname:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def selector_sni_callback(
|
|
||||||
sslobj: ssl.SSLObject, server_name: str, ctx: CertSelector
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""Select a certificate matching the SNI."""
|
|
||||||
# Call server_name_callback to store the SNI on sslobj
|
|
||||||
server_name_callback(sslobj, server_name, ctx)
|
|
||||||
# Find a new context matching the hostname
|
|
||||||
try:
|
|
||||||
sslobj.context = find_cert(ctx, server_name)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(f"Rejecting TLS connection: {e}")
|
|
||||||
# This would show ERR_SSL_UNRECOGNIZED_NAME_ALERT on client side if
|
|
||||||
# asyncio/uvloop did proper SSL shutdown. They don't.
|
|
||||||
return ssl.ALERT_DESCRIPTION_UNRECOGNIZED_NAME
|
|
||||||
return None # mypy complains without explicit return
|
|
||||||
|
|
||||||
|
|
||||||
def server_name_callback(
|
|
||||||
sslobj: ssl.SSLObject, server_name: str, ctx: ssl.SSLContext
|
|
||||||
) -> None:
|
|
||||||
"""Store the received SNI as sslobj.sanic_server_name."""
|
|
||||||
sslobj.sanic_server_name = server_name # type: ignore
|
|
Loading…
x
Reference in New Issue
Block a user