From dc3c4d13932a3f753bd6eb5da815bbf642a02185 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 12 Jul 2023 23:47:58 +0300 Subject: [PATCH] Add custom typing to config and ctx (#2785) --- sanic/__init__.py | 19 +++ sanic/app.py | 114 ++++++++++++++++- sanic/blueprints.py | 2 +- sanic/errorpages.py | 2 +- sanic/exceptions.py | 2 +- sanic/headers.py | 2 +- sanic/request/types.py | 41 +++++- sanic/router.py | 2 +- sanic/server/websockets/frame.py | 1 + sanic/server/websockets/impl.py | 2 +- setup.py | 3 +- tests/test_app.py | 2 +- tests/test_request.py | 26 ++++ tests/typing/samples/app_custom_config.py | 10 ++ tests/typing/samples/app_custom_ctx.py | 9 ++ tests/typing/samples/app_default.py | 5 + tests/typing/samples/app_fully_custom.py | 14 ++ tests/typing/samples/request_custom_ctx.py | 17 +++ tests/typing/samples/request_custom_sanic.py | 19 +++ tests/typing/samples/request_fully_custom.py | 34 +++++ tests/typing/test_typing.py | 127 +++++++++++++++++++ 21 files changed, 433 insertions(+), 20 deletions(-) create mode 100644 tests/typing/samples/app_custom_config.py create mode 100644 tests/typing/samples/app_custom_ctx.py create mode 100644 tests/typing/samples/app_default.py create mode 100644 tests/typing/samples/app_fully_custom.py create mode 100644 tests/typing/samples/request_custom_ctx.py create mode 100644 tests/typing/samples/request_custom_sanic.py create mode 100644 tests/typing/samples/request_fully_custom.py create mode 100644 tests/typing/test_typing.py diff --git a/sanic/__init__.py b/sanic/__init__.py index 67394ae6..8438ded4 100644 --- a/sanic/__init__.py +++ b/sanic/__init__.py @@ -1,6 +1,11 @@ +from types import SimpleNamespace + +from typing_extensions import TypeAlias + from sanic.__version__ import __version__ from sanic.app import Sanic from sanic.blueprints import Blueprint +from sanic.config import Config from sanic.constants import HTTPMethod from sanic.exceptions import ( BadRequest, @@ -32,15 +37,29 @@ from sanic.response import ( from sanic.server.websockets.impl import WebsocketImplProtocol as Websocket +DefaultSanic: TypeAlias = "Sanic[Config, SimpleNamespace]" +""" +A type alias for a Sanic app with a default config and namespace. +""" + +DefaultRequest: TypeAlias = Request[DefaultSanic, SimpleNamespace] +""" +A type alias for a request with a default Sanic app and namespace. +""" + __all__ = ( "__version__", # Common objects "Sanic", + "Config", "Blueprint", "HTTPMethod", "HTTPResponse", "Request", "Websocket", + # Common types + "DefaultSanic", + "DefaultRequest", # Common exceptions "BadRequest", "ExpectationFailed", diff --git a/sanic/app.py b/sanic/app.py index 02027878..e966a4fd 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -33,6 +33,7 @@ from typing import ( Coroutine, Deque, Dict, + Generic, Iterable, Iterator, List, @@ -42,6 +43,8 @@ from typing import ( Type, TypeVar, Union, + cast, + overload, ) from urllib.parse import urlencode, urlunparse @@ -103,8 +106,17 @@ if TYPE_CHECKING: if OS_IS_WINDOWS: # no cov enable_windows_color_support() +ctx_type = TypeVar("ctx_type") +config_type = TypeVar("config_type", bound=Config) -class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): + +class Sanic( + Generic[config_type, ctx_type], + StaticHandleMixin, + BaseSanic, + StartupMixin, + metaclass=TouchUpMeta, +): """ The main application instance """ @@ -162,11 +174,99 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): _app_registry: ClassVar[Dict[str, "Sanic"]] = {} test_mode: ClassVar[bool] = False + @overload + def __init__( + self: Sanic[Config, SimpleNamespace], + name: str, + config: None = None, + ctx: None = None, + router: Optional[Router] = None, + signal_router: Optional[SignalRouter] = None, + error_handler: Optional[ErrorHandler] = None, + env_prefix: Optional[str] = SANIC_PREFIX, + request_class: Optional[Type[Request]] = None, + strict_slashes: bool = False, + log_config: Optional[Dict[str, Any]] = None, + configure_logging: bool = True, + dumps: Optional[Callable[..., AnyStr]] = None, + loads: Optional[Callable[..., Any]] = None, + inspector: bool = False, + inspector_class: Optional[Type[Inspector]] = None, + certloader_class: Optional[Type[CertLoader]] = None, + ) -> None: + ... + + @overload + def __init__( + self: Sanic[config_type, SimpleNamespace], + name: str, + config: Optional[config_type] = None, + ctx: None = None, + router: Optional[Router] = None, + signal_router: Optional[SignalRouter] = None, + error_handler: Optional[ErrorHandler] = None, + env_prefix: Optional[str] = SANIC_PREFIX, + request_class: Optional[Type[Request]] = None, + strict_slashes: bool = False, + log_config: Optional[Dict[str, Any]] = None, + configure_logging: bool = True, + dumps: Optional[Callable[..., AnyStr]] = None, + loads: Optional[Callable[..., Any]] = None, + inspector: bool = False, + inspector_class: Optional[Type[Inspector]] = None, + certloader_class: Optional[Type[CertLoader]] = None, + ) -> None: + ... + + @overload + def __init__( + self: Sanic[Config, ctx_type], + name: str, + config: None = None, + ctx: Optional[ctx_type] = None, + router: Optional[Router] = None, + signal_router: Optional[SignalRouter] = None, + error_handler: Optional[ErrorHandler] = None, + env_prefix: Optional[str] = SANIC_PREFIX, + request_class: Optional[Type[Request]] = None, + strict_slashes: bool = False, + log_config: Optional[Dict[str, Any]] = None, + configure_logging: bool = True, + dumps: Optional[Callable[..., AnyStr]] = None, + loads: Optional[Callable[..., Any]] = None, + inspector: bool = False, + inspector_class: Optional[Type[Inspector]] = None, + certloader_class: Optional[Type[CertLoader]] = None, + ) -> None: + ... + + @overload + def __init__( + self: Sanic[config_type, ctx_type], + name: str, + config: Optional[config_type] = None, + ctx: Optional[ctx_type] = None, + router: Optional[Router] = None, + signal_router: Optional[SignalRouter] = None, + error_handler: Optional[ErrorHandler] = None, + env_prefix: Optional[str] = SANIC_PREFIX, + request_class: Optional[Type[Request]] = None, + strict_slashes: bool = False, + log_config: Optional[Dict[str, Any]] = None, + configure_logging: bool = True, + dumps: Optional[Callable[..., AnyStr]] = None, + loads: Optional[Callable[..., Any]] = None, + inspector: bool = False, + inspector_class: Optional[Type[Inspector]] = None, + certloader_class: Optional[Type[CertLoader]] = None, + ) -> None: + ... + def __init__( self, - name: Optional[str] = None, - config: Optional[Config] = None, - ctx: Optional[Any] = None, + name: str, + config: Optional[config_type] = None, + ctx: Optional[ctx_type] = None, router: Optional[Router] = None, signal_router: Optional[SignalRouter] = None, error_handler: Optional[ErrorHandler] = None, @@ -194,7 +294,9 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): ) # First setup config - self.config: Config = config or Config(env_prefix=env_prefix) + self.config: config_type = cast( + config_type, config or Config(env_prefix=env_prefix) + ) if inspector: self.config.INSPECTOR = inspector @@ -218,7 +320,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta): certloader_class or CertLoader ) self.configure_logging: bool = configure_logging - self.ctx: Any = ctx or SimpleNamespace() + self.ctx: ctx_type = cast(ctx_type, ctx or SimpleNamespace()) self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.inspector_class: Type[Inspector] = inspector_class or Inspector self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index fc9408fe..7551000b 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -111,7 +111,7 @@ class Blueprint(BaseSanic): def __init__( self, - name: str = None, + name: str, url_prefix: Optional[str] = None, host: Optional[Union[List[str], str]] = None, version: Optional[Union[int, str, float]] = None, diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 7b243ddb..e7ea165b 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -312,7 +312,7 @@ def exception_response( debug: bool, fallback: str, base: t.Type[BaseRenderer], - renderer: t.Type[t.Optional[BaseRenderer]] = None, + renderer: t.Optional[t.Type[BaseRenderer]] = None, ) -> HTTPResponse: """ Render a response for the default FALLBACK exception handler. diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 51c0bb7f..cebd77db 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -90,7 +90,7 @@ class SanicException(Exception): super().__init__(message) - self.status_code = status_code + self.status_code = status_code or self.status_code self.quiet = quiet self.headers = headers diff --git a/sanic/headers.py b/sanic/headers.py index b7754adb..ce0f8c91 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -436,7 +436,7 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: def parse_credentials( header: Optional[str], - prefixes: Union[List, Tuple, Set] = None, + prefixes: Optional[Union[List, Tuple, Set]] = None, ) -> Tuple[Optional[str], Optional[str]]: """Parses any header with the aim to retrieve any credentials from it.""" if not prefixes or not isinstance(prefixes, (list, tuple, set)): diff --git a/sanic/request/types.py b/sanic/request/types.py index 2e075e88..98f6e908 100644 --- a/sanic/request/types.py +++ b/sanic/request/types.py @@ -2,11 +2,13 @@ from __future__ import annotations from contextvars import ContextVar from inspect import isawaitable +from types import SimpleNamespace from typing import ( TYPE_CHECKING, Any, DefaultDict, Dict, + Generic, List, Optional, Tuple, @@ -15,6 +17,7 @@ from typing import ( ) from sanic_routing.route import Route +from typing_extensions import TypeVar from sanic.http.constants import HTTP # type: ignore from sanic.http.stream import Stream @@ -23,13 +26,13 @@ from sanic.models.http_types import Credentials if TYPE_CHECKING: - from sanic.server import ConnInfo from sanic.app import Sanic + from sanic.config import Config + from sanic.server import ConnInfo import uuid from collections import defaultdict -from types import SimpleNamespace from urllib.parse import parse_qs, parse_qsl, urlunparse from httptools import parse_url @@ -68,8 +71,21 @@ try: except ImportError: from json import loads as json_loads # type: ignore +if TYPE_CHECKING: + # The default argument of TypeVar is proposed to be added in Python 3.13 + # by PEP 696 (https://www.python.org/dev/peps/pep-0696/). + # Therefore, we use typing_extensions.TypeVar for compatibility. + # For more information, see: + # https://discuss.python.org/t/pep-696-type-defaults-for-typevarlikes + sanic_type = TypeVar( + "sanic_type", bound=Sanic, default=Sanic[Config, SimpleNamespace] + ) +else: + sanic_type = TypeVar("sanic_type") +ctx_type = TypeVar("ctx_type") -class Request: + +class Request(Generic[sanic_type, ctx_type]): """ Properties of an HTTP request such as URL, headers, etc. """ @@ -80,6 +96,7 @@ class Request: __slots__ = ( "__weakref__", "_cookies", + "_ctx", "_id", "_ip", "_parsed_url", @@ -96,7 +113,6 @@ class Request: "app", "body", "conn_info", - "ctx", "head", "headers", "method", @@ -125,7 +141,7 @@ class Request: version: str, method: str, transport: TransportProtocol, - app: Sanic, + app: sanic_type, head: bytes = b"", stream_id: int = 0, ): @@ -149,7 +165,7 @@ class Request: # Init but do not inhale self.body = b"" self.conn_info: Optional[ConnInfo] = None - self.ctx = SimpleNamespace() + self._ctx: Optional[ctx_type] = None self.parsed_accept: Optional[AcceptList] = None self.parsed_args: DefaultDict[ Tuple[bool, bool, str, str], RequestParameters @@ -176,6 +192,10 @@ class Request: class_name = self.__class__.__name__ return f"<{class_name}: {self.method} {self.path}>" + @staticmethod + def make_context() -> ctx_type: + return cast(ctx_type, SimpleNamespace()) + @classmethod def get_current(cls) -> Request: """ @@ -205,6 +225,15 @@ class Request: def generate_id(*_): return uuid.uuid4() + @property + def ctx(self) -> ctx_type: + """ + :return: The current request context + """ + if not self._ctx: + self._ctx = self.make_context() + return self._ctx + @property def stream_id(self): """ diff --git a/sanic/router.py b/sanic/router.py index 17339cb2..854a3e98 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -75,7 +75,7 @@ class Router(BaseRouter): strict_slashes: bool = False, stream: bool = False, ignore_body: bool = False, - version: Union[str, float, int] = None, + version: Optional[Union[str, float, int]] = None, name: Optional[str] = None, unquote: bool = False, static: bool = False, diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index 130dc5fa..dacc187c 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -96,6 +96,7 @@ class WebsocketFrameAssembler: If ``timeout`` is set and elapses before a complete message is received, :meth:`get` returns ``None``. """ + completed: bool async with self.read_mutex: if timeout is not None and timeout <= 0: if not self.message_complete.is_set(): diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index c76ebc04..71439cfd 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -21,7 +21,7 @@ from websockets.frames import Frame, Opcode try: # websockets < 11.0 - from websockets.connection import Event, State + from websockets.connection import Event, State # type: ignore from websockets.server import ServerConnection as ServerProtocol except ImportError: # websockets >= 11.0 from websockets.protocol import Event, State # type: ignore diff --git a/setup.py b/setup.py index bf172c20..1f321a9f 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ requirements = [ "multidict>=5.0,<7.0", "html5tagger>=1.2.1", "tracerite>=1.0.0", + "typing-extensions>=4.4.0", ] tests_require = [ @@ -126,7 +127,7 @@ tests_require = [ "black", "isort>=5.0.0", "bandit", - "mypy>=0.901,<0.910", + "mypy", "docutils", "pygments", "uvicorn<0.15.0", diff --git a/tests/test_app.py b/tests/test_app.py index b670ef45..aefec3e6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -293,7 +293,7 @@ def test_handle_request_with_nested_sanic_exception( def test_app_name_required(): - with pytest.raises(SanicException): + with pytest.raises(TypeError): Sanic() diff --git a/tests/test_request.py b/tests/test_request.py index 0762357d..2a527043 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -310,3 +310,29 @@ def test_request_idempotent(method, idempotent): def test_request_cacheable(method, cacheable): request = Request(b"/", {}, None, method, None, None) assert request.is_cacheable is cacheable + + +def test_custom_ctx(): + class CustomContext: + FOO = "foo" + + class CustomRequest(Request[Sanic, CustomContext]): + @staticmethod + def make_context() -> CustomContext: + return CustomContext() + + app = Sanic("Test", request_class=CustomRequest) + + @app.get("/") + async def handler(request: CustomRequest): + return response.json( + [ + isinstance(request, CustomRequest), + isinstance(request.ctx, CustomContext), + request.ctx.FOO, + ] + ) + + _, resp = app.test_client.get("/") + + assert resp.json == [True, True, "foo"] diff --git a/tests/typing/samples/app_custom_config.py b/tests/typing/samples/app_custom_config.py new file mode 100644 index 00000000..5f10e7a9 --- /dev/null +++ b/tests/typing/samples/app_custom_config.py @@ -0,0 +1,10 @@ +from sanic import Sanic +from sanic.config import Config + + +class CustomConfig(Config): + pass + + +app = Sanic("test", config=CustomConfig()) +reveal_type(app) diff --git a/tests/typing/samples/app_custom_ctx.py b/tests/typing/samples/app_custom_ctx.py new file mode 100644 index 00000000..fb0dc1bb --- /dev/null +++ b/tests/typing/samples/app_custom_ctx.py @@ -0,0 +1,9 @@ +from sanic import Sanic + + +class Foo: + pass + + +app = Sanic("test", ctx=Foo()) +reveal_type(app) diff --git a/tests/typing/samples/app_default.py b/tests/typing/samples/app_default.py new file mode 100644 index 00000000..34524c50 --- /dev/null +++ b/tests/typing/samples/app_default.py @@ -0,0 +1,5 @@ +from sanic import Sanic + + +app = Sanic("test") +reveal_type(app) diff --git a/tests/typing/samples/app_fully_custom.py b/tests/typing/samples/app_fully_custom.py new file mode 100644 index 00000000..197f1e03 --- /dev/null +++ b/tests/typing/samples/app_fully_custom.py @@ -0,0 +1,14 @@ +from sanic import Sanic +from sanic.config import Config + + +class CustomConfig(Config): + pass + + +class Foo: + pass + + +app = Sanic("test", config=CustomConfig(), ctx=Foo()) +reveal_type(app) diff --git a/tests/typing/samples/request_custom_ctx.py b/tests/typing/samples/request_custom_ctx.py new file mode 100644 index 00000000..5fa4fe6a --- /dev/null +++ b/tests/typing/samples/request_custom_ctx.py @@ -0,0 +1,17 @@ +from types import SimpleNamespace + +from sanic import Request, Sanic +from sanic.config import Config + + +class Foo: + pass + + +app = Sanic("test") + + +@app.get("/") +async def handler(request: Request[Sanic[Config, SimpleNamespace], Foo]): + reveal_type(request.ctx) + reveal_type(request.app) diff --git a/tests/typing/samples/request_custom_sanic.py b/tests/typing/samples/request_custom_sanic.py new file mode 100644 index 00000000..bdacd92b --- /dev/null +++ b/tests/typing/samples/request_custom_sanic.py @@ -0,0 +1,19 @@ +from types import SimpleNamespace + +from sanic import Request, Sanic +from sanic.config import Config + + +class CustomConfig(Config): + pass + + +app = Sanic("test", config=CustomConfig()) + + +@app.get("/") +async def handler( + request: Request[Sanic[CustomConfig, SimpleNamespace], SimpleNamespace] +): + reveal_type(request.ctx) + reveal_type(request.app) diff --git a/tests/typing/samples/request_fully_custom.py b/tests/typing/samples/request_fully_custom.py new file mode 100644 index 00000000..e2ec4b31 --- /dev/null +++ b/tests/typing/samples/request_fully_custom.py @@ -0,0 +1,34 @@ +from sanic import Request, Sanic +from sanic.config import Config + + +class CustomConfig(Config): + pass + + +class Foo: + pass + + +class RequestContext: + foo: Foo + + +class CustomRequest(Request[Sanic[CustomConfig, Foo], RequestContext]): + @staticmethod + def make_context() -> RequestContext: + ctx = RequestContext() + ctx.foo = Foo() + return ctx + + +app = Sanic( + "test", config=CustomConfig(), ctx=Foo(), request_class=CustomRequest +) + + +@app.get("/") +async def handler(request: CustomRequest): + reveal_type(request) + reveal_type(request.ctx) + reveal_type(request.app) diff --git a/tests/typing/test_typing.py b/tests/typing/test_typing.py new file mode 100644 index 00000000..5ebba266 --- /dev/null +++ b/tests/typing/test_typing.py @@ -0,0 +1,127 @@ +# flake8: noqa: E501 + +import subprocess +import sys + +from pathlib import Path +from typing import List, Tuple + +import pytest + + +CURRENT_DIR = Path(__file__).parent + + +def run_check(path_location: str) -> str: + """Use mypy to check the given path location and return the output.""" + + mypy_path = "mypy" + path = CURRENT_DIR / path_location + command = [mypy_path, path.resolve().as_posix()] + + process = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + output = process.stdout + process.stderr + return output + + +@pytest.mark.parametrize( + "path_location,expected", + ( + ( + "app_default.py", + [ + ( + "sanic.app.Sanic[sanic.config.Config, types.SimpleNamespace]", + 5, + ) + ], + ), + ( + "app_custom_config.py", + [ + ( + "sanic.app.Sanic[app_custom_config.CustomConfig, types.SimpleNamespace]", + 10, + ) + ], + ), + ( + "app_custom_ctx.py", + [("sanic.app.Sanic[sanic.config.Config, app_custom_ctx.Foo]", 9)], + ), + ( + "app_fully_custom.py", + [ + ( + "sanic.app.Sanic[app_fully_custom.CustomConfig, app_fully_custom.Foo]", + 14, + ) + ], + ), + ( + "request_custom_sanic.py", + [ + ("types.SimpleNamespace", 18), + ( + "sanic.app.Sanic[request_custom_sanic.CustomConfig, types.SimpleNamespace]", + 19, + ), + ], + ), + ( + "request_custom_ctx.py", + [ + ("request_custom_ctx.Foo", 16), + ( + "sanic.app.Sanic[sanic.config.Config, types.SimpleNamespace]", + 17, + ), + ], + ), + ( + "request_fully_custom.py", + [ + ("request_fully_custom.CustomRequest", 32), + ("request_fully_custom.RequestContext", 33), + ( + "sanic.app.Sanic[request_fully_custom.CustomConfig, request_fully_custom.Foo]", + 34, + ), + ], + ), + ), +) +def test_check_app_default( + path_location: str, expected: List[Tuple[str, int]] +) -> None: + output = run_check(f"samples/{path_location}") + + for text, number in expected: + current = CURRENT_DIR / f"samples/{path_location}" + path = current.relative_to(CURRENT_DIR.parent) + + target = Path.cwd() + while True: + note = _text_from_path(current, path, target, number, text) + try: + assert note in output, output + except AssertionError: + target = target.parent + if not target.exists(): + raise + else: + break + + +def _text_from_path( + base: Path, path: Path, target: Path, number: int, text: str +) -> str: + relative_to_cwd = base.relative_to(target) + prefix = ".".join(relative_to_cwd.parts[:-1]) + text = text.replace(path.stem, f"{prefix}.{path.stem}") + return f'{path}:{number}: note: Revealed type is "{text}"'