From a937977ca2052921901e0a1640be78fcbd38b00d Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 27 Dec 2021 23:22:37 +0200 Subject: [PATCH] WIP --- sanic/app.py | 3 +- sanic/cli/app.py | 3 + sanic/cli/arguments.py | 24 +++ sanic/config.py | 8 +- sanic/constants.py | 2 + sanic/http/http3.py | 192 +++++++++++++++-------- sanic/request.py | 4 + sanic/response.py | 4 +- sanic/server/protocols/http_protocol.py | 6 +- sanic/server/runners.py | 31 +++- sanic/tls.py | 196 ------------------------ 11 files changed, 206 insertions(+), 267 deletions(-) delete mode 100644 sanic/tls.py diff --git a/sanic/app.py b/sanic/app.py index cf80912d..55c4920e 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -76,6 +76,7 @@ from sanic.handlers import ErrorHandler from sanic.helpers import _default from sanic.http import Stage from sanic.http.constants import HTTP +from sanic.http.tls import process_to_context from sanic.log import ( LOGGING_CONFIG_DEFAULTS, Colors, @@ -104,7 +105,6 @@ from sanic.server import serve, serve_multiple, serve_single, try_use_uvloop from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter -from sanic.tls import process_to_context from sanic.touchup import TouchUp, TouchUpMeta @@ -1216,6 +1216,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): finally: self.is_running = False logger.info("Server Stopped") + print("END OF RUN") def stop(self): """ diff --git a/sanic/cli/app.py b/sanic/cli/app.py index 3001b6e1..9c5e55d2 100644 --- a/sanic/cli/app.py +++ b/sanic/cli/app.py @@ -11,6 +11,7 @@ from typing import Any, List, Union from sanic.app import Sanic from sanic.application.logo import get_logo from sanic.cli.arguments import Group +from sanic.http.constants import HTTP from sanic.log import error_logger from sanic.simple import create_simple_server @@ -160,6 +161,7 @@ Or, a path to a directory to run as a simple HTTP server: elif len(ssl) == 1 and ssl[0] is not None: # Use only one cert, no TLSSelector. ssl = ssl[0] + version = HTTP(self.args.http) kwargs = { "access_log": self.args.access_log, "debug": self.args.debug, @@ -172,6 +174,7 @@ Or, a path to a directory to run as a simple HTTP server: "unix": self.args.unix, "verbosity": self.args.verbosity or 0, "workers": self.args.workers, + "version": version, } if self.args.auto_reload: diff --git a/sanic/cli/arguments.py b/sanic/cli/arguments.py index 20644bdc..05e228f4 100644 --- a/sanic/cli/arguments.py +++ b/sanic/cli/arguments.py @@ -83,6 +83,30 @@ class ApplicationGroup(Group): ) +class HTTPVersionGroup(Group): + name = "HTTP version" + + def attach(self): + group = self.container.add_mutually_exclusive_group() + group.add_argument( + "--http", + dest="http", + type=int, + default=1, + help=( + "Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should " + "be either 1 or 3 [default 1]" + ), + ) + group.add_argument( + "-3", + dest="http", + action="store_const", + const=3, + help=("Run Sanic server using HTTP/3"), + ) + + class SocketGroup(Group): name = "Socket binding" diff --git a/sanic/config.py b/sanic/config.py index 30c8627f..c99d9abe 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -26,6 +26,9 @@ DEFAULT_CONFIG = { "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds "KEEP_ALIVE": True, + "LOCAL_TLS_KEY": _default, + "LOCAL_TLS_CERT": _default, + "LOCALHOST": "localhost", "MOTD": True, "MOTD_DISPLAY": {}, "NOISY_EXCEPTIONS": False, @@ -68,9 +71,12 @@ class Config(dict, metaclass=DescriptorMeta): GRACEFUL_SHUTDOWN_TIMEOUT: float KEEP_ALIVE_TIMEOUT: int KEEP_ALIVE: bool - NOISY_EXCEPTIONS: bool + LOCAL_TLS_KEY: Union[Path, str, Default] + LOCAL_TLS_CERT: Union[Path, str, Default] + LOCALHOST: str MOTD: bool MOTD_DISPLAY: Dict[str, str] + NOISY_EXCEPTIONS: bool PROXIES_COUNT: Optional[int] REAL_IP_HEADER: Optional[str] REGISTER: bool diff --git a/sanic/constants.py b/sanic/constants.py index 80f1d2a9..0bc68f26 100644 --- a/sanic/constants.py +++ b/sanic/constants.py @@ -26,3 +26,5 @@ class HTTPMethod(str, Enum): HTTP_METHODS = tuple(HTTPMethod.__members__.values()) DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" +DEFAULT_LOCAL_TLS_KEY = "key.pem" +DEFAULT_LOCAL_TLS_CERT = "cert.pem" diff --git a/sanic/http/http3.py b/sanic/http/http3.py index cc82a51d..ce0d0797 100644 --- a/sanic/http/http3.py +++ b/sanic/http/http3.py @@ -1,18 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +import asyncio + +from abc import ABC +from ast import Mod +from ssl import SSLContext +from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.connection import H3_ALPN, H3Connection -from aioquic.h3.events import H3Event - -# from aioquic.h3.events import ( -# DatagramReceived, -# DataReceived, -# H3Event, -# HeadersReceived, -# WebTransportStreamDataReceived, -# ) +from aioquic.h3.events import ( + DatagramReceived, + DataReceived, + H3Event, + HeadersReceived, + WebTransportStreamDataReceived, +) from aioquic.quic.configuration import QuicConfiguration # from aioquic.quic.events import ( @@ -22,89 +25,144 @@ from aioquic.quic.configuration import QuicConfiguration # ) from aioquic.tls import SessionTicket +from sanic.compat import Header +from sanic.exceptions import SanicException +from sanic.http.tls import CertSimple + if TYPE_CHECKING: + from sanic import Sanic from sanic.request import Request + from sanic.response import BaseHTTPResponse + from sanic.server.protocols.http_protocol import Http3Protocol # from sanic.compat import Header +# from sanic.application.state import Mode from sanic.log import logger -from sanic.response import BaseHTTPResponse HttpConnection = Union[H0Connection, H3Connection] -async def handler(request: Request): - logger.info(f"Request received: {request}") - response = await request.app.handle_request(request) - logger.info(f"Build response: {response=}") - - class Transport: ... +class Receiver(ABC): + def __init__(self, transmit, protocol, request) -> None: + self.transmit = transmit + self.protocol = protocol + self.request = request + + +class HTTPReceiver(Receiver): + async def respond(self): + logger.info(f"Request received: {self.request}") + await self.protocol.app.handle_request(self.request) + + +class WebsocketReceiver(Receiver): + ... + + +class WebTransportReceiver(Receiver): + ... + + class Http3: + HANDLER_PROPERTY_MAPPING = { + DataReceived: "stream_id", + HeadersReceived: "stream_id", + DatagramReceived: "flow_id", + WebTransportStreamDataReceived: "session_id", + } + def __init__( self, - connection: HttpConnection, + protocol: Http3Protocol, transmit: Callable[[], None], ) -> None: self.request_body = None - self.connection = connection + self.request: Optional[Request] = None + self.protocol = protocol self.transmit = transmit + self.receivers: Dict[int, Receiver] = {} def http_event_received(self, event: H3Event) -> None: print("[http_event_received]:", event) - # if isinstance(event, HeadersReceived): - # method, path, *rem = event.headers - # headers = Header(((k.decode(), v.decode()) for k, v in rem)) - # method = method[1].decode() - # path = path[1] - # scheme = headers.pop(":scheme") - # authority = headers.pop(":authority") - # print(f"{headers=}") - # print(f"{method=}") - # print(f"{path=}") - # print(f"{scheme=}") - # print(f"{authority=}") - # if authority: - # headers["host"] = authority - - # request = Request( - # path, headers, "3", method, Transport(), app, b"" - # ) - # request.stream = Stream( - # connection=self._http, transmit=self.transmit - # ) - # print(f"{request=}") + receiver = self.get_or_make_receiver(event) + print(f"{receiver=}") # asyncio.ensure_future(handler(request)) + def get_or_make_receiver(self, event: H3Event) -> Receiver: + if ( + isinstance(event, HeadersReceived) + and event.stream_id not in self.receivers + ): + self.request = self._make_request(event) + receiver = HTTPReceiver(self.transmit, self.protocol, self.request) + self.receivers[event.stream_id] = receiver + asyncio.ensure_future(receiver.respond()) + else: + ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)]) + return self.receivers[ident] + + def _make_request(self, event: HeadersReceived) -> Request: + method, path, *rem = event.headers + headers = Header(((k.decode(), v.decode()) for k, v in rem)) + method = method[1].decode() + path = path[1] + scheme = headers.pop(":scheme") + authority = headers.pop(":authority") + print(f"{headers=}") + print(f"{method=}") + print(f"{path=}") + print(f"{scheme=}") + print(f"{authority=}") + if authority: + headers["host"] = authority + + request = self.protocol.request_class( + path, headers, "3", method, Transport(), self.protocol.app, b"" + ) + request.stream = self + print(f"{request=}") + return request + + async def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse: + print(f"[respond]: {response=}") + response.headers.update({"foo": "bar"}) + self.response, response.stream = response, self + + # Need more appropriate place to send these + self.protocol.connection.send_headers( + stream_id=0, + headers=[ + (b":status", str(self.response.status).encode()), + *( + (k.encode(), v.encode()) + for k, v in self.response.headers.items() + ), + ], + ) + # TEMP + await self.drain(response) + + return response + + async def drain(self, response: BaseHTTPResponse) -> None: + await self.send(response.body, False) + async def send(self, data: bytes, end_stream: bool) -> None: print(f"[send]: {data=} {end_stream=}") print(self.response.headers) - # self.connection.send_headers( - # stream_id=0, - # headers=[ - # (b":status", str(self.response.status).encode()), - # *( - # (k.encode(), v.encode()) - # for k, v in self.response.headers.items() - # ), - # ], - # ) - # self.connection.send_data( - # stream_id=0, - # data=data, - # end_stream=end_stream, - # ) - # self.transmit() - - def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse: - print(f"[respond]: {response=}") - self.response, response.stream = response, self - return response + self.protocol.connection.send_data( + stream_id=0, + data=data, + end_stream=end_stream, + ) + self.transmit() class SessionTicketStore: @@ -122,13 +180,19 @@ class SessionTicketStore: return self.tickets.pop(label, None) -def get_config(): +def get_config(app: Sanic, ssl: SSLContext): + if not isinstance(ssl, CertSimple): + raise SanicException("SSLContext is not CertSimple") + config = QuicConfiguration( alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], is_client=False, max_datagram_frame_size=65536, ) - config.load_cert_chain("./cert.pem", "./key.pem", password="qqqqqqqq") + # TODO: + # - add password kwarg, read from config.TLS_CERT_PASSWORD + config.load_cert_chain(ssl.sanic["cert"], ssl.sanic["key"]) + return config diff --git a/sanic/request.py b/sanic/request.py index 97ab9982..ee535ac2 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import isawaitable from typing import ( TYPE_CHECKING, Any, @@ -205,6 +206,9 @@ class Request: # Connect the response if isinstance(response, BaseHTTPResponse) and self.stream: response = self.stream.respond(response) + + if isawaitable(response): + response = await response # Run response middleware try: response = await self.app._run_response_middleware( diff --git a/sanic/response.py b/sanic/response.py index 8525d381..cd3bd952 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -25,6 +25,7 @@ from sanic.cookies import CookieJar from sanic.exceptions import SanicException, ServerError from sanic.helpers import has_message_body, remove_entity_headers from sanic.http import Http +from sanic.http.http3 import Http3 from sanic.models.protocol_types import HTMLProtocol, Range @@ -56,7 +57,7 @@ class BaseHTTPResponse: self.asgi: bool = False self.body: Optional[bytes] = None self.content_type: Optional[str] = None - self.stream: Optional[Union[Http, ASGIApp]] = None + self.stream: Optional[Union[Http, ASGIApp, Http3]] = None self.status: int = None self.headers = Header({}) self._cookies: Optional[CookieJar] = None @@ -121,6 +122,7 @@ class BaseHTTPResponse: :param data: str or bytes to be written :param end_stream: whether to close the stream after this block """ + print(f">>> BaseHTTPResponse: {data=} {end_stream=} {self.body=}") if data is None and end_stream is None: end_stream = True if self.stream is None: diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index 551603e0..df2ed699 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -262,7 +262,7 @@ class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol): self.app = app super().__init__(*args, **kwargs) self._setup() - self._connection = None + self._connection: Optional[H3Connection] = None def quic_event_received(self, event: QuicEvent) -> None: print("[quic_event_received]:", event) @@ -282,3 +282,7 @@ class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol): if self._connection is not None: for http_event in self._connection.handle_event(event): self._http.http_event_received(http_event) + + @property + def connection(self) -> Optional[H3Connection]: + return self._connection diff --git a/sanic/server/runners.py b/sanic/server/runners.py index aed22ffe..b5a9d9c2 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Type, Union from sanic.config import Config from sanic.http.constants import HTTP +from sanic.http.tls import get_ssl_context from sanic.server.events import trigger_events @@ -94,7 +95,7 @@ def serve( app.asgi = False if version is HTTP.VERSION_3: - return serve_http_3(host, port, app, loop) + return serve_http_3(host, port, app, loop, ssl) connections = connections if connections is not None else set() protocol_kwargs = _build_protocol_kwargs(protocol, app.config) @@ -200,10 +201,19 @@ def serve( remove_unix_socket(unix) -def serve_http_3(host, port, app, loop): +def serve_http_3( + host, + port, + app, + loop, + ssl, + register_sys_signals: bool = True, + run_multiple: bool = False, +): protocol = partial(Http3Protocol, app=app) ticket_store = get_ticket_store() - config = get_config() + ssl_context = get_ssl_context(app, ssl) + config = get_config(app, ssl_context) coro = quic_serve( host, port, @@ -214,6 +224,21 @@ def serve_http_3(host, port, app, loop): ) server = AsyncioServer(app, loop, coro, []) loop.run_until_complete(server.startup()) + + # TODO: Cleanup the non-DRY code block + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + os.environ["SANIC_WORKER_PROCESS"] = "true" + + # Register signals for graceful termination + if register_sys_signals: + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) + loop.run_until_complete(server.before_start()) loop.run_until_complete(server) loop.run_until_complete(server.after_start()) diff --git a/sanic/tls.py b/sanic/tls.py deleted file mode 100644 index be30f4a2..00000000 --- a/sanic/tls.py +++ /dev/null @@ -1,196 +0,0 @@ -import os -import ssl - -from typing import Iterable, Optional, Union - -from sanic.log import logger - - -# Only allow secure ciphers, notably leaving out AES-CBC mode -# OpenSSL chooses ECDSA or RSA depending on the cert in use -CIPHERS_TLS12 = [ - "ECDHE-ECDSA-CHACHA20-POLY1305", - "ECDHE-ECDSA-AES256-GCM-SHA384", - "ECDHE-ECDSA-AES128-GCM-SHA256", - "ECDHE-RSA-CHACHA20-POLY1305", - "ECDHE-RSA-AES256-GCM-SHA384", - "ECDHE-RSA-AES128-GCM-SHA256", -] - - -def create_context( - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - password: Optional[str] = None, -) -> ssl.SSLContext: - """Create a context with secure crypto and HTTP/1.1 in protocols.""" - context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) - context.minimum_version = ssl.TLSVersion.TLSv1_2 - context.set_ciphers(":".join(CIPHERS_TLS12)) - context.set_alpn_protocols(["http/1.1"]) - context.sni_callback = server_name_callback - if certfile and keyfile: - context.load_cert_chain(certfile, keyfile, password) - return context - - -def shorthand_to_ctx( - ctxdef: Union[None, ssl.SSLContext, dict, str] -) -> Optional[ssl.SSLContext]: - """Convert an ssl argument shorthand to an SSLContext object.""" - if ctxdef is None or isinstance(ctxdef, ssl.SSLContext): - return ctxdef - if isinstance(ctxdef, str): - return load_cert_dir(ctxdef) - if isinstance(ctxdef, dict): - return CertSimple(**ctxdef) - raise ValueError( - f"Invalid ssl argument {type(ctxdef)}." - " Expecting a list of certdirs, a dict or an SSLContext." - ) - - -def process_to_context( - ssldef: Union[None, ssl.SSLContext, dict, str, list, tuple] -) -> Optional[ssl.SSLContext]: - """Process app.run ssl argument from easy formats to full SSLContext.""" - return ( - CertSelector(map(shorthand_to_ctx, ssldef)) - if isinstance(ssldef, (list, tuple)) - else shorthand_to_ctx(ssldef) - ) - - -def load_cert_dir(p: str) -> ssl.SSLContext: - if os.path.isfile(p): - raise ValueError(f"Certificate folder expected but {p} is a file.") - keyfile = os.path.join(p, "privkey.pem") - certfile = os.path.join(p, "fullchain.pem") - if not os.access(keyfile, os.R_OK): - raise ValueError( - f"Certificate not found or permission denied {keyfile}" - ) - if not os.access(certfile, os.R_OK): - raise ValueError( - f"Certificate not found or permission denied {certfile}" - ) - return CertSimple(certfile, keyfile) - - -class CertSimple(ssl.SSLContext): - """A wrapper for creating SSLContext with a sanic attribute.""" - - def __new__(cls, cert, key, **kw): - # try common aliases, rename to cert/key - certfile = kw["cert"] = kw.pop("certificate", None) or cert - keyfile = kw["key"] = kw.pop("keyfile", None) or key - password = kw.pop("password", None) - if not certfile or not keyfile: - raise ValueError("SSL dict needs filenames for cert and key.") - subject = {} - if "names" not in kw: - cert = ssl._ssl._test_decode_cert(certfile) # type: ignore - kw["names"] = [ - name - for t, name in cert["subjectAltName"] - if t in ["DNS", "IP Address"] - ] - subject = {k: v for item in cert["subject"] for k, v in item} - self = create_context(certfile, keyfile, password) - self.__class__ = cls - self.sanic = {**subject, **kw} - return self - - def __init__(self, cert, key, **kw): - pass # Do not call super().__init__ because it is already initialized - - -class CertSelector(ssl.SSLContext): - """Automatically select SSL certificate based on the hostname that the - client is trying to access, via SSL SNI. Paths to certificate folders - with privkey.pem and fullchain.pem in them should be provided, and - will be matched in the order given whenever there is a new connection. - """ - - def __new__(cls, ctxs): - return super().__new__(cls) - - def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]): - super().__init__() - self.sni_callback = selector_sni_callback # type: ignore - self.sanic_select = [] - self.sanic_fallback = None - all_names = [] - for i, ctx in enumerate(ctxs): - if not ctx: - continue - names = dict(getattr(ctx, "sanic", {})).get("names", []) - all_names += names - self.sanic_select.append(ctx) - if i == 0: - self.sanic_fallback = ctx - if not all_names: - raise ValueError( - "No certificates with SubjectAlternativeNames found." - ) - logger.info(f"Certificate vhosts: {', '.join(all_names)}") - - -def find_cert(self: CertSelector, server_name: str): - """Find the first certificate that matches the given SNI. - - :raises ssl.CertificateError: No matching certificate found. - :return: A matching ssl.SSLContext object if found.""" - if not server_name: - if self.sanic_fallback: - return self.sanic_fallback - raise ValueError( - "The client provided no SNI to match for certificate." - ) - for ctx in self.sanic_select: - if match_hostname(ctx, server_name): - return ctx - if self.sanic_fallback: - return self.sanic_fallback - raise ValueError(f"No certificate found matching hostname {server_name!r}") - - -def match_hostname( - ctx: Union[ssl.SSLContext, CertSelector], hostname: str -) -> bool: - """Match names from CertSelector against a received hostname.""" - # Local certs are considered trusted, so this can be less pedantic - # and thus faster than the deprecated ssl.match_hostname function is. - names = dict(getattr(ctx, "sanic", {})).get("names", []) - hostname = hostname.lower() - for name in names: - if name.startswith("*."): - if hostname.split(".", 1)[-1] == name[2:]: - return True - elif name == hostname: - return True - return False - - -def selector_sni_callback( - sslobj: ssl.SSLObject, server_name: str, ctx: CertSelector -) -> Optional[int]: - """Select a certificate matching the SNI.""" - # Call server_name_callback to store the SNI on sslobj - server_name_callback(sslobj, server_name, ctx) - # Find a new context matching the hostname - try: - sslobj.context = find_cert(ctx, server_name) - except ValueError as e: - logger.warning(f"Rejecting TLS connection: {e}") - # This would show ERR_SSL_UNRECOGNIZED_NAME_ALERT on client side if - # asyncio/uvloop did proper SSL shutdown. They don't. - return ssl.ALERT_DESCRIPTION_UNRECOGNIZED_NAME - return None # mypy complains without explicit return - - -def server_name_callback( - sslobj: ssl.SSLObject, server_name: str, ctx: ssl.SSLContext -) -> None: - """Store the received SNI as sslobj.sanic_server_name.""" - sslobj.sanic_server_name = server_name # type: ignore