Merge pull request #2380 from sanic-org/http3-startup
This commit is contained in:
commit
e1937149c9
88
sanic/application/spinner.py
Normal file
88
sanic/application/spinner.py
Normal file
@ -0,0 +1,88 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from contextlib import contextmanager
|
||||
from curses.ascii import SP
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
|
||||
if os.name == "nt":
|
||||
import ctypes
|
||||
import msvcrt
|
||||
|
||||
class _CursorInfo(ctypes.Structure):
|
||||
_fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)]
|
||||
|
||||
|
||||
class Spinner:
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
self.queue: Queue[int] = Queue()
|
||||
self.spinner = self.cursor()
|
||||
self.thread = Thread(target=self.run)
|
||||
|
||||
def start(self):
|
||||
self.queue.put(1)
|
||||
self.thread.start()
|
||||
self.hide()
|
||||
|
||||
def run(self):
|
||||
while self.queue.get():
|
||||
output = f"\r{self.message} [{next(self.spinner)}]"
|
||||
sys.stdout.write(output)
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
self.queue.put(1)
|
||||
|
||||
def stop(self):
|
||||
self.queue.put(0)
|
||||
self.thread.join()
|
||||
self.show()
|
||||
|
||||
@staticmethod
|
||||
def cursor():
|
||||
while True:
|
||||
for cursor in "|/-\\":
|
||||
yield cursor
|
||||
|
||||
@staticmethod
|
||||
def hide():
|
||||
if os.name == "nt":
|
||||
ci = _CursorInfo()
|
||||
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
||||
ctypes.windll.kernel32.GetConsoleCursorInfo(
|
||||
handle, ctypes.byref(ci)
|
||||
)
|
||||
ci.visible = False
|
||||
ctypes.windll.kernel32.SetConsoleCursorInfo(
|
||||
handle, ctypes.byref(ci)
|
||||
)
|
||||
elif os.name == "posix":
|
||||
sys.stdout.write("\033[?25l")
|
||||
sys.stdout.flush()
|
||||
|
||||
@staticmethod
|
||||
def show():
|
||||
if os.name == "nt":
|
||||
ci = _CursorInfo()
|
||||
handle = ctypes.windll.kernel32.GetStdHandle(-11)
|
||||
ctypes.windll.kernel32.GetConsoleCursorInfo(
|
||||
handle, ctypes.byref(ci)
|
||||
)
|
||||
ci.visible = True
|
||||
ctypes.windll.kernel32.SetConsoleCursorInfo(
|
||||
handle, ctypes.byref(ci)
|
||||
)
|
||||
elif os.name == "posix":
|
||||
sys.stdout.write("\033[?25h")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def loading(message: str = "Loading"):
|
||||
spinner = Spinner(message)
|
||||
spinner.start()
|
||||
yield
|
||||
spinner.stop()
|
@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import sanic.app # noqa
|
||||
|
||||
from sanic.compat import Header
|
||||
from sanic.exceptions import ServerError
|
||||
from sanic.helpers import _default
|
||||
from sanic.http import Stage
|
||||
from sanic.log import logger
|
||||
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
||||
from sanic.request import Request
|
||||
from sanic.response import BaseHTTPResponse
|
||||
@ -16,30 +17,35 @@ from sanic.server import ConnInfo
|
||||
from sanic.server.websockets.connection import WebSocketConnection
|
||||
|
||||
|
||||
if TYPE_CHECKING: # no cov
|
||||
from sanic import Sanic
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||
def __init__(self, asgi_app: ASGIApp) -> None:
|
||||
self.asgi_app = asgi_app
|
||||
|
||||
if (
|
||||
"server.init.before"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
warnings.warn(
|
||||
'You have set a listener for "before_server_start" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as early as possible, but not before "
|
||||
"the ASGI server is started."
|
||||
)
|
||||
if (
|
||||
"server.shutdown.after"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
warnings.warn(
|
||||
'You have set a listener for "after_server_stop" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as late as possible, but not after "
|
||||
"the ASGI server is stopped."
|
||||
)
|
||||
if self.asgi_app.sanic_app.state.verbosity > 0:
|
||||
if (
|
||||
"server.init.before"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
logger.debug(
|
||||
'You have set a listener for "before_server_start" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as early as possible, but not before "
|
||||
"the ASGI server is started."
|
||||
)
|
||||
if (
|
||||
"server.shutdown.after"
|
||||
in self.asgi_app.sanic_app.signal_router.name_index
|
||||
):
|
||||
logger.debug(
|
||||
'You have set a listener for "after_server_stop" '
|
||||
"in ASGI mode. "
|
||||
"It will be executed as late as possible, but not after "
|
||||
"the ASGI server is stopped."
|
||||
)
|
||||
|
||||
async def startup(self) -> None:
|
||||
"""
|
||||
@ -88,7 +94,7 @@ class Lifespan:
|
||||
|
||||
|
||||
class ASGIApp:
|
||||
sanic_app: "sanic.app.Sanic"
|
||||
sanic_app: Sanic
|
||||
request: Request
|
||||
transport: MockTransport
|
||||
lifespan: Lifespan
|
||||
|
@ -11,7 +11,6 @@ from typing import Any, List, Union
|
||||
from sanic.app import Sanic
|
||||
from sanic.application.logo import get_logo
|
||||
from sanic.cli.arguments import Group
|
||||
from sanic.http.constants import HTTP
|
||||
from sanic.log import error_logger
|
||||
from sanic.simple import create_simple_server
|
||||
|
||||
@ -59,10 +58,13 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
os.environ.get("SANIC_RELOADER_PROCESS", "") != "true"
|
||||
)
|
||||
self.args: List[Any] = []
|
||||
self.groups: List[Group] = []
|
||||
|
||||
def attach(self):
|
||||
for group in Group._registry:
|
||||
group.create(self.parser).attach()
|
||||
instance = group.create(self.parser)
|
||||
instance.attach()
|
||||
self.groups.append(instance)
|
||||
|
||||
def run(self):
|
||||
# This is to provide backwards compat -v to display version
|
||||
@ -75,9 +77,13 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
try:
|
||||
app = self._get_app()
|
||||
kwargs = self._build_run_kwargs()
|
||||
app.run(**kwargs)
|
||||
except ValueError:
|
||||
error_logger.exception("Failed to run app")
|
||||
else:
|
||||
for http_version in self.args.http:
|
||||
app.prepare(**kwargs, version=http_version)
|
||||
|
||||
Sanic.serve()
|
||||
|
||||
def _precheck(self):
|
||||
# # Custom TLS mismatch handling for better diagnostics
|
||||
@ -137,11 +143,14 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
" Example File: project/sanic_server.py -> app\n"
|
||||
" Example Module: project.sanic_server.app"
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
raise e
|
||||
return app
|
||||
|
||||
def _build_run_kwargs(self):
|
||||
for group in self.groups:
|
||||
group.prepare(self.args)
|
||||
ssl: Union[None, dict, str, list] = []
|
||||
if self.args.tlshost:
|
||||
ssl.append(None)
|
||||
@ -154,7 +163,6 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
elif len(ssl) == 1 and ssl[0] is not None:
|
||||
# Use only one cert, no TLSSelector.
|
||||
ssl = ssl[0]
|
||||
version = HTTP(self.args.http)
|
||||
kwargs = {
|
||||
"access_log": self.args.access_log,
|
||||
"debug": self.args.debug,
|
||||
@ -167,7 +175,7 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
"unix": self.args.unix,
|
||||
"verbosity": self.args.verbosity or 0,
|
||||
"workers": self.args.workers,
|
||||
"version": version,
|
||||
"auto_cert": self.args.auto_cert,
|
||||
}
|
||||
|
||||
for maybe_arg in ("auto_reload", "dev"):
|
||||
@ -177,4 +185,5 @@ Or, a path to a directory to run as a simple HTTP server:
|
||||
if self.args.path:
|
||||
kwargs["auto_reload"] = True
|
||||
kwargs["reload_dir"] = self.args.path
|
||||
|
||||
return kwargs
|
||||
|
@ -6,6 +6,7 @@ from typing import List, Optional, Type, Union
|
||||
from sanic_routing import __version__ as __routing_version__ # type: ignore
|
||||
|
||||
from sanic import __version__
|
||||
from sanic.http.constants import HTTP
|
||||
|
||||
|
||||
class Group:
|
||||
@ -38,6 +39,9 @@ class Group:
|
||||
"--no-" + args[0][2:], *args[1:], action="store_false", **kwargs
|
||||
)
|
||||
|
||||
def prepare(self, args) -> None:
|
||||
...
|
||||
|
||||
|
||||
class GeneralGroup(Group):
|
||||
name = None
|
||||
@ -87,25 +91,39 @@ class HTTPVersionGroup(Group):
|
||||
name = "HTTP version"
|
||||
|
||||
def attach(self):
|
||||
group = self.container.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
http_values = [http.value for http in HTTP.__members__.values()]
|
||||
|
||||
self.container.add_argument(
|
||||
"--http",
|
||||
dest="http",
|
||||
action="append",
|
||||
choices=http_values,
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should "
|
||||
"be either 1 or 3 [default 1]"
|
||||
"Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n"
|
||||
"be either 1, or 3. [default 1]"
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
self.container.add_argument(
|
||||
"-1",
|
||||
dest="http",
|
||||
action="append_const",
|
||||
const=1,
|
||||
help=("Run Sanic server using HTTP/1.1"),
|
||||
)
|
||||
self.container.add_argument(
|
||||
"-3",
|
||||
dest="http",
|
||||
action="store_const",
|
||||
action="append_const",
|
||||
const=3,
|
||||
help=("Run Sanic server using HTTP/3"),
|
||||
)
|
||||
|
||||
def prepare(self, args):
|
||||
if not args.http:
|
||||
args.http = [1]
|
||||
args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True))
|
||||
|
||||
|
||||
class SocketGroup(Group):
|
||||
name = "Socket binding"
|
||||
@ -116,7 +134,6 @@ class SocketGroup(Group):
|
||||
"--host",
|
||||
dest="host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host address [default 127.0.0.1]",
|
||||
)
|
||||
self.container.add_argument(
|
||||
@ -124,7 +141,6 @@ class SocketGroup(Group):
|
||||
"--port",
|
||||
dest="port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to serve on [default 8000]",
|
||||
)
|
||||
self.container.add_argument(
|
||||
@ -233,7 +249,16 @@ class DevelopmentGroup(Group):
|
||||
"--dev",
|
||||
dest="dev",
|
||||
action="store_true",
|
||||
help=("debug + auto reload."),
|
||||
help=("debug + auto reload"),
|
||||
)
|
||||
self.container.add_argument(
|
||||
"--auto-cert",
|
||||
dest="auto_cert",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Create a temporary TLS certificate for local development "
|
||||
"(requires mkcert)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
@ -20,6 +20,12 @@ class Stage(Enum):
|
||||
FAILED = 100 # Unrecoverable state (error while sending response)
|
||||
|
||||
|
||||
class HTTP(Enum):
|
||||
class HTTP(IntEnum):
|
||||
VERSION_1 = 1
|
||||
VERSION_3 = 3
|
||||
|
||||
def display(self) -> str:
|
||||
value = str(self.value)
|
||||
if value == 1:
|
||||
value = "1.1"
|
||||
return f"HTTP/{value}"
|
||||
|
@ -333,6 +333,12 @@ class Http(metaclass=TouchUpMeta):
|
||||
self.response_func = self.head_response_ignored
|
||||
|
||||
headers["connection"] = "keep-alive" if self.keep_alive else "close"
|
||||
|
||||
# This header may be removed or modified by the AltSvcCheck Touchup
|
||||
# service. At server start, we either remove this header from ever
|
||||
# being assigned, or we change the value as required.
|
||||
headers["alt-svc"] = ""
|
||||
|
||||
ret = format_http1_response(status, res.processed_headers)
|
||||
if data:
|
||||
ret += data
|
||||
|
@ -3,15 +3,16 @@ from __future__ import annotations
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from contextlib import suppress
|
||||
from inspect import currentframe, getframeinfo
|
||||
from pathlib import Path
|
||||
from ssl import SSLContext
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
||||
|
||||
from sanic.application.constants import Mode
|
||||
from sanic.application.spinner import loading
|
||||
from sanic.constants import DEFAULT_LOCAL_TLS_CERT, DEFAULT_LOCAL_TLS_KEY
|
||||
from sanic.exceptions import SanicException
|
||||
from sanic.helpers import Default
|
||||
@ -233,7 +234,7 @@ def get_ssl_context(app: Sanic, ssl: Optional[SSLContext]) -> SSLContext:
|
||||
|
||||
if app.state.mode is Mode.PRODUCTION:
|
||||
raise SanicException(
|
||||
"Cannot run Sanic as an HTTP/3 server in PRODUCTION mode "
|
||||
"Cannot run Sanic as an HTTPS server in PRODUCTION mode "
|
||||
"without passing a TLS certificate. If you are developing "
|
||||
"locally, please enable DEVELOPMENT mode and Sanic will "
|
||||
"generate a localhost TLS certificate. For more information "
|
||||
@ -283,15 +284,32 @@ def generate_local_certificate(
|
||||
):
|
||||
check_mkcert()
|
||||
|
||||
cmd = [
|
||||
"mkcert",
|
||||
"-key-file",
|
||||
str(key_path),
|
||||
"-cert-file",
|
||||
str(cert_path),
|
||||
localhost,
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
if not key_path.parent.exists() or not cert_path.parent.exists():
|
||||
raise SanicException(
|
||||
f"Cannot generate certificate at [{key_path}, {cert_path}]. One "
|
||||
"or more of the directories does not exist."
|
||||
)
|
||||
|
||||
message = "Generating TLS certificate"
|
||||
with loading(message):
|
||||
cmd = [
|
||||
"mkcert",
|
||||
"-key-file",
|
||||
str(key_path),
|
||||
"-cert-file",
|
||||
str(cert_path),
|
||||
localhost,
|
||||
]
|
||||
resp = subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
sys.stdout.write("\r" + " " * (len(message) + 4))
|
||||
sys.stdout.flush()
|
||||
sys.stdout.write(resp.stdout)
|
||||
|
||||
|
||||
def check_mkcert():
|
||||
|
@ -26,8 +26,10 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from sanic import reloader_helpers
|
||||
@ -38,8 +40,8 @@ from sanic.base.meta import SanicMeta
|
||||
from sanic.compat import OS_IS_WINDOWS
|
||||
from sanic.helpers import _default
|
||||
from sanic.http.constants import HTTP
|
||||
from sanic.http.tls import process_to_context
|
||||
from sanic.log import Colors, error_logger, logger
|
||||
from sanic.http.tls import get_ssl_context, process_to_context
|
||||
from sanic.log import Colors, deprecation, error_logger, logger
|
||||
from sanic.models.handler_types import ListenerType
|
||||
from sanic.server import Signal as ServerSignal
|
||||
from sanic.server import try_use_uvloop
|
||||
@ -93,6 +95,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
fast: bool = False,
|
||||
verbosity: int = 0,
|
||||
motd_display: Optional[Dict[str, str]] = None,
|
||||
auto_cert: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run the HTTP Server and listen until keyboard interrupt or term
|
||||
@ -152,6 +155,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
fast=fast,
|
||||
verbosity=verbosity,
|
||||
motd_display=motd_display,
|
||||
auto_cert=auto_cert,
|
||||
)
|
||||
|
||||
self.__class__.serve(primary=self) # type: ignore
|
||||
@ -180,7 +184,15 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
fast: bool = False,
|
||||
verbosity: int = 0,
|
||||
motd_display: Optional[Dict[str, str]] = None,
|
||||
auto_cert: bool = False,
|
||||
) -> None:
|
||||
if version == 3 and self.state.server_info:
|
||||
raise RuntimeError(
|
||||
"Serving HTTP/3 instances as a secondary server is "
|
||||
"not supported. There can only be a single HTTP/3 worker "
|
||||
"and it must be prepared first."
|
||||
)
|
||||
|
||||
if dev:
|
||||
debug = True
|
||||
auto_reload = True
|
||||
@ -222,7 +234,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
return
|
||||
|
||||
if sock is None:
|
||||
host, port = host or "127.0.0.1", port or 8000
|
||||
host, port = self.get_address(host, port, version)
|
||||
|
||||
if protocol is None:
|
||||
protocol = (
|
||||
@ -258,6 +270,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
protocol=protocol,
|
||||
backlog=backlog,
|
||||
register_sys_signals=register_sys_signals,
|
||||
auto_cert=auto_cert,
|
||||
)
|
||||
self.state.server_info.append(
|
||||
ApplicationServerInfo(settings=server_settings)
|
||||
@ -327,7 +340,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
"""
|
||||
|
||||
if sock is None:
|
||||
host, port = host or "127.0.0.1", port or 8000
|
||||
host, port = host, port = self.get_address(host, port)
|
||||
|
||||
if protocol is None:
|
||||
protocol = (
|
||||
@ -402,6 +415,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
backlog: int = 100,
|
||||
register_sys_signals: bool = True,
|
||||
run_async: bool = False,
|
||||
auto_cert: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Helper function used by `run` and `create_server`."""
|
||||
if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0:
|
||||
@ -411,13 +425,17 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
"#proxy-configuration"
|
||||
)
|
||||
|
||||
if not self.state.is_debug:
|
||||
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION
|
||||
|
||||
if isinstance(version, int):
|
||||
version = HTTP(version)
|
||||
|
||||
ssl = process_to_context(ssl)
|
||||
|
||||
if not self.state.is_debug:
|
||||
self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION
|
||||
if version is HTTP.VERSION_3 or auto_cert:
|
||||
if TYPE_CHECKING:
|
||||
self = cast(Sanic, self)
|
||||
ssl = get_ssl_context(self, ssl)
|
||||
|
||||
self.state.host = host or ""
|
||||
self.state.port = port or 0
|
||||
@ -441,7 +459,7 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
"backlog": backlog,
|
||||
}
|
||||
|
||||
self.motd(self.serve_location)
|
||||
self.motd(server_settings=server_settings)
|
||||
|
||||
if sys.stdout.isatty() and not self.state.is_debug:
|
||||
error_logger.warning(
|
||||
@ -467,7 +485,19 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
|
||||
return server_settings
|
||||
|
||||
def motd(self, serve_location):
|
||||
def motd(
|
||||
self,
|
||||
serve_location: str = "",
|
||||
server_settings: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if serve_location:
|
||||
deprecation(
|
||||
"Specifying a serve_location in the MOTD is deprecated and "
|
||||
"will be removed.",
|
||||
22.9,
|
||||
)
|
||||
else:
|
||||
serve_location = self.get_server_location(server_settings)
|
||||
if self.config.MOTD:
|
||||
mode = [f"{self.state.mode},"]
|
||||
if self.state.fast:
|
||||
@ -480,9 +510,16 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
else:
|
||||
mode.append(f"w/ {self.state.workers} workers")
|
||||
|
||||
server = ", ".join(
|
||||
(
|
||||
self.state.server,
|
||||
server_settings["version"].display(), # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
display = {
|
||||
"mode": " ".join(mode),
|
||||
"server": self.state.server,
|
||||
"server": server,
|
||||
"python": platform.python_version(),
|
||||
"platform": platform.platform(),
|
||||
}
|
||||
@ -506,7 +543,9 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
module_name = package_name.replace("-", "_")
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
packages.append(f"{package_name}=={module.__version__}")
|
||||
packages.append(
|
||||
f"{package_name}=={module.__version__}" # type: ignore
|
||||
)
|
||||
except ImportError:
|
||||
...
|
||||
|
||||
@ -526,25 +565,49 @@ class RunnerMixin(metaclass=SanicMeta):
|
||||
|
||||
@property
|
||||
def serve_location(self) -> str:
|
||||
server_settings = self.state.server_info[0].settings
|
||||
return self.get_server_location(server_settings)
|
||||
|
||||
@staticmethod
|
||||
def get_server_location(
|
||||
server_settings: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
serve_location = ""
|
||||
proto = "http"
|
||||
if self.state.ssl is not None:
|
||||
if not server_settings:
|
||||
return serve_location
|
||||
|
||||
if server_settings["ssl"] is not None:
|
||||
proto = "https"
|
||||
if self.state.unix:
|
||||
serve_location = f"{self.state.unix} {proto}://..."
|
||||
elif self.state.sock:
|
||||
serve_location = f"{self.state.sock.getsockname()} {proto}://..."
|
||||
elif self.state.host and self.state.port:
|
||||
if server_settings["unix"]:
|
||||
serve_location = f'{server_settings["unix"]} {proto}://...'
|
||||
elif server_settings["sock"]:
|
||||
serve_location = (
|
||||
f'{server_settings["sock"].getsockname()} {proto}://...'
|
||||
)
|
||||
elif server_settings["host"] and server_settings["port"]:
|
||||
# colon(:) is legal for a host only in an ipv6 address
|
||||
display_host = (
|
||||
f"[{self.state.host}]"
|
||||
if ":" in self.state.host
|
||||
else self.state.host
|
||||
f'[{server_settings["host"]}]'
|
||||
if ":" in server_settings["host"]
|
||||
else server_settings["host"]
|
||||
)
|
||||
serve_location = (
|
||||
f'{proto}://{display_host}:{server_settings["port"]}'
|
||||
)
|
||||
serve_location = f"{proto}://{display_host}:{self.state.port}"
|
||||
|
||||
return serve_location
|
||||
|
||||
@staticmethod
|
||||
def get_address(
|
||||
host: Optional[str],
|
||||
port: Optional[int],
|
||||
version: HTTPVersion = HTTP.VERSION_1,
|
||||
) -> Tuple[str, int]:
|
||||
host = host or "127.0.0.1"
|
||||
port = port or (8443 if version == 3 else 8000)
|
||||
return host, port
|
||||
|
||||
@classmethod
|
||||
def should_auto_reload(cls) -> bool:
|
||||
return any(app.state.auto_reload for app in cls._app_registry.values())
|
||||
|
@ -95,8 +95,47 @@ def serve(
|
||||
app.asgi = False
|
||||
|
||||
if version is HTTP.VERSION_3:
|
||||
return serve_http_3(host, port, app, loop, ssl)
|
||||
return _serve_http_3(host, port, app, loop, ssl)
|
||||
return _serve_http_1(
|
||||
host,
|
||||
port,
|
||||
app,
|
||||
ssl,
|
||||
sock,
|
||||
unix,
|
||||
reuse_port,
|
||||
loop,
|
||||
protocol,
|
||||
backlog,
|
||||
register_sys_signals,
|
||||
run_multiple,
|
||||
run_async,
|
||||
connections,
|
||||
signal,
|
||||
state,
|
||||
asyncio_server_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _serve_http_1(
|
||||
host,
|
||||
port,
|
||||
app,
|
||||
ssl,
|
||||
sock,
|
||||
unix,
|
||||
reuse_port,
|
||||
loop,
|
||||
protocol,
|
||||
backlog,
|
||||
register_sys_signals,
|
||||
run_multiple,
|
||||
run_async,
|
||||
connections,
|
||||
signal,
|
||||
state,
|
||||
asyncio_server_kwargs,
|
||||
):
|
||||
connections = connections if connections is not None else set()
|
||||
protocol_kwargs = _build_protocol_kwargs(protocol, app.config)
|
||||
server = partial(
|
||||
@ -201,7 +240,7 @@ def serve(
|
||||
remove_unix_socket(unix)
|
||||
|
||||
|
||||
def serve_http_3(
|
||||
def _serve_http_3(
|
||||
host,
|
||||
port,
|
||||
app,
|
||||
|
@ -518,8 +518,12 @@ class WebsocketImplProtocol:
|
||||
)
|
||||
try:
|
||||
self.recv_cancel = asyncio.Future()
|
||||
tasks = (
|
||||
self.recv_cancel,
|
||||
asyncio.ensure_future(self.assembler.get(timeout)),
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
(self.recv_cancel, self.assembler.get(timeout)),
|
||||
tasks,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done_task = next(iter(done))
|
||||
@ -570,8 +574,12 @@ class WebsocketImplProtocol:
|
||||
self.can_pause = False
|
||||
self.recv_cancel = asyncio.Future()
|
||||
while True:
|
||||
tasks = (
|
||||
self.recv_cancel,
|
||||
asyncio.ensure_future(self.assembler.get(timeout=0)),
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
(self.recv_cancel, self.assembler.get(timeout=0)),
|
||||
tasks,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done_task = next(iter(done))
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .altsvc import AltSvcCheck # noqa
|
||||
from .base import BaseScheme
|
||||
from .ode import OptionalDispatchEvent # noqa
|
||||
|
||||
|
56
sanic/touchup/schemes/altsvc.py
Normal file
56
sanic/touchup/schemes/altsvc.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ast import Assign, Constant, NodeTransformer, Subscript
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from sanic.http.constants import HTTP
|
||||
|
||||
from .base import BaseScheme
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sanic import Sanic
|
||||
|
||||
|
||||
class AltSvcCheck(BaseScheme):
|
||||
ident = "ALTSVC"
|
||||
|
||||
def visitors(self) -> List[NodeTransformer]:
|
||||
return [RemoveAltSvc(self.app, self.app.state.verbosity)]
|
||||
|
||||
|
||||
class RemoveAltSvc(NodeTransformer):
|
||||
def __init__(self, app: Sanic, verbosity: int = 0) -> None:
|
||||
self._app = app
|
||||
self._verbosity = verbosity
|
||||
self._versions = {
|
||||
info.settings["version"] for info in app.state.server_info
|
||||
}
|
||||
|
||||
def visit_Assign(self, node: Assign) -> Any:
|
||||
if any(self._matches(target) for target in node.targets):
|
||||
if self._should_remove():
|
||||
return None
|
||||
assert isinstance(node.value, Constant)
|
||||
node.value.value = self.value()
|
||||
return node
|
||||
|
||||
def _should_remove(self) -> bool:
|
||||
return len(self._versions) == 1
|
||||
|
||||
@staticmethod
|
||||
def _matches(node) -> bool:
|
||||
return (
|
||||
isinstance(node, Subscript)
|
||||
and isinstance(node.slice, Constant)
|
||||
and node.slice.value == "alt-svc"
|
||||
)
|
||||
|
||||
def value(self):
|
||||
values = []
|
||||
for info in self._app.state.server_info:
|
||||
port = info.settings["port"]
|
||||
version = info.settings["version"]
|
||||
if version is HTTP.VERSION_3:
|
||||
values.append(f'h3=":{port}"')
|
||||
return ", ".join(values)
|
@ -1,5 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Set, Type
|
||||
from ast import NodeTransformer, parse
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
|
||||
|
||||
class BaseScheme(ABC):
|
||||
@ -10,11 +13,26 @@ class BaseScheme(ABC):
|
||||
self.app = app
|
||||
|
||||
@abstractmethod
|
||||
def run(self, method, module_globals) -> None:
|
||||
def visitors(self) -> List[NodeTransformer]:
|
||||
...
|
||||
|
||||
def __init_subclass__(cls):
|
||||
BaseScheme._registry.add(cls)
|
||||
|
||||
def __call__(self, method, module_globals):
|
||||
return self.run(method, module_globals)
|
||||
def __call__(self):
|
||||
return self.visitors()
|
||||
|
||||
@classmethod
|
||||
def build(cls, method, module_globals, app):
|
||||
raw_source = getsource(method)
|
||||
src = dedent(raw_source)
|
||||
node = parse(src)
|
||||
|
||||
for scheme in cls._registry:
|
||||
for visitor in scheme(app)():
|
||||
node = visitor.visit(node)
|
||||
|
||||
compiled_src = compile(node, method.__name__, "exec")
|
||||
exec_locals: Dict[str, Any] = {}
|
||||
exec(compiled_src, module_globals, exec_locals) # nosec
|
||||
return exec_locals[method.__name__]
|
||||
|
@ -1,7 +1,5 @@
|
||||
from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse
|
||||
from inspect import getsource
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from ast import Attribute, Await, Expr, NodeTransformer
|
||||
from typing import Any, List
|
||||
|
||||
from sanic.log import logger
|
||||
|
||||
@ -20,18 +18,10 @@ class OptionalDispatchEvent(BaseScheme):
|
||||
signal.name for signal in app.signal_router.routes
|
||||
]
|
||||
|
||||
def run(self, method, module_globals):
|
||||
raw_source = getsource(method)
|
||||
src = dedent(raw_source)
|
||||
tree = parse(src)
|
||||
node = RemoveDispatch(
|
||||
self._registered_events, self.app.state.verbosity
|
||||
).visit(tree)
|
||||
compiled_src = compile(node, method.__name__, "exec")
|
||||
exec_locals: Dict[str, Any] = {}
|
||||
exec(compiled_src, module_globals, exec_locals) # nosec
|
||||
|
||||
return exec_locals[method.__name__]
|
||||
def visitors(self) -> List[NodeTransformer]:
|
||||
return [
|
||||
RemoveDispatch(self._registered_events, self.app.state.verbosity)
|
||||
]
|
||||
|
||||
def _sync_events(self):
|
||||
all_events = set()
|
||||
|
@ -21,10 +21,8 @@ class TouchUp:
|
||||
|
||||
module = getmodule(target)
|
||||
module_globals = dict(getmembers(module))
|
||||
|
||||
for scheme in BaseScheme._registry:
|
||||
modified = scheme(app)(method, module_globals)
|
||||
setattr(target, method_name, modified)
|
||||
modified = BaseScheme.build(method, module_globals, app)
|
||||
setattr(target, method_name, modified)
|
||||
|
||||
target.__touched__ = True
|
||||
|
||||
|
1
setup.py
1
setup.py
@ -148,6 +148,7 @@ extras_require = {
|
||||
"docs": docs_require,
|
||||
"all": all_require,
|
||||
"ext": ["sanic-ext"],
|
||||
"http3": ["aioquic"],
|
||||
}
|
||||
|
||||
setup_kwargs["install_requires"] = requirements
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from collections import deque, namedtuple
|
||||
|
||||
@ -6,6 +7,7 @@ import pytest
|
||||
import uvicorn
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.application.state import Mode
|
||||
from sanic.asgi import MockTransport
|
||||
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
|
||||
from sanic.request import Request
|
||||
@ -44,7 +46,7 @@ def protocol(transport):
|
||||
return transport.get_protocol()
|
||||
|
||||
|
||||
def test_listeners_triggered():
|
||||
def test_listeners_triggered(caplog):
|
||||
app = Sanic("app")
|
||||
before_server_start = False
|
||||
after_server_start = False
|
||||
@ -82,9 +84,31 @@ def test_listeners_triggered():
|
||||
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
|
||||
server = CustomServer(config=config)
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
start_message = (
|
||||
'You have set a listener for "before_server_start" in ASGI mode. '
|
||||
"It will be executed as early as possible, but not before the ASGI "
|
||||
"server is started."
|
||||
)
|
||||
stop_message = (
|
||||
'You have set a listener for "after_server_stop" in ASGI mode. '
|
||||
"It will be executed as late as possible, but not after the ASGI "
|
||||
"server is stopped."
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) not in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) not in caplog.record_tuples
|
||||
|
||||
all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
|
||||
for task in all_tasks:
|
||||
task.cancel()
|
||||
@ -94,8 +118,38 @@ def test_listeners_triggered():
|
||||
assert before_server_stop
|
||||
assert after_server_stop
|
||||
|
||||
app.state.mode = Mode.DEBUG
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
def test_listeners_triggered_async(app):
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) not in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) not in caplog.record_tuples
|
||||
|
||||
app.state.verbosity = 2
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) in caplog.record_tuples
|
||||
|
||||
|
||||
def test_listeners_triggered_async(app, caplog):
|
||||
before_server_start = False
|
||||
after_server_start = False
|
||||
before_server_stop = False
|
||||
@ -132,9 +186,31 @@ def test_listeners_triggered_async(app):
|
||||
config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0)
|
||||
server = CustomServer(config=config)
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
start_message = (
|
||||
'You have set a listener for "before_server_start" in ASGI mode. '
|
||||
"It will be executed as early as possible, but not before the ASGI "
|
||||
"server is started."
|
||||
)
|
||||
stop_message = (
|
||||
'You have set a listener for "after_server_stop" in ASGI mode. '
|
||||
"It will be executed as late as possible, but not after the ASGI "
|
||||
"server is stopped."
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) not in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) not in caplog.record_tuples
|
||||
|
||||
all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
|
||||
for task in all_tasks:
|
||||
task.cancel()
|
||||
@ -144,6 +220,36 @@ def test_listeners_triggered_async(app):
|
||||
assert before_server_stop
|
||||
assert after_server_stop
|
||||
|
||||
app.state.mode = Mode.DEBUG
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) not in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) not in caplog.record_tuples
|
||||
|
||||
app.state.verbosity = 2
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
server.run()
|
||||
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
start_message,
|
||||
) in caplog.record_tuples
|
||||
assert (
|
||||
"sanic.root",
|
||||
logging.DEBUG,
|
||||
stop_message,
|
||||
) in caplog.record_tuples
|
||||
|
||||
|
||||
def test_non_default_uvloop_config_raises_warning(app):
|
||||
app.config.USE_UVLOOP = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user