From dc3ccba52748dbabd0369fd34334c5619b7ef945 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sat, 25 Dec 2021 22:20:06 +0200 Subject: [PATCH] Auto extend with Sanic Extensions (#2308) --- sanic/app.py | 110 +++++++++++++++++++++++++++------- sanic/application/ext.py | 39 ++++++++++++ sanic/application/motd.py | 3 - sanic/application/state.py | 9 ++- sanic/blueprint_group.py | 2 +- sanic/blueprints.py | 4 +- sanic/config.py | 2 + sanic/http.py | 2 +- sanic/request.py | 2 +- sanic/server/runners.py | 2 + sanic/views.py | 4 +- sanic/worker.py | 4 +- setup.py | 1 + tests/conftest.py | 22 ++++++- tests/test_cli.py | 22 ++++--- tests/test_ext_integration.py | 84 ++++++++++++++++++++++++++ 16 files changed, 264 insertions(+), 48 deletions(-) create mode 100644 sanic/application/ext.py create mode 100644 tests/test_ext_integration.py diff --git a/sanic/app.py b/sanic/app.py index 9f488359..c393c715 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -28,6 +28,7 @@ from ssl import SSLContext from traceback import format_exc from types import SimpleNamespace from typing import ( + TYPE_CHECKING, Any, AnyStr, Awaitable, @@ -41,6 +42,7 @@ from typing import ( Set, Tuple, Type, + TypeVar, Union, ) from urllib.parse import urlencode, urlunparse @@ -53,6 +55,7 @@ from sanic_routing.exceptions import ( # type: ignore from sanic_routing.route import Route # type: ignore from sanic import reloader_helpers +from sanic.application.ext import setup_ext from sanic.application.logo import get_logo from sanic.application.motd import MOTD from sanic.application.state import ApplicationState, Mode @@ -103,11 +106,21 @@ from sanic.tls import process_to_context from sanic.touchup import TouchUp, TouchUpMeta +if TYPE_CHECKING: # no cov + try: + from sanic_ext import Extend # type: ignore + from sanic_ext.extensions.base import Extension # type: ignore + except ImportError: + Extend = TypeVar("Extend") # type: ignore + + if OS_IS_WINDOWS: enable_windows_color_support() filterwarnings("once", category=DeprecationWarning) +SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") + class Sanic(BaseSanic, metaclass=TouchUpMeta): """ @@ -125,6 +138,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "_asgi_client", "_blueprint_order", "_delayed_tasks", + "_ext", "_future_exceptions", "_future_listeners", "_future_middleware", @@ -1421,26 +1435,15 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "#proxy-configuration" ) + ssl = process_to_context(ssl) + self.debug = debug self.state.host = host self.state.port = port self.state.workers = workers - - # Serve - serve_location = "" - proto = "http" - if ssl is not None: - proto = "https" - if unix: - serve_location = f"{unix} {proto}://..." - elif sock: - serve_location = f"{sock.getsockname()} {proto}://..." - elif host and port: - # colon(:) is legal for a host only in an ipv6 address - display_host = f"[{host}]" if ":" in host else host - serve_location = f"{proto}://{display_host}:{port}" - - ssl = process_to_context(ssl) + self.state.ssl = ssl + self.state.unix = unix + self.state.sock = sock server_settings = { "protocol": protocol, @@ -1456,7 +1459,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): "backlog": backlog, } - self.motd(serve_location) + self.motd(self.serve_location) if sys.stdout.isatty() and not self.state.is_debug: error_logger.warning( @@ -1482,6 +1485,27 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): return server_settings + @property + def serve_location(self) -> str: + serve_location = "" + proto = "http" + if self.state.ssl is not None: + proto = "https" + if self.state.unix: + serve_location = f"{self.state.unix} {proto}://..." + elif self.state.sock: + serve_location = f"{self.state.sock.getsockname()} {proto}://..." + elif self.state.host and self.state.port: + # colon(:) is legal for a host only in an ipv6 address + display_host = ( + f"[{self.state.host}]" + if ":" in self.state.host + else self.state.host + ) + serve_location = f"{proto}://{display_host}:{self.state.port}" + + return serve_location + def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) @@ -1790,11 +1814,8 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): display["auto-reload"] = reload_display packages = [] - for package_name, module_name in { - "sanic-routing": "sanic_routing", - "sanic-testing": "sanic_testing", - "sanic-ext": "sanic_ext", - }.items(): + for package_name in SANIC_PACKAGES: + module_name = package_name.replace("-", "_") try: module = import_module(module_name) packages.append(f"{package_name}=={module.__version__}") @@ -1814,6 +1835,41 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): ) MOTD.output(logo, serve_location, display, extra) + @property + def ext(self) -> Extend: + if not hasattr(self, "_ext"): + setup_ext(self, fail=True) + + if not hasattr(self, "_ext"): + raise RuntimeError( + "Sanic Extensions is not installed. You can add it to your " + "environment using:\n$ pip install sanic[ext]\nor\n$ pip " + "install sanic-ext" + ) + return self._ext # type: ignore + + def extend( + self, + *, + extensions: Optional[List[Type[Extension]]] = None, + built_in_extensions: bool = True, + config: Optional[Union[Config, Dict[str, Any]]] = None, + **kwargs, + ) -> Extend: + if hasattr(self, "_ext"): + raise RuntimeError( + "Cannot extend Sanic after Sanic Extensions has been setup." + ) + setup_ext( + self, + extensions=extensions, + built_in_extensions=built_in_extensions, + config=config, + fail=True, + **kwargs, + ) + return self.ext + # -------------------------------------------------------------------- # # Class methods # -------------------------------------------------------------------- # @@ -1875,6 +1931,14 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): async def _startup(self): self._future_registry.clear() + + # Startup Sanic Extensions + if not hasattr(self, "_ext"): + setup_ext(self) + if hasattr(self, "_ext"): + self.ext._display() + + # Setup routers self.signalize() self.finalize() @@ -1890,8 +1954,10 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): ) self.__class__._uvloop_setting = self.config.USE_UVLOOP + # Startup time optimizations ErrorHandler.finalize(self.error_handler, config=self.config) TouchUp.run(self) + self.state.is_started = True async def _server_event( diff --git a/sanic/application/ext.py b/sanic/application/ext.py new file mode 100644 index 00000000..deb7c5d4 --- /dev/null +++ b/sanic/application/ext.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from contextlib import suppress +from importlib import import_module +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: # no cov + from sanic import Sanic + + try: + from sanic_ext import Extend # type: ignore + except ImportError: + ... + + +def setup_ext(app: Sanic, *, fail: bool = False, **kwargs): + if not app.config.AUTO_EXTEND: + return + + sanic_ext = None + with suppress(ModuleNotFoundError): + sanic_ext = import_module("sanic_ext") + + if not sanic_ext: + if fail: + raise RuntimeError( + "Sanic Extensions is not installed. You can add it to your " + "environment using:\n$ pip install sanic[ext]\nor\n$ pip " + "install sanic-ext" + ) + + return + + if not getattr(app, "_ext", None): + Ext: Extend = getattr(sanic_ext, "Extend") + app._ext = Ext(app, **kwargs) + + return app.ext diff --git a/sanic/application/motd.py b/sanic/application/motd.py index 32825b12..4de046a5 100644 --- a/sanic/application/motd.py +++ b/sanic/application/motd.py @@ -41,9 +41,6 @@ class MOTD(ABC): class MOTDBasic(MOTD): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - def display(self): if self.logo: logger.debug(self.logo) diff --git a/sanic/application/state.py b/sanic/application/state.py index f5ff4fe4..0345dad5 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -5,7 +5,9 @@ import logging from dataclasses import dataclass, field from enum import Enum, auto from pathlib import Path -from typing import TYPE_CHECKING, Any, Set, Union +from socket import socket +from ssl import SSLContext +from typing import TYPE_CHECKING, Any, Optional, Set, Union from sanic.log import logger @@ -37,8 +39,11 @@ class ApplicationState: coffee: bool = field(default=False) fast: bool = field(default=False) host: str = field(default="") - mode: Mode = field(default=Mode.PRODUCTION) port: int = field(default=0) + ssl: Optional[SSLContext] = field(default=None) + sock: Optional[socket] = field(default=None) + unix: Optional[str] = field(default=None) + mode: Mode = field(default=Mode.PRODUCTION) reload_dirs: Set[Path] = field(default_factory=set) server: Server = field(default=Server.SANIC) is_running: bool = field(default=False) diff --git a/sanic/blueprint_group.py b/sanic/blueprint_group.py index a9b51410..b16d8c58 100644 --- a/sanic/blueprint_group.py +++ b/sanic/blueprint_group.py @@ -5,7 +5,7 @@ from functools import partial from typing import TYPE_CHECKING, List, Optional, Union -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.blueprints import Blueprint diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 521f77f6..df4501dd 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -36,8 +36,8 @@ from sanic.models.handler_types import ( ) -if TYPE_CHECKING: - from sanic import Sanic # noqa +if TYPE_CHECKING: # no cov + from sanic import Sanic def lazy(func, as_decorator=True): diff --git a/sanic/config.py b/sanic/config.py index d0b3c620..133efcdc 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -18,6 +18,7 @@ SANIC_PREFIX = "SANIC_" DEFAULT_CONFIG = { "_FALLBACK_ERROR_FORMAT": _default, "ACCESS_LOG": True, + "AUTO_EXTEND": True, "AUTO_RELOAD": False, "EVENT_AUTOREGISTER": False, "FORWARDED_FOR_HEADER": "X-Forwarded-For", @@ -59,6 +60,7 @@ class DescriptorMeta(type): class Config(dict, metaclass=DescriptorMeta): ACCESS_LOG: bool + AUTO_EXTEND: bool AUTO_RELOAD: bool EVENT_AUTOREGISTER: bool FORWARDED_FOR_HEADER: str diff --git a/sanic/http.py b/sanic/http.py index 86f23fe3..c7523f33 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.request import Request from sanic.response import BaseHTTPResponse diff --git a/sanic/request.py b/sanic/request.py index ddec6e82..97ab9982 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -15,7 +15,7 @@ from typing import ( from sanic_routing.route import Route # type: ignore -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.server import ConnInfo from sanic.app import Sanic diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 21c7cdaa..db33057b 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -21,6 +21,7 @@ from functools import partial from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func +from sanic.application.ext import setup_ext from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows from sanic.log import error_logger, logger from sanic.models.server_types import Signal @@ -116,6 +117,7 @@ def serve( **asyncio_server_kwargs, ) + setup_ext(app) if run_async: return AsyncioServer( app=app, diff --git a/sanic/views.py b/sanic/views.py index da44f8b7..23cd110d 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -13,7 +13,7 @@ from typing import ( from sanic.models.handler_types import RouteHandler -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic import Sanic from sanic.blueprints import Blueprint @@ -81,6 +81,8 @@ class HTTPMethodView: def dispatch_request(self, request, *args, **kwargs): handler = getattr(self, request.method.lower(), None) + if not handler and request.method == "HEAD": + handler = self.get return handler(request, *args, **kwargs) @classmethod diff --git a/sanic/worker.py b/sanic/worker.py index cdc238e2..befe8d78 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -15,10 +15,10 @@ from sanic.server.protocols.websocket_protocol import WebSocketProtocol try: import ssl # type: ignore -except ImportError: +except ImportError: # no cov ssl = None # type: ignore -if UVLOOP_INSTALLED: +if UVLOOP_INSTALLED: # no cov try_use_uvloop() diff --git a/setup.py b/setup.py index 36de0c4f..ea6f285a 100644 --- a/setup.py +++ b/setup.py @@ -147,6 +147,7 @@ extras_require = { "dev": dev_require, "docs": docs_require, "all": all_require, + "ext": ["sanic-ext"], } setup_kwargs["install_requires"] = requirements diff --git a/tests/conftest.py b/tests/conftest.py index 292914cd..22decde5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,10 @@ import string import sys import uuid +from contextlib import suppress from logging import LogRecord -from typing import Callable, List, Tuple +from typing import List, Tuple +from unittest.mock import MagicMock import pytest @@ -184,3 +186,21 @@ def message_in_records(): return error_captured return msg_in_log + + +@pytest.fixture +def ext_instance(): + ext_instance = MagicMock() + ext_instance.injection = MagicMock() + return ext_instance + + +@pytest.fixture(autouse=True) # type: ignore +def sanic_ext(ext_instance): # noqa + sanic_ext = MagicMock(__version__="1.2.3") + sanic_ext.Extend = MagicMock() + sanic_ext.Extend.return_value = ext_instance + sys.modules["sanic_ext"] = sanic_ext + yield sanic_ext + with suppress(KeyError): + del sys.modules["sanic_ext"] diff --git a/tests/test_cli.py b/tests/test_cli.py index 86daa36f..254c91d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -32,6 +32,12 @@ def starting_line(lines): return 0 +def read_app_info(lines): + for line in lines: + if line.startswith(b"{") and line.endswith(b"}"): + return json.loads(line) + + @pytest.mark.parametrize( "appname", ( @@ -199,9 +205,7 @@ def test_debug(cmd): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - - app_info = lines[starting_line(lines) + 9] - info = json.loads(app_info) + info = read_app_info(lines) assert info["debug"] is True assert info["auto_reload"] is True @@ -212,9 +216,7 @@ def test_auto_reload(cmd): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - - app_info = lines[starting_line(lines) + 9] - info = json.loads(app_info) + info = read_app_info(lines) assert info["debug"] is False assert info["auto_reload"] is True @@ -227,9 +229,7 @@ def test_access_logs(cmd, expected): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - - app_info = lines[starting_line(lines) + 8] - info = json.loads(app_info) + info = read_app_info(lines) assert info["access_log"] is expected @@ -254,8 +254,6 @@ def test_noisy_exceptions(cmd, expected): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - - app_info = lines[starting_line(lines) + 8] - info = json.loads(app_info) + info = read_app_info(lines) assert info["noisy_exceptions"] is expected diff --git a/tests/test_ext_integration.py b/tests/test_ext_integration.py new file mode 100644 index 00000000..ec311a02 --- /dev/null +++ b/tests/test_ext_integration.py @@ -0,0 +1,84 @@ +import sys + +from unittest.mock import MagicMock + +import pytest + +from sanic import Sanic + + +try: + import sanic_ext + + SANIC_EXT_IN_ENV = True +except ImportError: + SANIC_EXT_IN_ENV = False + + +@pytest.fixture +def stoppable_app(app): + @app.before_server_start + async def stop(*_): + app.stop() + + return app + + +def test_ext_is_loaded(stoppable_app: Sanic, sanic_ext): + stoppable_app.run() + sanic_ext.Extend.assert_called_once_with(stoppable_app) + + +def test_ext_is_not_loaded(stoppable_app: Sanic, sanic_ext): + stoppable_app.config.AUTO_EXTEND = False + stoppable_app.run() + sanic_ext.Extend.assert_not_called() + + +def test_extend_with_args(stoppable_app: Sanic, sanic_ext): + stoppable_app.extend(built_in_extensions=False) + stoppable_app.run() + sanic_ext.Extend.assert_called_once_with( + stoppable_app, built_in_extensions=False, config=None, extensions=None + ) + + +def test_access_object_sets_up_extension(app: Sanic, sanic_ext): + app.ext + sanic_ext.Extend.assert_called_once_with(app) + + +def test_extend_cannot_be_called_multiple_times(app: Sanic, sanic_ext): + app.extend() + + message = "Cannot extend Sanic after Sanic Extensions has been setup." + with pytest.raises(RuntimeError, match=message): + app.extend() + sanic_ext.Extend.assert_called_once_with( + app, extensions=None, built_in_extensions=True, config=None + ) + + +@pytest.mark.skipif( + SANIC_EXT_IN_ENV, + reason="Running tests with sanic_ext already in the environment", +) +def test_fail_if_not_loaded(app: Sanic): + del sys.modules["sanic_ext"] + with pytest.raises( + RuntimeError, match="Sanic Extensions is not installed.*" + ): + app.extend(built_in_extensions=False) + + +def test_can_access_app_ext_while_running(app: Sanic, sanic_ext, ext_instance): + class IceCream: + flavor: str + + @app.before_server_start + async def injections(*_): + app.ext.injection(IceCream) + app.stop() + + app.run() + ext_instance.injection.assert_called_with(IceCream)