Auto extend with Sanic Extensions (#2308)

This commit is contained in:
Adam Hopkins 2021-12-25 22:20:06 +02:00 committed by GitHub
parent b91ffed010
commit dc3ccba527
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 264 additions and 48 deletions

View File

@ -28,6 +28,7 @@ from ssl import SSLContext
from traceback import format_exc from traceback import format_exc
from types import SimpleNamespace from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
AnyStr, AnyStr,
Awaitable, Awaitable,
@ -41,6 +42,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
) )
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
@ -53,6 +55,7 @@ from sanic_routing.exceptions import ( # type: ignore
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
from sanic import reloader_helpers from sanic import reloader_helpers
from sanic.application.ext import setup_ext
from sanic.application.logo import get_logo from sanic.application.logo import get_logo
from sanic.application.motd import MOTD from sanic.application.motd import MOTD
from sanic.application.state import ApplicationState, Mode from sanic.application.state import ApplicationState, Mode
@ -103,11 +106,21 @@ from sanic.tls import process_to_context
from sanic.touchup import TouchUp, TouchUpMeta from sanic.touchup import TouchUp, TouchUpMeta
if TYPE_CHECKING: # no cov
try:
from sanic_ext import Extend # type: ignore
from sanic_ext.extensions.base import Extension # type: ignore
except ImportError:
Extend = TypeVar("Extend") # type: ignore
if OS_IS_WINDOWS: if OS_IS_WINDOWS:
enable_windows_color_support() enable_windows_color_support()
filterwarnings("once", category=DeprecationWarning) filterwarnings("once", category=DeprecationWarning)
SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext")
class Sanic(BaseSanic, metaclass=TouchUpMeta): class Sanic(BaseSanic, metaclass=TouchUpMeta):
""" """
@ -125,6 +138,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"_asgi_client", "_asgi_client",
"_blueprint_order", "_blueprint_order",
"_delayed_tasks", "_delayed_tasks",
"_ext",
"_future_exceptions", "_future_exceptions",
"_future_listeners", "_future_listeners",
"_future_middleware", "_future_middleware",
@ -1421,26 +1435,15 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"#proxy-configuration" "#proxy-configuration"
) )
ssl = process_to_context(ssl)
self.debug = debug self.debug = debug
self.state.host = host self.state.host = host
self.state.port = port self.state.port = port
self.state.workers = workers self.state.workers = workers
self.state.ssl = ssl
# Serve self.state.unix = unix
serve_location = "" self.state.sock = sock
proto = "http"
if ssl is not None:
proto = "https"
if unix:
serve_location = f"{unix} {proto}://..."
elif sock:
serve_location = f"{sock.getsockname()} {proto}://..."
elif host and port:
# colon(:) is legal for a host only in an ipv6 address
display_host = f"[{host}]" if ":" in host else host
serve_location = f"{proto}://{display_host}:{port}"
ssl = process_to_context(ssl)
server_settings = { server_settings = {
"protocol": protocol, "protocol": protocol,
@ -1456,7 +1459,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"backlog": backlog, "backlog": backlog,
} }
self.motd(serve_location) self.motd(self.serve_location)
if sys.stdout.isatty() and not self.state.is_debug: if sys.stdout.isatty() and not self.state.is_debug:
error_logger.warning( error_logger.warning(
@ -1482,6 +1485,27 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
return server_settings return server_settings
@property
def serve_location(self) -> str:
serve_location = ""
proto = "http"
if self.state.ssl is not None:
proto = "https"
if self.state.unix:
serve_location = f"{self.state.unix} {proto}://..."
elif self.state.sock:
serve_location = f"{self.state.sock.getsockname()} {proto}://..."
elif self.state.host and self.state.port:
# colon(:) is legal for a host only in an ipv6 address
display_host = (
f"[{self.state.host}]"
if ":" in self.state.host
else self.state.host
)
serve_location = f"{proto}://{display_host}:{self.state.port}"
return serve_location
def _build_endpoint_name(self, *parts): def _build_endpoint_name(self, *parts):
parts = [self.name, *parts] parts = [self.name, *parts]
return ".".join(parts) return ".".join(parts)
@ -1790,11 +1814,8 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
display["auto-reload"] = reload_display display["auto-reload"] = reload_display
packages = [] packages = []
for package_name, module_name in { for package_name in SANIC_PACKAGES:
"sanic-routing": "sanic_routing", module_name = package_name.replace("-", "_")
"sanic-testing": "sanic_testing",
"sanic-ext": "sanic_ext",
}.items():
try: try:
module = import_module(module_name) module = import_module(module_name)
packages.append(f"{package_name}=={module.__version__}") packages.append(f"{package_name}=={module.__version__}")
@ -1814,6 +1835,41 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
) )
MOTD.output(logo, serve_location, display, extra) MOTD.output(logo, serve_location, display, extra)
@property
def ext(self) -> Extend:
if not hasattr(self, "_ext"):
setup_ext(self, fail=True)
if not hasattr(self, "_ext"):
raise RuntimeError(
"Sanic Extensions is not installed. You can add it to your "
"environment using:\n$ pip install sanic[ext]\nor\n$ pip "
"install sanic-ext"
)
return self._ext # type: ignore
def extend(
self,
*,
extensions: Optional[List[Type[Extension]]] = None,
built_in_extensions: bool = True,
config: Optional[Union[Config, Dict[str, Any]]] = None,
**kwargs,
) -> Extend:
if hasattr(self, "_ext"):
raise RuntimeError(
"Cannot extend Sanic after Sanic Extensions has been setup."
)
setup_ext(
self,
extensions=extensions,
built_in_extensions=built_in_extensions,
config=config,
fail=True,
**kwargs,
)
return self.ext
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Class methods # Class methods
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
@ -1875,6 +1931,14 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
async def _startup(self): async def _startup(self):
self._future_registry.clear() self._future_registry.clear()
# Startup Sanic Extensions
if not hasattr(self, "_ext"):
setup_ext(self)
if hasattr(self, "_ext"):
self.ext._display()
# Setup routers
self.signalize() self.signalize()
self.finalize() self.finalize()
@ -1890,8 +1954,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
) )
self.__class__._uvloop_setting = self.config.USE_UVLOOP self.__class__._uvloop_setting = self.config.USE_UVLOOP
# Startup time optimizations
ErrorHandler.finalize(self.error_handler, config=self.config) ErrorHandler.finalize(self.error_handler, config=self.config)
TouchUp.run(self) TouchUp.run(self)
self.state.is_started = True self.state.is_started = True
async def _server_event( async def _server_event(

39
sanic/application/ext.py Normal file
View File

@ -0,0 +1,39 @@
from __future__ import annotations
from contextlib import suppress
from importlib import import_module
from typing import TYPE_CHECKING
if TYPE_CHECKING: # no cov
from sanic import Sanic
try:
from sanic_ext import Extend # type: ignore
except ImportError:
...
def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
if not app.config.AUTO_EXTEND:
return
sanic_ext = None
with suppress(ModuleNotFoundError):
sanic_ext = import_module("sanic_ext")
if not sanic_ext:
if fail:
raise RuntimeError(
"Sanic Extensions is not installed. You can add it to your "
"environment using:\n$ pip install sanic[ext]\nor\n$ pip "
"install sanic-ext"
)
return
if not getattr(app, "_ext", None):
Ext: Extend = getattr(sanic_ext, "Extend")
app._ext = Ext(app, **kwargs)
return app.ext

View File

@ -41,9 +41,6 @@ class MOTD(ABC):
class MOTDBasic(MOTD): class MOTDBasic(MOTD):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def display(self): def display(self):
if self.logo: if self.logo:
logger.debug(self.logo) logger.debug(self.logo)

View File

@ -5,7 +5,9 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Set, Union from socket import socket
from ssl import SSLContext
from typing import TYPE_CHECKING, Any, Optional, Set, Union
from sanic.log import logger from sanic.log import logger
@ -37,8 +39,11 @@ class ApplicationState:
coffee: bool = field(default=False) coffee: bool = field(default=False)
fast: bool = field(default=False) fast: bool = field(default=False)
host: str = field(default="") host: str = field(default="")
mode: Mode = field(default=Mode.PRODUCTION)
port: int = field(default=0) port: int = field(default=0)
ssl: Optional[SSLContext] = field(default=None)
sock: Optional[socket] = field(default=None)
unix: Optional[str] = field(default=None)
mode: Mode = field(default=Mode.PRODUCTION)
reload_dirs: Set[Path] = field(default_factory=set) reload_dirs: Set[Path] = field(default_factory=set)
server: Server = field(default=Server.SANIC) server: Server = field(default=Server.SANIC)
is_running: bool = field(default=False) is_running: bool = field(default=False)

View File

@ -5,7 +5,7 @@ from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING: # no cov
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint

View File

@ -36,8 +36,8 @@ from sanic.models.handler_types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING: # no cov
from sanic import Sanic # noqa from sanic import Sanic
def lazy(func, as_decorator=True): def lazy(func, as_decorator=True):

View File

@ -18,6 +18,7 @@ SANIC_PREFIX = "SANIC_"
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"_FALLBACK_ERROR_FORMAT": _default, "_FALLBACK_ERROR_FORMAT": _default,
"ACCESS_LOG": True, "ACCESS_LOG": True,
"AUTO_EXTEND": True,
"AUTO_RELOAD": False, "AUTO_RELOAD": False,
"EVENT_AUTOREGISTER": False, "EVENT_AUTOREGISTER": False,
"FORWARDED_FOR_HEADER": "X-Forwarded-For", "FORWARDED_FOR_HEADER": "X-Forwarded-For",
@ -59,6 +60,7 @@ class DescriptorMeta(type):
class Config(dict, metaclass=DescriptorMeta): class Config(dict, metaclass=DescriptorMeta):
ACCESS_LOG: bool ACCESS_LOG: bool
AUTO_EXTEND: bool
AUTO_RELOAD: bool AUTO_RELOAD: bool
EVENT_AUTOREGISTER: bool EVENT_AUTOREGISTER: bool
FORWARDED_FOR_HEADER: str FORWARDED_FOR_HEADER: str

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING: # no cov
from sanic.request import Request from sanic.request import Request
from sanic.response import BaseHTTPResponse from sanic.response import BaseHTTPResponse

View File

@ -15,7 +15,7 @@ from typing import (
from sanic_routing.route import Route # type: ignore from sanic_routing.route import Route # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING: # no cov
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.app import Sanic from sanic.app import Sanic

View File

@ -21,6 +21,7 @@ from functools import partial
from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import SIG_IGN, SIGINT, SIGTERM, Signals
from signal import signal as signal_func from signal import signal as signal_func
from sanic.application.ext import setup_ext
from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.models.server_types import Signal from sanic.models.server_types import Signal
@ -116,6 +117,7 @@ def serve(
**asyncio_server_kwargs, **asyncio_server_kwargs,
) )
setup_ext(app)
if run_async: if run_async:
return AsyncioServer( return AsyncioServer(
app=app, app=app,

View File

@ -13,7 +13,7 @@ from typing import (
from sanic.models.handler_types import RouteHandler from sanic.models.handler_types import RouteHandler
if TYPE_CHECKING: if TYPE_CHECKING: # no cov
from sanic import Sanic from sanic import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
@ -81,6 +81,8 @@ class HTTPMethodView:
def dispatch_request(self, request, *args, **kwargs): def dispatch_request(self, request, *args, **kwargs):
handler = getattr(self, request.method.lower(), None) handler = getattr(self, request.method.lower(), None)
if not handler and request.method == "HEAD":
handler = self.get
return handler(request, *args, **kwargs) return handler(request, *args, **kwargs)
@classmethod @classmethod

View File

@ -15,10 +15,10 @@ from sanic.server.protocols.websocket_protocol import WebSocketProtocol
try: try:
import ssl # type: ignore import ssl # type: ignore
except ImportError: except ImportError: # no cov
ssl = None # type: ignore ssl = None # type: ignore
if UVLOOP_INSTALLED: if UVLOOP_INSTALLED: # no cov
try_use_uvloop() try_use_uvloop()

View File

@ -147,6 +147,7 @@ extras_require = {
"dev": dev_require, "dev": dev_require,
"docs": docs_require, "docs": docs_require,
"all": all_require, "all": all_require,
"ext": ["sanic-ext"],
} }
setup_kwargs["install_requires"] = requirements setup_kwargs["install_requires"] = requirements

View File

@ -6,8 +6,10 @@ import string
import sys import sys
import uuid import uuid
from contextlib import suppress
from logging import LogRecord from logging import LogRecord
from typing import Callable, List, Tuple from typing import List, Tuple
from unittest.mock import MagicMock
import pytest import pytest
@ -184,3 +186,21 @@ def message_in_records():
return error_captured return error_captured
return msg_in_log return msg_in_log
@pytest.fixture
def ext_instance():
ext_instance = MagicMock()
ext_instance.injection = MagicMock()
return ext_instance
@pytest.fixture(autouse=True) # type: ignore
def sanic_ext(ext_instance): # noqa
sanic_ext = MagicMock(__version__="1.2.3")
sanic_ext.Extend = MagicMock()
sanic_ext.Extend.return_value = ext_instance
sys.modules["sanic_ext"] = sanic_ext
yield sanic_ext
with suppress(KeyError):
del sys.modules["sanic_ext"]

View File

@ -32,6 +32,12 @@ def starting_line(lines):
return 0 return 0
def read_app_info(lines):
for line in lines:
if line.startswith(b"{") and line.endswith(b"}"):
return json.loads(line)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"appname", "appname",
( (
@ -199,9 +205,7 @@ def test_debug(cmd):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
info = read_app_info(lines)
app_info = lines[starting_line(lines) + 9]
info = json.loads(app_info)
assert info["debug"] is True assert info["debug"] is True
assert info["auto_reload"] is True assert info["auto_reload"] is True
@ -212,9 +216,7 @@ def test_auto_reload(cmd):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
info = read_app_info(lines)
app_info = lines[starting_line(lines) + 9]
info = json.loads(app_info)
assert info["debug"] is False assert info["debug"] is False
assert info["auto_reload"] is True assert info["auto_reload"] is True
@ -227,9 +229,7 @@ def test_access_logs(cmd, expected):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
info = read_app_info(lines)
app_info = lines[starting_line(lines) + 8]
info = json.loads(app_info)
assert info["access_log"] is expected assert info["access_log"] is expected
@ -254,8 +254,6 @@ def test_noisy_exceptions(cmd, expected):
command = ["sanic", "fake.server.app", cmd] command = ["sanic", "fake.server.app", cmd]
out, err, exitcode = capture(command) out, err, exitcode = capture(command)
lines = out.split(b"\n") lines = out.split(b"\n")
info = read_app_info(lines)
app_info = lines[starting_line(lines) + 8]
info = json.loads(app_info)
assert info["noisy_exceptions"] is expected assert info["noisy_exceptions"] is expected

View File

@ -0,0 +1,84 @@
import sys
from unittest.mock import MagicMock
import pytest
from sanic import Sanic
try:
import sanic_ext
SANIC_EXT_IN_ENV = True
except ImportError:
SANIC_EXT_IN_ENV = False
@pytest.fixture
def stoppable_app(app):
@app.before_server_start
async def stop(*_):
app.stop()
return app
def test_ext_is_loaded(stoppable_app: Sanic, sanic_ext):
stoppable_app.run()
sanic_ext.Extend.assert_called_once_with(stoppable_app)
def test_ext_is_not_loaded(stoppable_app: Sanic, sanic_ext):
stoppable_app.config.AUTO_EXTEND = False
stoppable_app.run()
sanic_ext.Extend.assert_not_called()
def test_extend_with_args(stoppable_app: Sanic, sanic_ext):
stoppable_app.extend(built_in_extensions=False)
stoppable_app.run()
sanic_ext.Extend.assert_called_once_with(
stoppable_app, built_in_extensions=False, config=None, extensions=None
)
def test_access_object_sets_up_extension(app: Sanic, sanic_ext):
app.ext
sanic_ext.Extend.assert_called_once_with(app)
def test_extend_cannot_be_called_multiple_times(app: Sanic, sanic_ext):
app.extend()
message = "Cannot extend Sanic after Sanic Extensions has been setup."
with pytest.raises(RuntimeError, match=message):
app.extend()
sanic_ext.Extend.assert_called_once_with(
app, extensions=None, built_in_extensions=True, config=None
)
@pytest.mark.skipif(
SANIC_EXT_IN_ENV,
reason="Running tests with sanic_ext already in the environment",
)
def test_fail_if_not_loaded(app: Sanic):
del sys.modules["sanic_ext"]
with pytest.raises(
RuntimeError, match="Sanic Extensions is not installed.*"
):
app.extend(built_in_extensions=False)
def test_can_access_app_ext_while_running(app: Sanic, sanic_ext, ext_instance):
class IceCream:
flavor: str
@app.before_server_start
async def injections(*_):
app.ext.injection(IceCream)
app.stop()
app.run()
ext_instance.injection.assert_called_with(IceCream)