Merge pull request #2380 from sanic-org/http3-startup

This commit is contained in:
Adam Hopkins 2022-02-22 09:16:28 +02:00 committed by GitHub
commit e1937149c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 544 additions and 106 deletions

View File

@ -0,0 +1,88 @@
import os
import sys
import time
from contextlib import contextmanager
from curses.ascii import SP
from queue import Queue
from threading import Thread
if os.name == "nt":
import ctypes
import msvcrt
class _CursorInfo(ctypes.Structure):
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
class Spinner:
def __init__(self, message: str) -> None:
self.message = message
self.queue: Queue[int] = Queue()
self.spinner = self.cursor()
self.thread = Thread(target=self.run)
def start(self):
self.queue.put(1)
self.thread.start()
self.hide()
def run(self):
while self.queue.get():
output = f"\r{self.message} [{next(self.spinner)}]"
sys.stdout.write(output)
sys.stdout.flush()
time.sleep(0.1)
self.queue.put(1)
def stop(self):
self.queue.put(0)
self.thread.join()
self.show()
@staticmethod
def cursor():
while True:
for cursor in "|/-\\":
yield cursor
@staticmethod
def hide():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = False
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25l")
sys.stdout.flush()
@staticmethod
def show():
if os.name == "nt":
ci = _CursorInfo()
handle = ctypes.windll.kernel32.GetStdHandle(-11)
ctypes.windll.kernel32.GetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
ci.visible = True
ctypes.windll.kernel32.SetConsoleCursorInfo(
handle, ctypes.byref(ci)
)
elif os.name == "posix":
sys.stdout.write("\033[?25h")
sys.stdout.flush()
@contextmanager
def loading(message: str = "Loading"):
spinner = Spinner(message)
spinner.start()
yield
spinner.stop()

View File

@ -1,14 +1,15 @@
from __future__ import annotations
import warnings import warnings
from typing import Optional from typing import TYPE_CHECKING, Optional
from urllib.parse import quote from urllib.parse import quote
import sanic.app # noqa
from sanic.compat import Header from sanic.compat import Header
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http import Stage from sanic.http import Stage
from sanic.log import logger
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
from sanic.request import Request from sanic.request import Request
from sanic.response import BaseHTTPResponse from sanic.response import BaseHTTPResponse
@ -16,30 +17,35 @@ from sanic.server import ConnInfo
from sanic.server.websockets.connection import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
if TYPE_CHECKING: # no cov
from sanic import Sanic
class Lifespan: class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None: def __init__(self, asgi_app: ASGIApp) -> None:
self.asgi_app = asgi_app self.asgi_app = asgi_app
if ( if self.asgi_app.sanic_app.state.verbosity > 0:
"server.init.before" if (
in self.asgi_app.sanic_app.signal_router.name_index "server.init.before"
): in self.asgi_app.sanic_app.signal_router.name_index
warnings.warn( ):
'You have set a listener for "before_server_start" ' logger.debug(
"in ASGI mode. " 'You have set a listener for "before_server_start" '
"It will be executed as early as possible, but not before " "in ASGI mode. "
"the ASGI server is started." "It will be executed as early as possible, but not before "
) "the ASGI server is started."
if ( )
"server.shutdown.after" if (
in self.asgi_app.sanic_app.signal_router.name_index "server.shutdown.after"
): in self.asgi_app.sanic_app.signal_router.name_index
warnings.warn( ):
'You have set a listener for "after_server_stop" ' logger.debug(
"in ASGI mode. " 'You have set a listener for "after_server_stop" '
"It will be executed as late as possible, but not after " "in ASGI mode. "
"the ASGI server is stopped." "It will be executed as late as possible, but not after "
) "the ASGI server is stopped."
)
async def startup(self) -> None: async def startup(self) -> None:
""" """
@ -88,7 +94,7 @@ class Lifespan:
class ASGIApp: class ASGIApp:
sanic_app: "sanic.app.Sanic" sanic_app: Sanic
request: Request request: Request
transport: MockTransport transport: MockTransport
lifespan: Lifespan lifespan: Lifespan

View File

@ -11,7 +11,6 @@ from typing import Any, List, Union
from sanic.app import Sanic from sanic.app import Sanic
from sanic.application.logo import get_logo from sanic.application.logo import get_logo
from sanic.cli.arguments import Group from sanic.cli.arguments import Group
from sanic.http.constants import HTTP
from sanic.log import error_logger from sanic.log import error_logger
from sanic.simple import create_simple_server from sanic.simple import create_simple_server
@ -59,10 +58,13 @@ Or, a path to a directory to run as a simple HTTP server:
os.environ.get("SANIC_RELOADER_PROCESS", "") != "true" os.environ.get("SANIC_RELOADER_PROCESS", "") != "true"
) )
self.args: List[Any] = [] self.args: List[Any] = []
self.groups: List[Group] = []
def attach(self): def attach(self):
for group in Group._registry: for group in Group._registry:
group.create(self.parser).attach() instance = group.create(self.parser)
instance.attach()
self.groups.append(instance)
def run(self): def run(self):
# This is to provide backwards compat -v to display version # This is to provide backwards compat -v to display version
@ -75,9 +77,13 @@ Or, a path to a directory to run as a simple HTTP server:
try: try:
app = self._get_app() app = self._get_app()
kwargs = self._build_run_kwargs() kwargs = self._build_run_kwargs()
app.run(**kwargs)
except ValueError: except ValueError:
error_logger.exception("Failed to run app") error_logger.exception("Failed to run app")
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)
Sanic.serve()
def _precheck(self): def _precheck(self):
# # Custom TLS mismatch handling for better diagnostics # # Custom TLS mismatch handling for better diagnostics
@ -137,11 +143,14 @@ Or, a path to a directory to run as a simple HTTP server:
" Example File: project/sanic_server.py -> app\n" " Example File: project/sanic_server.py -> app\n"
" Example Module: project.sanic_server.app" " Example Module: project.sanic_server.app"
) )
sys.exit(1)
else: else:
raise e raise e
return app return app
def _build_run_kwargs(self): def _build_run_kwargs(self):
for group in self.groups:
group.prepare(self.args)
ssl: Union[None, dict, str, list] = [] ssl: Union[None, dict, str, list] = []
if self.args.tlshost: if self.args.tlshost:
ssl.append(None) ssl.append(None)
@ -154,7 +163,6 @@ Or, a path to a directory to run as a simple HTTP server:
elif len(ssl) == 1 and ssl[0] is not None: elif len(ssl) == 1 and ssl[0] is not None:
# Use only one cert, no TLSSelector. # Use only one cert, no TLSSelector.
ssl = ssl[0] ssl = ssl[0]
version = HTTP(self.args.http)
kwargs = { kwargs = {
"access_log": self.args.access_log, "access_log": self.args.access_log,
"debug": self.args.debug, "debug": self.args.debug,
@ -167,7 +175,7 @@ Or, a path to a directory to run as a simple HTTP server:
"unix": self.args.unix, "unix": self.args.unix,
"verbosity": self.args.verbosity or 0, "verbosity": self.args.verbosity or 0,
"workers": self.args.workers, "workers": self.args.workers,
"version": version, "auto_cert": self.args.auto_cert,
} }
for maybe_arg in ("auto_reload", "dev"): for maybe_arg in ("auto_reload", "dev"):
@ -177,4 +185,5 @@ Or, a path to a directory to run as a simple HTTP server:
if self.args.path: if self.args.path:
kwargs["auto_reload"] = True kwargs["auto_reload"] = True
kwargs["reload_dir"] = self.args.path kwargs["reload_dir"] = self.args.path
return kwargs return kwargs

View File

@ -6,6 +6,7 @@ from typing import List, Optional, Type, Union
from sanic_routing import __version__ as __routing_version__ # type: ignore from sanic_routing import __version__ as __routing_version__ # type: ignore
from sanic import __version__ from sanic import __version__
from sanic.http.constants import HTTP
class Group: class Group:
@ -38,6 +39,9 @@ class Group:
"--no-" + args[0][2:], *args[1:], action="store_false", **kwargs "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs
) )
def prepare(self, args) -> None:
...
class GeneralGroup(Group): class GeneralGroup(Group):
name = None name = None
@ -87,25 +91,39 @@ class HTTPVersionGroup(Group):
name = "HTTP version" name = "HTTP version"
def attach(self): def attach(self):
group = self.container.add_mutually_exclusive_group() http_values = [http.value for http in HTTP.__members__.values()]
group.add_argument(
self.container.add_argument(
"--http", "--http",
dest="http", dest="http",
action="append",
choices=http_values,
type=int, type=int,
default=1,
help=( help=(
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should " "Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n"
"be either 1 or 3 [default 1]" "be either 1, or 3. [default 1]"
), ),
) )
group.add_argument( self.container.add_argument(
"-1",
dest="http",
action="append_const",
const=1,
help=("Run Sanic server using HTTP/1.1"),
)
self.container.add_argument(
"-3", "-3",
dest="http", dest="http",
action="store_const", action="append_const",
const=3, const=3,
help=("Run Sanic server using HTTP/3"), help=("Run Sanic server using HTTP/3"),
) )
def prepare(self, args):
if not args.http:
args.http = [1]
args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True))
class SocketGroup(Group): class SocketGroup(Group):
name = "Socket binding" name = "Socket binding"
@ -116,7 +134,6 @@ class SocketGroup(Group):
"--host", "--host",
dest="host", dest="host",
type=str, type=str,
default="127.0.0.1",
help="Host address [default 127.0.0.1]", help="Host address [default 127.0.0.1]",
) )
self.container.add_argument( self.container.add_argument(
@ -124,7 +141,6 @@ class SocketGroup(Group):
"--port", "--port",
dest="port", dest="port",
type=int, type=int,
default=8000,
help="Port to serve on [default 8000]", help="Port to serve on [default 8000]",
) )
self.container.add_argument( self.container.add_argument(
@ -233,7 +249,16 @@ class DevelopmentGroup(Group):
"--dev", "--dev",
dest="dev", dest="dev",
action="store_true", action="store_true",
help=("debug + auto reload."), help=("debug + auto reload"),
)
self.container.add_argument(
"--auto-cert",
dest="auto_cert",
action="store_true",
help=(
"Create a temporary TLS certificate for local development "
"(requires mkcert)"
),
) )

View File

@ -1,4 +1,4 @@
from enum import Enum from enum import Enum, IntEnum
class Stage(Enum): class Stage(Enum):
@ -20,6 +20,12 @@ class Stage(Enum):
FAILED = 100 # Unrecoverable state (error while sending response) FAILED = 100 # Unrecoverable state (error while sending response)
class HTTP(Enum): class HTTP(IntEnum):
VERSION_1 = 1 VERSION_1 = 1
VERSION_3 = 3 VERSION_3 = 3
def display(self) -> str:
value = str(self.value)
if value == 1:
value = "1.1"
return f"HTTP/{value}"

View File

@ -333,6 +333,12 @@ class Http(metaclass=TouchUpMeta):
self.response_func = self.head_response_ignored self.response_func = self.head_response_ignored
headers["connection"] = "keep-alive" if self.keep_alive else "close" headers["connection"] = "keep-alive" if self.keep_alive else "close"
# This header may be removed or modified by the AltSvcCheck Touchup
# service. At server start, we either remove this header from ever
# being assigned, or we change the value as required.
headers["alt-svc"] = ""
ret = format_http1_response(status, res.processed_headers) ret = format_http1_response(status, res.processed_headers)
if data: if data:
ret += data ret += data

View File

@ -3,15 +3,16 @@ from __future__ import annotations
import os import os
import ssl import ssl
import subprocess import subprocess
import sys
from contextlib import suppress from contextlib import suppress
from inspect import currentframe, getframeinfo
from pathlib import Path from pathlib import Path
from ssl import SSLContext from ssl import SSLContext
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from sanic.application.constants import Mode from sanic.application.constants import Mode
from sanic.application.spinner import loading
from sanic.constants import DEFAULT_LOCAL_TLS_CERT, DEFAULT_LOCAL_TLS_KEY from sanic.constants import DEFAULT_LOCAL_TLS_CERT, DEFAULT_LOCAL_TLS_KEY
from sanic.exceptions import SanicException from sanic.exceptions import SanicException
from sanic.helpers import Default from sanic.helpers import Default
@ -233,7 +234,7 @@ def get_ssl_context(app: Sanic, ssl: Optional[SSLContext]) -> SSLContext:
if app.state.mode is Mode.PRODUCTION: if app.state.mode is Mode.PRODUCTION:
raise SanicException( raise SanicException(
"Cannot run Sanic as an HTTP/3 server in PRODUCTION mode " "Cannot run Sanic as an HTTPS server in PRODUCTION mode "
"without passing a TLS certificate. If you are developing " "without passing a TLS certificate. If you are developing "
"locally, please enable DEVELOPMENT mode and Sanic will " "locally, please enable DEVELOPMENT mode and Sanic will "
"generate a localhost TLS certificate. For more information " "generate a localhost TLS certificate. For more information "
@ -283,15 +284,32 @@ def generate_local_certificate(
): ):
check_mkcert() check_mkcert()
cmd = [ if not key_path.parent.exists() or not cert_path.parent.exists():
"mkcert", raise SanicException(
"-key-file", f"Cannot generate certificate at [{key_path}, {cert_path}]. One "
str(key_path), "or more of the directories does not exist."
"-cert-file", )
str(cert_path),
localhost, message = "Generating TLS certificate"
] with loading(message):
subprocess.run(cmd, check=True) cmd = [
"mkcert",
"-key-file",
str(key_path),
"-cert-file",
str(cert_path),
localhost,
]
resp = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
sys.stdout.write("\r" + " " * (len(message) + 4))
sys.stdout.flush()
sys.stdout.write(resp.stdout)
def check_mkcert(): def check_mkcert():

View File

@ -26,8 +26,10 @@ from typing import (
Literal, Literal,
Optional, Optional,
Set, Set,
Tuple,
Type, Type,
Union, Union,
cast,
) )
from sanic import reloader_helpers from sanic import reloader_helpers
@ -38,8 +40,8 @@ from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS from sanic.compat import OS_IS_WINDOWS
from sanic.helpers import _default from sanic.helpers import _default
from sanic.http.constants import HTTP from sanic.http.constants import HTTP
from sanic.http.tls import process_to_context from sanic.http.tls import get_ssl_context, process_to_context
from sanic.log import Colors, error_logger, logger from sanic.log import Colors, deprecation, error_logger, logger
from sanic.models.handler_types import ListenerType from sanic.models.handler_types import ListenerType
from sanic.server import Signal as ServerSignal from sanic.server import Signal as ServerSignal
from sanic.server import try_use_uvloop from sanic.server import try_use_uvloop
@ -93,6 +95,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False, fast: bool = False,
verbosity: int = 0, verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None, motd_display: Optional[Dict[str, str]] = None,
auto_cert: bool = False,
) -> None: ) -> None:
""" """
Run the HTTP Server and listen until keyboard interrupt or term Run the HTTP Server and listen until keyboard interrupt or term
@ -152,6 +155,7 @@ class RunnerMixin(metaclass=SanicMeta):
fast=fast, fast=fast,
verbosity=verbosity, verbosity=verbosity,
motd_display=motd_display, motd_display=motd_display,
auto_cert=auto_cert,
) )
self.__class__.serve(primary=self) # type: ignore self.__class__.serve(primary=self) # type: ignore
@ -180,7 +184,15 @@ class RunnerMixin(metaclass=SanicMeta):
fast: bool = False, fast: bool = False,
verbosity: int = 0, verbosity: int = 0,
motd_display: Optional[Dict[str, str]] = None, motd_display: Optional[Dict[str, str]] = None,
auto_cert: bool = False,
) -> None: ) -> None:
if version == 3 and self.state.server_info:
raise RuntimeError(
"Serving HTTP/3 instances as a secondary server is "
"not supported. There can only be a single HTTP/3 worker "
"and it must be prepared first."
)
if dev: if dev:
debug = True debug = True
auto_reload = True auto_reload = True
@ -222,7 +234,7 @@ class RunnerMixin(metaclass=SanicMeta):
return return
if sock is None: if sock is None:
host, port = host or "127.0.0.1", port or 8000 host, port = self.get_address(host, port, version)
if protocol is None: if protocol is None:
protocol = ( protocol = (
@ -258,6 +270,7 @@ class RunnerMixin(metaclass=SanicMeta):
protocol=protocol, protocol=protocol,
backlog=backlog, backlog=backlog,
register_sys_signals=register_sys_signals, register_sys_signals=register_sys_signals,
auto_cert=auto_cert,
) )
self.state.server_info.append( self.state.server_info.append(
ApplicationServerInfo(settings=server_settings) ApplicationServerInfo(settings=server_settings)
@ -327,7 +340,7 @@ class RunnerMixin(metaclass=SanicMeta):
""" """
if sock is None: if sock is None:
host, port = host or "127.0.0.1", port or 8000 host, port = host, port = self.get_address(host, port)
if protocol is None: if protocol is None:
protocol = ( protocol = (
@ -402,6 +415,7 @@ class RunnerMixin(metaclass=SanicMeta):
backlog: int = 100, backlog: int = 100,
register_sys_signals: bool = True, register_sys_signals: bool = True,
run_async: bool = False, run_async: bool = False,
auto_cert: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Helper function used by `run` and `create_server`.""" """Helper function used by `run` and `create_server`."""
if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0:
@ -411,13 +425,17 @@ class RunnerMixin(metaclass=SanicMeta):
"#proxy-configuration" "#proxy-configuration"
) )
if not self.state.is_debug:
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION
if isinstance(version, int): if isinstance(version, int):
version = HTTP(version) version = HTTP(version)
ssl = process_to_context(ssl) ssl = process_to_context(ssl)
if version is HTTP.VERSION_3 or auto_cert:
if not self.state.is_debug: if TYPE_CHECKING:
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION self = cast(Sanic, self)
ssl = get_ssl_context(self, ssl)
self.state.host = host or "" self.state.host = host or ""
self.state.port = port or 0 self.state.port = port or 0
@ -441,7 +459,7 @@ class RunnerMixin(metaclass=SanicMeta):
"backlog": backlog, "backlog": backlog,
} }
self.motd(self.serve_location) self.motd(server_settings=server_settings)
if sys.stdout.isatty() and not self.state.is_debug: if sys.stdout.isatty() and not self.state.is_debug:
error_logger.warning( error_logger.warning(
@ -467,7 +485,19 @@ class RunnerMixin(metaclass=SanicMeta):
return server_settings return server_settings
def motd(self, serve_location): def motd(
self,
serve_location: str = "",
server_settings: Optional[Dict[str, Any]] = None,
):
if serve_location:
deprecation(
"Specifying a serve_location in the MOTD is deprecated and "
"will be removed.",
22.9,
)
else:
serve_location = self.get_server_location(server_settings)
if self.config.MOTD: if self.config.MOTD:
mode = [f"{self.state.mode},"] mode = [f"{self.state.mode},"]
if self.state.fast: if self.state.fast:
@ -480,9 +510,16 @@ class RunnerMixin(metaclass=SanicMeta):
else: else:
mode.append(f"w/ {self.state.workers} workers") mode.append(f"w/ {self.state.workers} workers")
server = ", ".join(
(
self.state.server,
server_settings["version"].display(), # type: ignore
)
)
display = { display = {
"mode": " ".join(mode), "mode": " ".join(mode),
"server": self.state.server, "server": server,
"python": platform.python_version(), "python": platform.python_version(),
"platform": platform.platform(), "platform": platform.platform(),
} }
@ -506,7 +543,9 @@ class RunnerMixin(metaclass=SanicMeta):
module_name = package_name.replace("-", "_") module_name = package_name.replace("-", "_")
try: try:
module = import_module(module_name) module = import_module(module_name)
packages.append(f"{package_name}=={module.__version__}") packages.append(
f"{package_name}=={module.__version__}" # type: ignore
)
except ImportError: except ImportError:
... ...
@ -526,25 +565,49 @@ class RunnerMixin(metaclass=SanicMeta):
@property @property
def serve_location(self) -> str: def serve_location(self) -> str:
server_settings = self.state.server_info[0].settings
return self.get_server_location(server_settings)
@staticmethod
def get_server_location(
server_settings: Optional[Dict[str, Any]] = None
) -> str:
serve_location = "" serve_location = ""
proto = "http" proto = "http"
if self.state.ssl is not None: if not server_settings:
return serve_location
if server_settings["ssl"] is not None:
proto = "https" proto = "https"
if self.state.unix: if server_settings["unix"]:
serve_location = f"{self.state.unix} {proto}://..." serve_location = f'{server_settings["unix"]} {proto}://...'
elif self.state.sock: elif server_settings["sock"]:
serve_location = f"{self.state.sock.getsockname()} {proto}://..." serve_location = (
elif self.state.host and self.state.port: f'{server_settings["sock"].getsockname()} {proto}://...'
)
elif server_settings["host"] and server_settings["port"]:
# colon(:) is legal for a host only in an ipv6 address # colon(:) is legal for a host only in an ipv6 address
display_host = ( display_host = (
f"[{self.state.host}]" f'[{server_settings["host"]}]'
if ":" in self.state.host if ":" in server_settings["host"]
else self.state.host else server_settings["host"]
)
serve_location = (
f'{proto}://{display_host}:{server_settings["port"]}'
) )
serve_location = f"{proto}://{display_host}:{self.state.port}"
return serve_location return serve_location
@staticmethod
def get_address(
host: Optional[str],
port: Optional[int],
version: HTTPVersion = HTTP.VERSION_1,
) -> Tuple[str, int]:
host = host or "127.0.0.1"
port = port or (8443 if version == 3 else 8000)
return host, port
@classmethod @classmethod
def should_auto_reload(cls) -> bool: def should_auto_reload(cls) -> bool:
return any(app.state.auto_reload for app in cls._app_registry.values()) return any(app.state.auto_reload for app in cls._app_registry.values())

View File

@ -95,8 +95,47 @@ def serve(
app.asgi = False app.asgi = False
if version is HTTP.VERSION_3: if version is HTTP.VERSION_3:
return serve_http_3(host, port, app, loop, ssl) return _serve_http_3(host, port, app, loop, ssl)
return _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
)
def _serve_http_1(
host,
port,
app,
ssl,
sock,
unix,
reuse_port,
loop,
protocol,
backlog,
register_sys_signals,
run_multiple,
run_async,
connections,
signal,
state,
asyncio_server_kwargs,
):
connections = connections if connections is not None else set() connections = connections if connections is not None else set()
protocol_kwargs = _build_protocol_kwargs(protocol, app.config) protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
server = partial( server = partial(
@ -201,7 +240,7 @@ def serve(
remove_unix_socket(unix) remove_unix_socket(unix)
def serve_http_3( def _serve_http_3(
host, host,
port, port,
app, app,

View File

@ -518,8 +518,12 @@ class WebsocketImplProtocol:
) )
try: try:
self.recv_cancel = asyncio.Future() self.recv_cancel = asyncio.Future()
tasks = (
self.recv_cancel,
asyncio.ensure_future(self.assembler.get(timeout)),
)
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
(self.recv_cancel, self.assembler.get(timeout)), tasks,
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
done_task = next(iter(done)) done_task = next(iter(done))
@ -570,8 +574,12 @@ class WebsocketImplProtocol:
self.can_pause = False self.can_pause = False
self.recv_cancel = asyncio.Future() self.recv_cancel = asyncio.Future()
while True: while True:
tasks = (
self.recv_cancel,
asyncio.ensure_future(self.assembler.get(timeout=0)),
)
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
(self.recv_cancel, self.assembler.get(timeout=0)), tasks,
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
done_task = next(iter(done)) done_task = next(iter(done))

View File

@ -1,3 +1,4 @@
from .altsvc import AltSvcCheck # noqa
from .base import BaseScheme from .base import BaseScheme
from .ode import OptionalDispatchEvent # noqa from .ode import OptionalDispatchEvent # noqa

View File

@ -0,0 +1,56 @@
from __future__ import annotations
from ast import Assign, Constant, NodeTransformer, Subscript
from typing import TYPE_CHECKING, Any, List
from sanic.http.constants import HTTP
from .base import BaseScheme
if TYPE_CHECKING:
from sanic import Sanic
class AltSvcCheck(BaseScheme):
ident = "ALTSVC"
def visitors(self) -> List[NodeTransformer]:
return [RemoveAltSvc(self.app, self.app.state.verbosity)]
class RemoveAltSvc(NodeTransformer):
def __init__(self, app: Sanic, verbosity: int = 0) -> None:
self._app = app
self._verbosity = verbosity
self._versions = {
info.settings["version"] for info in app.state.server_info
}
def visit_Assign(self, node: Assign) -> Any:
if any(self._matches(target) for target in node.targets):
if self._should_remove():
return None
assert isinstance(node.value, Constant)
node.value.value = self.value()
return node
def _should_remove(self) -> bool:
return len(self._versions) == 1
@staticmethod
def _matches(node) -> bool:
return (
isinstance(node, Subscript)
and isinstance(node.slice, Constant)
and node.slice.value == "alt-svc"
)
def value(self):
values = []
for info in self._app.state.server_info:
port = info.settings["port"]
version = info.settings["version"]
if version is HTTP.VERSION_3:
values.append(f'h3=":{port}"')
return ", ".join(values)

View File

@ -1,5 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Set, Type from ast import NodeTransformer, parse
from inspect import getsource
from textwrap import dedent
from typing import Any, Dict, List, Set, Type
class BaseScheme(ABC): class BaseScheme(ABC):
@ -10,11 +13,26 @@ class BaseScheme(ABC):
self.app = app self.app = app
@abstractmethod @abstractmethod
def run(self, method, module_globals) -> None: def visitors(self) -> List[NodeTransformer]:
... ...
def __init_subclass__(cls): def __init_subclass__(cls):
BaseScheme._registry.add(cls) BaseScheme._registry.add(cls)
def __call__(self, method, module_globals): def __call__(self):
return self.run(method, module_globals) return self.visitors()
@classmethod
def build(cls, method, module_globals, app):
raw_source = getsource(method)
src = dedent(raw_source)
node = parse(src)
for scheme in cls._registry:
for visitor in scheme(app)():
node = visitor.visit(node)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]

View File

@ -1,7 +1,5 @@
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse from ast import Attribute, Await, Expr, NodeTransformer
from inspect import getsource from typing import Any, List
from textwrap import dedent
from typing import Any
from sanic.log import logger from sanic.log import logger
@ -20,18 +18,10 @@ class OptionalDispatchEvent(BaseScheme):
signal.name for signal in app.signal_router.routes signal.name for signal in app.signal_router.routes
] ]
def run(self, method, module_globals): def visitors(self) -> List[NodeTransformer]:
raw_source = getsource(method) return [
src = dedent(raw_source) RemoveDispatch(self._registered_events, self.app.state.verbosity)
tree = parse(src) ]
node = RemoveDispatch(
self._registered_events, self.app.state.verbosity
).visit(tree)
compiled_src = compile(node, method.__name__, "exec")
exec_locals: Dict[str, Any] = {}
exec(compiled_src, module_globals, exec_locals) # nosec
return exec_locals[method.__name__]
def _sync_events(self): def _sync_events(self):
all_events = set() all_events = set()

View File

@ -21,10 +21,8 @@ class TouchUp:
module = getmodule(target) module = getmodule(target)
module_globals = dict(getmembers(module)) module_globals = dict(getmembers(module))
modified = BaseScheme.build(method, module_globals, app)
for scheme in BaseScheme._registry: setattr(target, method_name, modified)
modified = scheme(app)(method, module_globals)
setattr(target, method_name, modified)
target.__touched__ = True target.__touched__ = True

View File

@ -148,6 +148,7 @@ extras_require = {
"docs": docs_require, "docs": docs_require,
"all": all_require, "all": all_require,
"ext": ["sanic-ext"], "ext": ["sanic-ext"],
"http3": ["aioquic"],
} }
setup_kwargs["install_requires"] = requirements setup_kwargs["install_requires"] = requirements

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
from collections import deque, namedtuple from collections import deque, namedtuple
@ -6,6 +7,7 @@ import pytest
import uvicorn import uvicorn
from sanic import Sanic from sanic import Sanic
from sanic.application.state import Mode
from sanic.asgi import MockTransport from sanic.asgi import MockTransport
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
from sanic.request import Request from sanic.request import Request
@ -44,7 +46,7 @@ def protocol(transport):
return transport.get_protocol() return transport.get_protocol()
def test_listeners_triggered(): def test_listeners_triggered(caplog):
app = Sanic("app") app = Sanic("app")
before_server_start = False before_server_start = False
after_server_start = False after_server_start = False
@ -82,9 +84,31 @@ def test_listeners_triggered():
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config) server = CustomServer(config=config)
with pytest.warns(UserWarning): start_message = (
'You have set a listener for "before_server_start" in ASGI mode. '
"It will be executed as early as possible, but not before the ASGI "
"server is started."
)
stop_message = (
'You have set a listener for "after_server_stop" in ASGI mode. '
"It will be executed as late as possible, but not after the ASGI "
"server is stopped."
)
with caplog.at_level(logging.DEBUG):
server.run() server.run()
assert (
"sanic.root",
logging.DEBUG,
start_message,
) not in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) not in caplog.record_tuples
all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks: for task in all_tasks:
task.cancel() task.cancel()
@ -94,8 +118,38 @@ def test_listeners_triggered():
assert before_server_stop assert before_server_stop
assert after_server_stop assert after_server_stop
app.state.mode = Mode.DEBUG
with caplog.at_level(logging.DEBUG):
server.run()
def test_listeners_triggered_async(app): assert (
"sanic.root",
logging.DEBUG,
start_message,
) not in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) not in caplog.record_tuples
app.state.verbosity = 2
with caplog.at_level(logging.DEBUG):
server.run()
assert (
"sanic.root",
logging.DEBUG,
start_message,
) in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) in caplog.record_tuples
def test_listeners_triggered_async(app, caplog):
before_server_start = False before_server_start = False
after_server_start = False after_server_start = False
before_server_stop = False before_server_stop = False
@ -132,9 +186,31 @@ def test_listeners_triggered_async(app):
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
server = CustomServer(config=config) server = CustomServer(config=config)
with pytest.warns(UserWarning): start_message = (
'You have set a listener for "before_server_start" in ASGI mode. '
"It will be executed as early as possible, but not before the ASGI "
"server is started."
)
stop_message = (
'You have set a listener for "after_server_stop" in ASGI mode. '
"It will be executed as late as possible, but not after the ASGI "
"server is stopped."
)
with caplog.at_level(logging.DEBUG):
server.run() server.run()
assert (
"sanic.root",
logging.DEBUG,
start_message,
) not in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) not in caplog.record_tuples
all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks: for task in all_tasks:
task.cancel() task.cancel()
@ -144,6 +220,36 @@ def test_listeners_triggered_async(app):
assert before_server_stop assert before_server_stop
assert after_server_stop assert after_server_stop
app.state.mode = Mode.DEBUG
with caplog.at_level(logging.DEBUG):
server.run()
assert (
"sanic.root",
logging.DEBUG,
start_message,
) not in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) not in caplog.record_tuples
app.state.verbosity = 2
with caplog.at_level(logging.DEBUG):
server.run()
assert (
"sanic.root",
logging.DEBUG,
start_message,
) in caplog.record_tuples
assert (
"sanic.root",
logging.DEBUG,
stop_message,
) in caplog.record_tuples
def test_non_default_uvloop_config_raises_warning(app): def test_non_default_uvloop_config_raises_warning(app):
app.config.USE_UVLOOP = True app.config.USE_UVLOOP = True