diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 13a5783d..00000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,28 +0,0 @@ -exclude_patterns: - - "sanic/__main__.py" - - "sanic/application/logo.py" - - "sanic/application/motd.py" - - "sanic/reloader_helpers.py" - - "sanic/simple.py" - - "sanic/utils.py" - - ".github/" - - "changelogs/" - - "docker/" - - "docs/" - - "examples/" - - "scripts/" - - "tests/" -checks: - argument-count: - enabled: false - file-lines: - config: - threshold: 1000 - method-count: - config: - threshold: 40 - complex-logic: - enabled: false - method-complexity: - config: - threshold: 10 diff --git a/.coveragerc b/.coveragerc index 63bec82c..22856065 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,13 +3,12 @@ branch = True source = sanic omit = site-packages - sanic/application/logo.py - sanic/application/motd.py - sanic/cli sanic/__main__.py + sanic/compat.py sanic/reloader_helpers.py sanic/simple.py sanic/utils.py + sanic/cli [html] directory = coverage @@ -21,3 +20,12 @@ exclude_lines = noqa NOQA pragma: no cover +omit = + site-packages + sanic/__main__.py + sanic/compat.py + sanic/reloader_helpers.py + sanic/simple.py + sanic/utils.py + sanic/cli +skip_empty = True diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 5108c247..1113fa80 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -2,9 +2,13 @@ name: "CodeQL" on: push: - branches: [ main ] + branches: + - main + - "*LTS" pull_request: - branches: [ main ] + branches: + - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] schedule: - cron: '25 16 * * 0' diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 92d93aa7..13e82615 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -3,13 +3,15 @@ on: push: branches: - main + - "*LTS" tags: - "!*" # Do not execute on tags pull_request: - types: [opened, synchronize, reopened, ready_for_review] + branches: + - main + - "*LTS" jobs: test: - if: github.event.pull_request.draft == false runs-on: ${{ matrix.os }} strategy: matrix: @@ -19,7 +21,6 @@ jobs: steps: - uses: actions/checkout@v2 - - uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} @@ -28,9 +29,10 @@ jobs: run: | python -m pip install --upgrade pip pip install tox - - uses: paambaati/codeclimate-action@v2.5.3 - if: always() - env: - CC_TEST_REPORTER_ID: ${{ secrets.CODECLIMATE }} + - name: Run coverage + run: tox -e coverage + continue-on-error: true + - uses: codecov/codecov-action@v2 with: - coverageCommand: tox -e coverage + files: ./coverage.xml + fail_ci_if_error: false diff --git a/.github/workflows/pr-bandit.yml b/.github/workflows/pr-bandit.yml index ca91312a..2bd70204 100644 --- a/.github/workflows/pr-bandit.yml +++ b/.github/workflows/pr-bandit.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-docs.yml b/.github/workflows/pr-docs.yml index 7b3c2f6e..8479aef5 100644 --- a/.github/workflows/pr-docs.yml +++ b/.github/workflows/pr-docs.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-linter.yml b/.github/workflows/pr-linter.yml index 9ed45d0a..11ad9d29 100644 --- a/.github/workflows/pr-linter.yml +++ b/.github/workflows/pr-linter.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python310.yml b/.github/workflows/pr-python310.yml index f3f7c607..5e66deec 100644 --- a/.github/workflows/pr-python310.yml +++ b/.github/workflows/pr-python310.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python37.yml b/.github/workflows/pr-python37.yml index 50f79c6e..c0051d33 100644 --- a/.github/workflows/pr-python37.yml +++ b/.github/workflows/pr-python37.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python38.yml b/.github/workflows/pr-python38.yml index 1e0b8050..09e93f3f 100644 --- a/.github/workflows/pr-python38.yml +++ b/.github/workflows/pr-python38.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python39.yml b/.github/workflows/pr-python39.yml index 1abd6bcb..ff479459 100644 --- a/.github/workflows/pr-python39.yml +++ b/.github/workflows/pr-python39.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-type-check.yml b/.github/workflows/pr-type-check.yml index 2fae03be..58a90ee3 100644 --- a/.github/workflows/pr-type-check.yml +++ b/.github/workflows/pr-type-check.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: @@ -15,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] config: - - { python-version: 3.7, tox-env: type-checking} + # - { python-version: 3.7, tox-env: type-checking} - { python-version: 3.8, tox-env: type-checking} - { python-version: 3.9, tox-env: type-checking} - { python-version: "3.10", tox-env: type-checking} diff --git a/.github/workflows/pr-windows.yml b/.github/workflows/pr-windows.yml index 9721b5b5..050ee2bb 100644 --- a/.github/workflows/pr-windows.yml +++ b/.github/workflows/pr-windows.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/README.rst b/README.rst index 6b6d0408..316c323d 100644 --- a/README.rst +++ b/README.rst @@ -66,7 +66,7 @@ Sanic | Build fast. Run fast. Sanic is a **Python 3.7+** web server and web framework that's written to go fast. It allows the usage of the ``async/await`` syntax added in Python 3.5, which makes your code non-blocking and speedy. -Sanic is also ASGI compliant, so you can deploy it with an `alternative ASGI webserver `_. +Sanic is also ASGI compliant, so you can deploy it with an `alternative ASGI webserver `_. `Source code on GitHub `_ | `Help and discussion board `_ | `User Guide `_ | `Chat on Discord `_ diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..5905afd5 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,27 @@ +coverage: + status: + patch: + default: + target: auto + threshold: 0.75 + project: + default: + target: auto + threshold: 0.5 + precision: 3 +codecov: + require_ci_to_pass: false +ignore: + - "sanic/__main__.py" + - "sanic/compat.py" + - "sanic/reloader_helpers.py" + - "sanic/simple.py" + - "sanic/utils.py" + - "sanic/cli" + - ".github/" + - "changelogs/" + - "docker/" + - "docs/" + - "examples/" + - "scripts/" + - "tests/" diff --git a/docs/sanic/releases/21/21.12.md b/docs/sanic/releases/21/21.12.md index 6c4dc419..f8f0d954 100644 --- a/docs/sanic/releases/21/21.12.md +++ b/docs/sanic/releases/21/21.12.md @@ -1,3 +1,9 @@ +## Version 21.12.1 + +- [#2349](https://github.com/sanic-org/sanic/pull/2349) Only display MOTD on startup +- [#2354](https://github.com/sanic-org/sanic/pull/2354) Ignore name argument in Python 3.7 +- [#2355](https://github.com/sanic-org/sanic/pull/2355) Add config.update support for all config values + ## Version 21.12.0 ### Features diff --git a/sanic/app.py b/sanic/app.py index 55c4920e..5f3a0438 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -3,28 +3,24 @@ from __future__ import annotations import asyncio import logging import logging.config -import os -import platform import re import sys from asyncio import ( AbstractEventLoop, CancelledError, - Protocol, Task, ensure_future, get_event_loop, + get_running_loop, wait_for, ) from asyncio.futures import Future from collections import defaultdict, deque +from contextlib import suppress from functools import partial -from importlib import import_module from inspect import isawaitable -from pathlib import Path from socket import socket -from ssl import SSLContext from traceback import format_exc from types import SimpleNamespace from typing import ( @@ -38,7 +34,6 @@ from typing import ( Dict, Iterable, List, - Literal, Optional, Set, Tuple, @@ -55,11 +50,8 @@ from sanic_routing.exceptions import ( # type: ignore ) from sanic_routing.route import Route # type: ignore -from sanic import reloader_helpers from sanic.application.ext import setup_ext -from sanic.application.logo import get_logo -from sanic.application.motd import MOTD -from sanic.application.state import ApplicationState, Mode +from sanic.application.state import ApplicationState, Mode, ServerStage from sanic.asgi import ASGIApp from sanic.base.root import BaseSanic from sanic.blueprint_group import BlueprintGroup @@ -73,18 +65,15 @@ from sanic.exceptions import ( URLBuildError, ) from sanic.handlers import ErrorHandler -from sanic.helpers import _default from sanic.http import Stage -from sanic.http.constants import HTTP -from sanic.http.tls import process_to_context from sanic.log import ( LOGGING_CONFIG_DEFAULTS, - Colors, deprecation, error_logger, logger, ) from sanic.mixins.listeners import ListenerEvent +from sanic.mixins.runner import RunnerMixin from sanic.models.futures import ( FutureException, FutureListener, @@ -99,10 +88,6 @@ from sanic.models.handler_types import Sanic as SanicVar from sanic.request import Request from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream from sanic.router import Router -from sanic.server import AsyncioServer, HttpProtocol -from sanic.server import Signal as ServerSignal -from sanic.server import serve, serve_multiple, serve_single, try_use_uvloop -from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta @@ -116,15 +101,13 @@ if TYPE_CHECKING: # no cov Extend = TypeVar("Extend") # type: ignore -if OS_IS_WINDOWS: +if OS_IS_WINDOWS: # no cov enable_windows_color_support() filterwarnings("once", category=DeprecationWarning) -SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") - -class Sanic(BaseSanic, metaclass=TouchUpMeta): +class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): """ The main application instance """ @@ -223,7 +206,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.blueprints: Dict[str, Blueprint] = {} self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() - self.debug = False self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} @@ -267,7 +249,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): Only supported when using the `app.run` method. """ - if not self.is_running and self.asgi is False: + if self.state.stage is ServerStage.STOPPED and self.asgi is False: raise SanicException( "Loop can only be retrieved after the app has started " "running. Not supported with `create_server` function" @@ -1054,286 +1036,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # Execution # -------------------------------------------------------------------- # - def make_coffee(self, *args, **kwargs): - self.state.coffee = True - self.run(*args, **kwargs) - - def run( - self, - host: Optional[str] = None, - port: Optional[int] = None, - *, - debug: bool = False, - auto_reload: Optional[bool] = None, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - workers: int = 1, - protocol: Optional[Type[Protocol]] = None, - backlog: int = 100, - register_sys_signals: bool = True, - access_log: Optional[bool] = None, - unix: Optional[str] = None, - loop: AbstractEventLoop = None, - reload_dir: Optional[Union[List[str], str]] = None, - noisy_exceptions: Optional[bool] = None, - motd: bool = True, - fast: bool = False, - verbosity: int = 0, - motd_display: Optional[Dict[str, str]] = None, - version: HTTP = HTTP.VERSION_1, - ) -> None: - """ - Run the HTTP Server and listen until keyboard interrupt or term - signal. On termination, drain connections before closing. - - :param host: Address to host on - :type host: str - :param port: Port to host on - :type port: int - :param debug: Enables debug output (slows server) - :type debug: bool - :param auto_reload: Reload app whenever its source code is changed. - Enabled by default in debug mode. - :type auto_relaod: bool - :param ssl: SSLContext, or location of certificate and key - for SSL encryption of worker(s) - :type ssl: str, dict, SSLContext or list - :param sock: Socket for the server to accept connections from - :type sock: socket - :param workers: Number of processes received before it is respected - :type workers: int - :param protocol: Subclass of asyncio Protocol class - :type protocol: type[Protocol] - :param backlog: a number of unaccepted connections that the system - will allow before refusing new connections - :type backlog: int - :param register_sys_signals: Register SIG* events - :type register_sys_signals: bool - :param access_log: Enables writing access logs (slows server) - :type access_log: bool - :param unix: Unix socket to listen on instead of TCP port - :type unix: str - :param noisy_exceptions: Log exceptions that are normally considered - to be quiet/silent - :type noisy_exceptions: bool - :return: Nothing - """ - self.state.verbosity = verbosity - - if fast and workers != 1: - raise RuntimeError("You cannot use both fast=True and workers=X") - - if motd_display: - self.config.MOTD_DISPLAY.update(motd_display) - - if reload_dir: - if isinstance(reload_dir, str): - reload_dir = [reload_dir] - - for directory in reload_dir: - direc = Path(directory) - if not direc.is_dir(): - logger.warning( - f"Directory {directory} could not be located" - ) - self.state.reload_dirs.add(Path(directory)) - - if loop is not None: - raise TypeError( - "loop is not a valid argument. To use an existing loop, " - "change to create_server().\nSee more: " - "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" - "#asynchronous-support" - ) - - if auto_reload or auto_reload is None and debug: - auto_reload = True - if os.environ.get("SANIC_SERVER_RUNNING") != "true": - return reloader_helpers.watchdog(1.0, self) - - if sock is None: - host, port = host or "127.0.0.1", port or 8000 - - if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) - - # Set explicitly passed configuration values - for attribute, value in { - "ACCESS_LOG": access_log, - "AUTO_RELOAD": auto_reload, - "MOTD": motd, - "NOISY_EXCEPTIONS": noisy_exceptions, - }.items(): - if value is not None: - setattr(self.config, attribute, value) - - if fast: - self.state.fast = True - try: - workers = len(os.sched_getaffinity(0)) - except AttributeError: - workers = os.cpu_count() or 1 - - server_settings = self._helper( - host=host, - port=port, - debug=debug, - ssl=ssl, - sock=sock, - unix=unix, - workers=workers, - protocol=protocol, - backlog=backlog, - register_sys_signals=register_sys_signals, - version=version, - ) - - if self.config.USE_UVLOOP is True or ( - self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS - ): - try_use_uvloop() - - try: - self.is_running = True - self.is_stopping = False - if workers > 1 and os.name != "posix": - logger.warn( - f"Multiprocessing is currently not supported on {os.name}," - " using workers=1 instead" - ) - workers = 1 - if workers == 1: - serve_single(server_settings) - else: - serve_multiple(server_settings, workers) - except BaseException: - error_logger.exception( - "Experienced exception while trying to serve" - ) - raise - finally: - self.is_running = False - logger.info("Server Stopped") - print("END OF RUN") - - def stop(self): - """ - This kills the Sanic - """ - if not self.is_stopping: - self.shutdown_tasks(timeout=0) - self.is_stopping = True - get_event_loop().stop() - - async def create_server( - self, - host: Optional[str] = None, - port: Optional[int] = None, - *, - debug: bool = False, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - protocol: Type[Protocol] = None, - backlog: int = 100, - access_log: Optional[bool] = None, - unix: Optional[str] = None, - return_asyncio_server: bool = False, - asyncio_server_kwargs: Dict[str, Any] = None, - noisy_exceptions: Optional[bool] = None, - ) -> Optional[AsyncioServer]: - """ - Asynchronous version of :func:`run`. - - This method will take care of the operations necessary to invoke - the *before_start* events via :func:`trigger_events` method invocation - before starting the *sanic* app in Async mode. - - .. note:: - This does not support multiprocessing and is not the preferred - way to run a :class:`Sanic` application. - - :param host: Address to host on - :type host: str - :param port: Port to host on - :type port: int - :param debug: Enables debug output (slows server) - :type debug: bool - :param ssl: SSLContext, or location of certificate and key - for SSL encryption of worker(s) - :type ssl: SSLContext or dict - :param sock: Socket for the server to accept connections from - :type sock: socket - :param protocol: Subclass of asyncio Protocol class - :type protocol: type[Protocol] - :param backlog: a number of unaccepted connections that the system - will allow before refusing new connections - :type backlog: int - :param access_log: Enables writing access logs (slows server) - :type access_log: bool - :param return_asyncio_server: flag that defines whether there's a need - to return asyncio.Server or - start it serving right away - :type return_asyncio_server: bool - :param asyncio_server_kwargs: key-value arguments for - asyncio/uvloop create_server method - :type asyncio_server_kwargs: dict - :param noisy_exceptions: Log exceptions that are normally considered - to be quiet/silent - :type noisy_exceptions: bool - :return: AsyncioServer if return_asyncio_server is true, else Nothing - """ - - if sock is None: - host, port = host or "127.0.0.1", port or 8000 - - if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) - - # Set explicitly passed configuration values - for attribute, value in { - "ACCESS_LOG": access_log, - "NOISY_EXCEPTIONS": noisy_exceptions, - }.items(): - if value is not None: - setattr(self.config, attribute, value) - - server_settings = self._helper( - host=host, - port=port, - debug=debug, - ssl=ssl, - sock=sock, - unix=unix, - loop=get_event_loop(), - protocol=protocol, - backlog=backlog, - run_async=return_asyncio_server, - ) - - if self.config.USE_UVLOOP is not _default: - error_logger.warning( - "You are trying to change the uvloop configuration, but " - "this is only effective when using the run(...) method. " - "When using the create_server(...) method Sanic will use " - "the already existing loop." - ) - - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - if main_start or main_stop: - logger.warning( - "Listener events for the main process are not available " - "with create_server()" - ) - - return await serve( - asyncio_server_kwargs=asyncio_server_kwargs, **server_settings - ) - async def _run_request_middleware( self, request, request_name=None ): # no cov @@ -1417,104 +1119,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): break return response - def _helper( - self, - host: Optional[str] = None, - port: Optional[int] = None, - debug: bool = False, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - unix: Optional[str] = None, - workers: int = 1, - loop: AbstractEventLoop = None, - protocol: Type[Protocol] = HttpProtocol, - backlog: int = 100, - register_sys_signals: bool = True, - run_async: bool = False, - version: Union[HTTP, Literal[1], Literal[3]] = HTTP.VERSION_1, - ): - """Helper function used by `run` and `create_server`.""" - if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: - raise ValueError( - "PROXIES_COUNT cannot be negative. " - "https://sanic.readthedocs.io/en/latest/sanic/config.html" - "#proxy-configuration" - ) - - if isinstance(version, int): - version = HTTP(version) - ssl = process_to_context(ssl) - - self.debug = debug - self.state.host = host - self.state.port = port - self.state.workers = workers - self.state.ssl = ssl - self.state.unix = unix - self.state.sock = sock - - server_settings = { - "protocol": protocol, - "host": host, - "port": port, - "sock": sock, - "unix": unix, - "ssl": ssl, - "app": self, - "signal": ServerSignal(), - "loop": loop, - "register_sys_signals": register_sys_signals, - "backlog": backlog, - "version": version, - } - - self.motd(self.serve_location) - - if sys.stdout.isatty() and not self.state.is_debug: - error_logger.warning( - f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. " - "Consider using '--debug' or '--dev' while actively " - f"developing your application.{Colors.END}" - ) - - # Register start/stop events - for event_name, settings_name, reverse in ( - ("main_process_start", "main_start", False), - ("main_process_stop", "main_stop", True), - ): - listeners = self.listeners[event_name].copy() - if reverse: - listeners.reverse() - # Prepend sanic to the arguments when listeners are triggered - listeners = [partial(listener, self) for listener in listeners] - server_settings[settings_name] = listeners # type: ignore - - if run_async: - server_settings["run_async"] = True - - return server_settings - - @property - def serve_location(self) -> str: - serve_location = "" - proto = "http" - if self.state.ssl is not None: - proto = "https" - if self.state.unix: - serve_location = f"{self.state.unix} {proto}://..." - elif self.state.sock: - serve_location = f"{self.state.sock.getsockname()} {proto}://..." - elif self.state.host and self.state.port: - # colon(:) is legal for a host only in an ipv6 address - display_host = ( - f"[{self.state.host}]" - if ":" in self.state.host - else self.state.host - ) - serve_location = f"{proto}://{display_host}:{self.state.port}" - - return serve_location - def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) @@ -1561,10 +1165,20 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): name: Optional[str] = None, register: bool = True, ) -> Task: - prepped = cls._prep_task(task, app, loop) - task = loop.create_task(prepped, name=name) + if not isinstance(task, Future): + prepped = cls._prep_task(task, app, loop) + if sys.version_info < (3, 8): # no cov + task = loop.create_task(prepped) + if name: + error_logger.warning( + "Cannot set a name for a task when using Python 3.7. " + "Your task will be created without a name." + ) + task.get_name = lambda: name + else: + task = loop.create_task(prepped, name=name) - if name and register: + if name and register and sys.version_info > (3, 7): app._task_registry[name] = task return task @@ -1598,12 +1212,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): :param task: future, couroutine or awaitable """ - if name and sys.version_info == (3, 7): - name = None - error_logger.warning( - "Cannot set a name for a task when using Python 3.7. Your " - "task will be created without a name." - ) try: loop = self.loop # Will raise SanicError if loop is not started return self._loop_add_task( @@ -1626,10 +1234,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): def get_task( self, name: str, *, raise_exception: bool = True ) -> Optional[Task]: - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) try: return self._task_registry[name] except KeyError: @@ -1646,17 +1250,13 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): *, raise_exception: bool = True, ) -> None: - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) task = self.get_task(name, raise_exception=raise_exception) if task and not task.cancelled(): args: Tuple[str, ...] = () if msg: if sys.version_info >= (3, 9): args = (msg,) - else: + else: # no cov raise RuntimeError( "Cancelling a task with a message is only supported " "on Python 3.9+." @@ -1668,10 +1268,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): ... def purge_tasks(self): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) for task in self.tasks: if task.done() or task.cancelled(): name = task.get_name() @@ -1684,27 +1280,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): def shutdown_tasks( self, timeout: Optional[float] = None, increment: float = 0.1 ): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) for task in self.tasks: - task.cancel() + if task.get_name() != "RunServer": + task.cancel() if timeout is None: timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT while len(self._task_registry) and timeout: - self.loop.run_until_complete(asyncio.sleep(increment)) + with suppress(RuntimeError): + running_loop = get_running_loop() + running_loop.run_until_complete(asyncio.sleep(increment)) self.purge_tasks() timeout -= increment @property def tasks(self): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) return iter(self._task_registry.values()) # -------------------------------------------------------------------- # @@ -1718,7 +1309,8 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): details: https://asgi.readthedocs.io/en/latest """ self.asgi = True - self.motd("") + if scope["type"] == "lifespan": + self.motd("") self._asgi_app = await ASGIApp.create(self, scope, receive, send) asgi_app = self._asgi_app await asgi_app() @@ -1753,6 +1345,13 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): @debug.setter def debug(self, value: bool): + deprecation( + "Setting the value of a Sanic application's debug value directly " + "is deprecated and will be removed in v22.9. Please set it using " + "the CLI, app.run, app.prepare, or directly set " + "app.state.mode to Mode.DEBUG.", + 22.9, + ) mode = Mode.DEBUG if value else Mode.PRODUCTION self.state.mode = mode @@ -1770,80 +1369,60 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): @property def is_running(self): + deprecation( + "Use of the is_running property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) return self.state.is_running @is_running.setter def is_running(self, value: bool): + deprecation( + "Use of the is_running property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) self.state.is_running = value @property def is_stopping(self): + deprecation( + "Use of the is_stopping property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) return self.state.is_stopping @is_stopping.setter def is_stopping(self, value: bool): + deprecation( + "Use of the is_stopping property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) self.state.is_stopping = value @property def reload_dirs(self): return self.state.reload_dirs - def motd(self, serve_location): - if self.config.MOTD: - mode = [f"{self.state.mode},"] - if self.state.fast: - mode.append("goin' fast") - if self.state.asgi: - mode.append("ASGI") - else: - if self.state.workers == 1: - mode.append("single worker") - else: - mode.append(f"w/ {self.state.workers} workers") - - display = { - "mode": " ".join(mode), - "server": self.state.server, - "python": platform.python_version(), - "platform": platform.platform(), - } - extra = {} - if self.config.AUTO_RELOAD: - reload_display = "enabled" - if self.state.reload_dirs: - reload_display += ", ".join( - [ - "", - *( - str(path.absolute()) - for path in self.state.reload_dirs - ), - ] - ) - display["auto-reload"] = reload_display - - packages = [] - for package_name in SANIC_PACKAGES: - module_name = package_name.replace("-", "_") - try: - module = import_module(module_name) - packages.append(f"{package_name}=={module.__version__}") - except ImportError: - ... - - if packages: - display["packages"] = ", ".join(packages) - - if self.config.MOTD_DISPLAY: - extra.update(self.config.MOTD_DISPLAY) - - logo = ( - get_logo(coffee=self.state.coffee) - if self.config.LOGO == "" or self.config.LOGO is True - else self.config.LOGO - ) - MOTD.output(logo, serve_location, display, extra) - @property def ext(self) -> Extend: if not hasattr(self, "_ext"): @@ -1941,7 +1520,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): async def _startup(self): self._future_registry.clear() - # Startup Sanic Extensions if not hasattr(self, "_ext"): setup_ext(self) if hasattr(self, "_ext"): @@ -1964,8 +1542,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.__class__._uvloop_setting = self.config.USE_UVLOOP # Startup time optimizations - ErrorHandler.finalize(self.error_handler, config=self.config) - TouchUp.run(self) + if self.state.primary: + # TODO: + # - Raise warning if secondary apps have error handler config + ErrorHandler.finalize(self.error_handler, config=self.config) + TouchUp.run(self) self.state.is_started = True diff --git a/sanic/application/constants.py b/sanic/application/constants.py new file mode 100644 index 00000000..9d46cb8e --- /dev/null +++ b/sanic/application/constants.py @@ -0,0 +1,23 @@ +from enum import Enum, IntEnum, auto + + +class StrEnum(str, Enum): + def _generate_next_value_(name: str, *args) -> str: # type: ignore + return name.lower() + + +class Server(StrEnum): + SANIC = auto() + ASGI = auto() + GUNICORN = auto() + + +class Mode(StrEnum): + PRODUCTION = auto() + DEBUG = auto() + + +class ServerStage(IntEnum): + STOPPED = auto() + PARTIAL = auto() + SERVING = auto() diff --git a/sanic/application/spinner.py b/sanic/application/spinner.py new file mode 100644 index 00000000..c1e35338 --- /dev/null +++ b/sanic/application/spinner.py @@ -0,0 +1,88 @@ +import os +import sys +import time + +from contextlib import contextmanager +from curses.ascii import SP +from queue import Queue +from threading import Thread + + +if os.name == "nt": + import ctypes + import msvcrt + + class _CursorInfo(ctypes.Structure): + _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)] + + +class Spinner: + def __init__(self, message: str) -> None: + self.message = message + self.queue: Queue[int] = Queue() + self.spinner = self.cursor() + self.thread = Thread(target=self.run) + + def start(self): + self.queue.put(1) + self.thread.start() + self.hide() + + def run(self): + while self.queue.get(): + output = f"\r{self.message} [{next(self.spinner)}]" + sys.stdout.write(output) + sys.stdout.flush() + time.sleep(0.1) + self.queue.put(1) + + def stop(self): + self.queue.put(0) + self.thread.join() + self.show() + + @staticmethod + def cursor(): + while True: + for cursor in "|/-\\": + yield cursor + + @staticmethod + def hide(): + if os.name == "nt": + ci = _CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + ci.visible = False + ctypes.windll.kernel32.SetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + elif os.name == "posix": + sys.stdout.write("\033[?25l") + sys.stdout.flush() + + @staticmethod + def show(): + if os.name == "nt": + ci = _CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + ci.visible = True + ctypes.windll.kernel32.SetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + elif os.name == "posix": + sys.stdout.write("\033[?25h") + sys.stdout.flush() + + +@contextmanager +def loading(message: str = "Loading"): + spinner = Spinner(message) + spinner.start() + yield + spinner.stop() diff --git a/sanic/application/state.py b/sanic/application/state.py index 0345dad5..4dad18e8 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -3,33 +3,25 @@ from __future__ import annotations import logging from dataclasses import dataclass, field -from enum import Enum, auto from pathlib import Path from socket import socket from ssl import SSLContext -from typing import TYPE_CHECKING, Any, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from sanic.application.constants import Mode, Server, ServerStage from sanic.log import logger +from sanic.server.async_server import AsyncioServer -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic import Sanic -class StrEnum(str, Enum): - def _generate_next_value_(name: str, *args) -> str: # type: ignore - return name.lower() - - -class Server(StrEnum): - SANIC = auto() - ASGI = auto() - GUNICORN = auto() - - -class Mode(StrEnum): - PRODUCTION = auto() - DEBUG = auto() +@dataclass +class ApplicationServerInfo: + settings: Dict[str, Any] + stage: ServerStage = field(default=ServerStage.STOPPED) + server: Optional[AsyncioServer] = field(default=None) @dataclass @@ -45,12 +37,15 @@ class ApplicationState: unix: Optional[str] = field(default=None) mode: Mode = field(default=Mode.PRODUCTION) reload_dirs: Set[Path] = field(default_factory=set) + auto_reload: bool = field(default=False) server: Server = field(default=Server.SANIC) is_running: bool = field(default=False) is_started: bool = field(default=False) is_stopping: bool = field(default=False) verbosity: int = field(default=0) workers: int = field(default=0) + primary: bool = field(default=True) + server_info: List[ApplicationServerInfo] = field(default_factory=list) # This property relates to the ApplicationState instance and should # not be changed except in the __post_init__ method @@ -77,3 +72,17 @@ class ApplicationState: @property def is_debug(self): return self.mode is Mode.DEBUG + + @property + def stage(self) -> ServerStage: + if not self.server_info: + return ServerStage.STOPPED + + if all(info.stage is ServerStage.SERVING for info in self.server_info): + return ServerStage.SERVING + elif any( + info.stage is ServerStage.SERVING for info in self.server_info + ): + return ServerStage.PARTIAL + + return ServerStage.STOPPED diff --git a/sanic/asgi.py b/sanic/asgi.py index 5ef15a91..26140168 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,14 +1,15 @@ +from __future__ import annotations + import warnings -from typing import Optional +from typing import TYPE_CHECKING, Optional from urllib.parse import quote -import sanic.app # noqa - from sanic.compat import Header from sanic.exceptions import ServerError from sanic.helpers import _default from sanic.http import Stage +from sanic.log import logger from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.response import BaseHTTPResponse @@ -16,30 +17,35 @@ from sanic.server import ConnInfo from sanic.server.websockets.connection import WebSocketConnection +if TYPE_CHECKING: # no cov + from sanic import Sanic + + class Lifespan: - def __init__(self, asgi_app: "ASGIApp") -> None: + def __init__(self, asgi_app: ASGIApp) -> None: self.asgi_app = asgi_app - if ( - "server.init.before" - in self.asgi_app.sanic_app.signal_router.name_index - ): - warnings.warn( - 'You have set a listener for "before_server_start" ' - "in ASGI mode. " - "It will be executed as early as possible, but not before " - "the ASGI server is started." - ) - if ( - "server.shutdown.after" - in self.asgi_app.sanic_app.signal_router.name_index - ): - warnings.warn( - 'You have set a listener for "after_server_stop" ' - "in ASGI mode. " - "It will be executed as late as possible, but not after " - "the ASGI server is stopped." - ) + if self.asgi_app.sanic_app.state.verbosity > 0: + if ( + "server.init.before" + in self.asgi_app.sanic_app.signal_router.name_index + ): + logger.debug( + 'You have set a listener for "before_server_start" ' + "in ASGI mode. " + "It will be executed as early as possible, but not before " + "the ASGI server is started." + ) + if ( + "server.shutdown.after" + in self.asgi_app.sanic_app.signal_router.name_index + ): + logger.debug( + 'You have set a listener for "after_server_stop" ' + "in ASGI mode. " + "It will be executed as late as possible, but not after " + "the ASGI server is stopped." + ) async def startup(self) -> None: """ @@ -88,7 +94,7 @@ class Lifespan: class ASGIApp: - sanic_app: "sanic.app.Sanic" + sanic_app: Sanic request: Request transport: MockTransport lifespan: Lifespan diff --git a/sanic/cli/app.py b/sanic/cli/app.py index 2d12c2c7..0ce93adc 100644 --- a/sanic/cli/app.py +++ b/sanic/cli/app.py @@ -11,7 +11,6 @@ from typing import Any, List, Union from sanic.app import Sanic from sanic.application.logo import get_logo from sanic.cli.arguments import Group -from sanic.http.constants import HTTP from sanic.log import error_logger from sanic.simple import create_simple_server @@ -59,10 +58,13 @@ Or, a path to a directory to run as a simple HTTP server: os.environ.get("SANIC_RELOADER_PROCESS", "") != "true" ) self.args: List[Any] = [] + self.groups: List[Group] = [] def attach(self): for group in Group._registry: - group.create(self.parser).attach() + instance = group.create(self.parser) + instance.attach() + self.groups.append(instance) def run(self): # This is to provide backwards compat -v to display version @@ -75,18 +77,15 @@ Or, a path to a directory to run as a simple HTTP server: try: app = self._get_app() kwargs = self._build_run_kwargs() - app.run(**kwargs) except ValueError: error_logger.exception("Failed to run app") + else: + for http_version in self.args.http: + app.prepare(**kwargs, version=http_version) + + Sanic.serve() def _precheck(self): - if self.args.debug and self.main_process: - error_logger.warning( - "Starting in v22.3, --debug will no " - "longer automatically run the auto-reloader.\n Switch to " - "--dev to continue using that functionality." - ) - # # Custom TLS mismatch handling for better diagnostics if self.main_process and ( # one of cert/key missing @@ -144,11 +143,14 @@ Or, a path to a directory to run as a simple HTTP server: " Example File: project/sanic_server.py -> app\n" " Example Module: project.sanic_server.app" ) + sys.exit(1) else: raise e return app def _build_run_kwargs(self): + for group in self.groups: + group.prepare(self.args) ssl: Union[None, dict, str, list] = [] if self.args.tlshost: ssl.append(None) @@ -161,7 +163,6 @@ Or, a path to a directory to run as a simple HTTP server: elif len(ssl) == 1 and ssl[0] is not None: # Use only one cert, no TLSSelector. ssl = ssl[0] - version = HTTP(self.args.http) kwargs = { "access_log": self.args.access_log, "debug": self.args.debug, @@ -178,16 +179,12 @@ Or, a path to a directory to run as a simple HTTP server: "auto_tls": self.args.auto_tls, } - if self.args.auto_reload: - kwargs["auto_reload"] = True + for maybe_arg in ("auto_reload", "dev"): + if getattr(self.args, maybe_arg, False): + kwargs[maybe_arg] = True if self.args.path: - if self.args.auto_reload or self.args.debug: - kwargs["reload_dir"] = self.args.path - else: - error_logger.warning( - "Ignoring '--reload-dir' since auto reloading was not " - "enabled. If you would like to watch directories for " - "changes, consider using --debug or --auto-reload." - ) + kwargs["auto_reload"] = True + kwargs["reload_dir"] = self.args.path + return kwargs diff --git a/sanic/cli/arguments.py b/sanic/cli/arguments.py index e68b535c..33348cec 100644 --- a/sanic/cli/arguments.py +++ b/sanic/cli/arguments.py @@ -6,6 +6,7 @@ from typing import List, Optional, Type, Union from sanic_routing import __version__ as __routing_version__ # type: ignore from sanic import __version__ +from sanic.http.constants import HTTP class Group: @@ -38,6 +39,9 @@ class Group: "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs ) + def prepare(self, args) -> None: + ... + class GeneralGroup(Group): name = None @@ -87,25 +91,39 @@ class HTTPVersionGroup(Group): name = "HTTP version" def attach(self): - group = self.container.add_mutually_exclusive_group() - group.add_argument( + http_values = [http.value for http in HTTP.__members__.values()] + + self.container.add_argument( "--http", dest="http", + action="append", + choices=http_values, type=int, - default=1, help=( - "Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should " - "be either 1 or 3 [default 1]" + "Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n" + "be either 1, or 3. [default 1]" ), ) - group.add_argument( + self.container.add_argument( + "-1", + dest="http", + action="append_const", + const=1, + help=("Run Sanic server using HTTP/1.1"), + ) + self.container.add_argument( "-3", dest="http", - action="store_const", + action="append_const", const=3, help=("Run Sanic server using HTTP/3"), ) + def prepare(self, args): + if not args.http: + args.http = [1] + args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True)) + class SocketGroup(Group): name = "Socket binding" @@ -116,7 +134,6 @@ class SocketGroup(Group): "--host", dest="host", type=str, - default="127.0.0.1", help="Host address [default 127.0.0.1]", ) self.container.add_argument( @@ -124,7 +141,6 @@ class SocketGroup(Group): "--port", dest="port", type=int, - default=8000, help="Port to serve on [default 8000]", ) self.container.add_argument( diff --git a/sanic/config.py b/sanic/config.py index c99d9abe..7e519fe9 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -130,22 +130,27 @@ class Config(dict, metaclass=DescriptorMeta): raise AttributeError(f"Config has no '{ke.args[0]}'") def __setattr__(self, attr, value) -> None: - if attr in self.__class__.__setters__: - try: - super().__setattr__(attr, value) - except AttributeError: - ... - else: - return None self.update({attr: value}) def __setitem__(self, attr, value) -> None: self.update({attr: value}) def update(self, *other, **kwargs) -> None: - other_mapping = {k: v for item in other for k, v in dict(item).items()} - super().update(*other, **kwargs) - for attr, value in {**other_mapping, **kwargs}.items(): + kwargs.update({k: v for item in other for k, v in dict(item).items()}) + setters: Dict[str, Any] = { + k: kwargs.pop(k) + for k in {**kwargs}.keys() + if k in self.__class__.__setters__ + } + + for key, value in setters.items(): + try: + super().__setattr__(key, value) + except AttributeError: + ... + + super().update(**kwargs) + for attr, value in {**setters, **kwargs}.items(): self._post_set(attr, value) def _post_set(self, attr, value) -> None: diff --git a/sanic/handlers.py b/sanic/handlers.py index 44ff77a7..b3f01566 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,13 +1,12 @@ from __future__ import annotations -from inspect import signature from typing import Dict, List, Optional, Tuple, Type, Union from sanic.config import Config from sanic.errorpages import ( DEFAULT_FORMAT, BaseRenderer, - HTMLRenderer, + TextRenderer, exception_response, ) from sanic.exceptions import ( @@ -35,13 +34,11 @@ class ErrorHandler: """ - # Beginning in v22.3, the base renderer will be TextRenderer def __init__( self, fallback: Union[str, Default] = _default, - base: Type[BaseRenderer] = HTMLRenderer, + base: Type[BaseRenderer] = TextRenderer, ): - self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] self.cached_handlers: Dict[ Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] ] = {} @@ -53,14 +50,14 @@ class ErrorHandler: self._warn_fallback_deprecation() @property - def fallback(self): + def fallback(self): # no cov # This is for backwards compat and can be removed in v22.6 if self._fallback is _default: return DEFAULT_FORMAT return self._fallback @fallback.setter - def fallback(self, value: str): + def fallback(self, value: str): # no cov self._warn_fallback_deprecation() if not isinstance(value, str): raise SanicException( @@ -95,8 +92,8 @@ class ErrorHandler: def finalize( cls, error_handler: ErrorHandler, + config: Config, fallback: Optional[str] = None, - config: Optional[Config] = None, ): if fallback: deprecation( @@ -107,14 +104,10 @@ class ErrorHandler: 22.6, ) - if config is None: - deprecation( - "Starting in v22.3, config will be a required argument " - "for ErrorHandler.finalize().", - 22.3, - ) + if not fallback: + fallback = config.FALLBACK_ERROR_FORMAT - if fallback and fallback != DEFAULT_FORMAT: + if fallback != DEFAULT_FORMAT: if error_handler._fallback is not _default: error_logger.warning( f"Setting the fallback value to {fallback}. This changes " @@ -128,27 +121,9 @@ class ErrorHandler: f"Error handler is non-conforming: {type(error_handler)}" ) - sig = signature(error_handler.lookup) - if len(sig.parameters) == 1: - deprecation( - "You are using a deprecated error handler. The lookup " - "method should accept two positional parameters: " - "(exception, route_name: Optional[str]). " - "Until you upgrade your ErrorHandler.lookup, Blueprint " - "specific exceptions will not work properly. Beginning " - "in v22.3, the legacy style lookup method will not " - "work at all.", - 22.3, - ) - legacy_lookup = error_handler._legacy_lookup - error_handler._lookup = legacy_lookup # type: ignore - def _full_lookup(self, exception, route_name: Optional[str] = None): return self.lookup(exception, route_name) - def _legacy_lookup(self, exception, route_name: Optional[str] = None): - return self.lookup(exception) - def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -162,9 +137,6 @@ class ErrorHandler: :return: None """ - # self.handlers is deprecated and will be removed in version 22.3 - self.handlers.append((exception, handler)) - if route_names: for route in route_names: self.cached_handlers[(exception, route)] = handler @@ -236,7 +208,7 @@ class ErrorHandler: except Exception: try: url = repr(request.url) - except AttributeError: + except AttributeError: # no cov url = "unknown" response_message = ( "Exception raised in exception handler " '"%s" for uri: %s' @@ -281,7 +253,7 @@ class ErrorHandler: if quiet is False or noisy is True: try: url = repr(request.url) - except AttributeError: + except AttributeError: # no cov url = "unknown" error_logger.exception( diff --git a/sanic/headers.py b/sanic/headers.py index b744974c..b4457653 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -2,7 +2,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import unquote from sanic.exceptions import InvalidHeader @@ -394,3 +394,17 @@ def parse_accept(accept: str) -> AcceptContainer: return AcceptContainer( sorted(accept_list, key=_sort_accept_value, reverse=True) ) + + +def parse_credentials( + header: Optional[str], + prefixes: Union[List, Tuple, Set] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Parses any header with the aim to retrieve any credentials from it.""" + if not prefixes or not isinstance(prefixes, (list, tuple, set)): + prefixes = ("Basic", "Bearer", "Token") + if header is not None: + for prefix in prefixes: + if prefix in header: + return prefix, header.partition(prefix)[-1].strip() + return None, header diff --git a/sanic/http/constants.py b/sanic/http/constants.py index 359e9ec0..df3eebeb 100644 --- a/sanic/http/constants.py +++ b/sanic/http/constants.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, IntEnum class Stage(Enum): @@ -20,7 +20,7 @@ class Stage(Enum): FAILED = 100 # Unrecoverable state (error while sending response) -class HTTP(Enum): +class HTTP(IntEnum): VERSION_1 = 1 VERSION_3 = 3 diff --git a/sanic/http/http1.py b/sanic/http/http1.py index 96232fb5..111357cb 100644 --- a/sanic/http/http1.py +++ b/sanic/http/http1.py @@ -333,6 +333,12 @@ class Http(metaclass=TouchUpMeta): self.response_func = self.head_response_ignored headers["connection"] = "keep-alive" if self.keep_alive else "close" + + # This header may be removed or modified by the AltSvcCheck Touchup + # service. At server start, we either remove this header from ever + # being assigned, or we change the value as required. + headers["alt-svc"] = "" + ret = format_http1_response(status, res.processed_headers) if data: ret += data diff --git a/sanic/http/http3.py b/sanic/http/http3.py index ce0d0797..3d1b2b56 100644 --- a/sanic/http/http3.py +++ b/sanic/http/http3.py @@ -3,9 +3,8 @@ from __future__ import annotations import asyncio from abc import ABC -from ast import Mod from ssl import SSLContext -from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.connection import H3_ALPN, H3Connection diff --git a/sanic/http/tls.py b/sanic/http/tls.py index 844fe435..387f669a 100644 --- a/sanic/http/tls.py +++ b/sanic/http/tls.py @@ -3,18 +3,19 @@ from __future__ import annotations import os import ssl import subprocess +import sys from contextlib import suppress -from inspect import currentframe, getframeinfo from pathlib import Path from ssl import SSLContext from tempfile import mkdtemp from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union -from sanic.application.state import Mode +from sanic.application.constants import Mode +from sanic.application.spinner import loading from sanic.constants import DEFAULT_LOCAL_TLS_CERT, DEFAULT_LOCAL_TLS_KEY from sanic.exceptions import SanicException -from sanic.helpers import Default, _default +from sanic.helpers import Default from sanic.log import logger @@ -233,7 +234,7 @@ def get_ssl_context(app: Sanic, ssl: Optional[SSLContext]) -> SSLContext: if app.state.mode is Mode.PRODUCTION: raise SanicException( - "Cannot run Sanic as an HTTP/3 server in PRODUCTION mode " + "Cannot run Sanic as an HTTPS server in PRODUCTION mode " "without passing a TLS certificate. If you are developing " "locally, please enable DEVELOPMENT mode and Sanic will " "generate a localhost TLS certificate. For more information " @@ -283,15 +284,32 @@ def generate_local_certificate( ): check_mkcert() - cmd = [ - "mkcert", - "-key-file", - str(key_path), - "-cert-file", - str(cert_path), - localhost, - ] - subprocess.run(cmd, check=True) + if not key_path.parent.exists() or not cert_path.parent.exists(): + raise SanicException( + f"Cannot generate certificate at [{key_path}, {cert_path}]. One " + "or more of the directories does not exist." + ) + + message = "Generating TLS certificate" + with loading(message): + cmd = [ + "mkcert", + "-key-file", + str(key_path), + "-cert-file", + str(cert_path), + localhost, + ] + resp = subprocess.run( + cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + sys.stdout.write("\r" + " " * (len(message) + 4)) + sys.stdout.flush() + sys.stdout.write(resp.stdout) def check_mkcert(): diff --git a/sanic/log.py b/sanic/log.py index 000b3ed8..4b3b960c 100644 --- a/sanic/log.py +++ b/sanic/log.py @@ -6,7 +6,7 @@ from typing import Any, Dict from warnings import warn -LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( +LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( # no cov version=1, disable_existing_loggers=False, loggers={ @@ -57,7 +57,7 @@ LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( ) -class Colors(str, Enum): +class Colors(str, Enum): # no cov END = "\033[0m" BLUE = "\033[01;34m" GREEN = "\033[01;32m" @@ -65,23 +65,23 @@ class Colors(str, Enum): RED = "\033[01;31m" -logger = logging.getLogger("sanic.root") +logger = logging.getLogger("sanic.root") # no cov """ General Sanic logger """ -error_logger = logging.getLogger("sanic.error") +error_logger = logging.getLogger("sanic.error") # no cov """ Logger used by Sanic for error logging """ -access_logger = logging.getLogger("sanic.access") +access_logger = logging.getLogger("sanic.access") # no cov """ Logger used by Sanic for access logging """ -def deprecation(message: str, version: float): +def deprecation(message: str, version: float): # no cov version_info = f"[DEPRECATION v{version}] " if sys.stdout.isatty(): version_info = f"{Colors.RED}{version_info}" diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 57b755ee..74e0e8b2 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -13,7 +13,7 @@ ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGIReceive = Callable[[], Awaitable[ASGIMessage]] -class MockProtocol: +class MockProtocol: # no cov def __init__(self, transport: "MockTransport", loop): # This should be refactored when < 3.8 support is dropped self.transport = transport @@ -56,7 +56,7 @@ class MockProtocol: await self._not_paused.wait() -class MockTransport: +class MockTransport: # no cov _protocol: Optional[MockProtocol] def __init__( diff --git a/sanic/models/http_types.py b/sanic/models/http_types.py new file mode 100644 index 00000000..595eaf0e --- /dev/null +++ b/sanic/models/http_types.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from base64 import b64decode +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass() +class Credentials: + auth_type: Optional[str] + token: Optional[str] + _username: Optional[str] = field(default=None) + _password: Optional[str] = field(default=None) + + def __post_init__(self): + if self._auth_is_basic: + self._username, self._password = ( + b64decode(self.token.encode("utf-8")).decode().split(":") + ) + + @property + def username(self): + if not self._auth_is_basic: + raise AttributeError("Username is available for Basic Auth only") + return self._username + + @property + def password(self): + if not self._auth_is_basic: + raise AttributeError("Password is available for Basic Auth only") + return self._password + + @property + def _auth_is_basic(self) -> bool: + return self.auth_type == "Basic" diff --git a/sanic/models/server_types.py b/sanic/models/server_types.py index ad8872e1..ba9f2918 100644 --- a/sanic/models/server_types.py +++ b/sanic/models/server_types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ssl import SSLObject from types import SimpleNamespace from typing import Any, Dict, Optional diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index 3c726edb..4111cc71 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -77,7 +77,7 @@ def _check_file(filename, mtimes): return need_reload -def watchdog(sleep_interval, app): +def watchdog(sleep_interval, reload_dirs): """Watch project files, restart worker process if a change happened. :param sleep_interval: interval in second. @@ -100,7 +100,7 @@ def watchdog(sleep_interval, app): changed = set() for filename in itertools.chain( _iter_module_files(), - *(d.glob("**/*") for d in app.reload_dirs), + *(d.glob("**/*") for d in reload_dirs), ): try: if _check_file(filename, mtimes): diff --git a/sanic/request.py b/sanic/request.py index ee535ac2..a1ce91c3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -15,6 +15,8 @@ from typing import ( from sanic_routing.route import Route # type: ignore +from sanic.models.http_types import Credentials + if TYPE_CHECKING: # no cov from sanic.server import ConnInfo @@ -38,6 +40,7 @@ from sanic.headers import ( Options, parse_accept, parse_content_header, + parse_credentials, parse_forwarded, parse_host, parse_xforwarded, @@ -99,11 +102,13 @@ class Request: "method", "parsed_accept", "parsed_args", - "parsed_not_grouped_args", + "parsed_credentials", "parsed_files", "parsed_form", - "parsed_json", "parsed_forwarded", + "parsed_json", + "parsed_not_grouped_args", + "parsed_token", "raw_url", "responded", "request_middleware_started", @@ -123,6 +128,7 @@ class Request: app: Sanic, head: bytes = b"", ): + self.raw_url = url_bytes # TODO: Content-Encoding detection self._parsed_url = parse_url(url_bytes) @@ -142,9 +148,11 @@ class Request: self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None self.parsed_accept: Optional[AcceptContainer] = None + self.parsed_credentials: Optional[Credentials] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None + self.parsed_token: Optional[str] = None self.parsed_args: DefaultDict[ Tuple[bool, bool, str, str], RequestParameters ] = defaultdict(RequestParameters) @@ -336,20 +344,41 @@ class Request: return self.parsed_accept @property - def token(self): + def token(self) -> Optional[str]: """Attempt to return the auth header token. :return: token related to request """ - prefixes = ("Bearer", "Token") - auth_header = self.headers.getone("authorization", None) + if self.parsed_token is None: + prefixes = ("Bearer", "Token") + _, token = parse_credentials( + self.headers.getone("authorization", None), prefixes + ) + self.parsed_token = token + return self.parsed_token - if auth_header is not None: - for prefix in prefixes: - if prefix in auth_header: - return auth_header.partition(prefix)[-1].strip() + @property + def credentials(self) -> Optional[Credentials]: + """Attempt to return the auth header value. - return auth_header + Covers NoAuth, Basic Auth, Bearer Token, Api Token authentication + schemas. + + :return: A Credentials object with token, or username and password + related to the request + """ + if self.parsed_credentials is None: + try: + prefix, credentials = parse_credentials( + self.headers.getone("authorization", None) + ) + if credentials: + self.parsed_credentials = Credentials( + auth_type=prefix, token=credentials + ) + except ValueError: + pass + return self.parsed_credentials @property def form(self): diff --git a/sanic/server/protocols/base_protocol.py b/sanic/server/protocols/base_protocol.py index 63d4bfb5..3a271669 100644 --- a/sanic/server/protocols/base_protocol.py +++ b/sanic/server/protocols/base_protocol.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.app import Sanic import asyncio diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index df2ed699..c215208a 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -8,7 +8,7 @@ from sanic.http.http3 import Http3 from sanic.touchup.meta import TouchUpMeta -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.app import Sanic from asyncio import CancelledError diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index ad6f8f95..d9539696 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -5,13 +5,13 @@ from websockets.server import ServerConnection from websockets.typing import Subprotocol from sanic.exceptions import ServerError -from sanic.log import deprecation, error_logger +from sanic.log import error_logger from sanic.server import HttpProtocol from ..websockets.impl import WebsocketImplProtocol -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from websockets import http11 @@ -29,9 +29,6 @@ class WebSocketProtocol(HttpProtocol): *args, websocket_timeout: float = 10.0, websocket_max_size: Optional[int] = None, - websocket_max_queue: Optional[int] = None, # max_queue is deprecated - websocket_read_limit: Optional[int] = None, # read_limit is deprecated - websocket_write_limit: Optional[int] = None, # write_limit deprecated websocket_ping_interval: Optional[float] = 20.0, websocket_ping_timeout: Optional[float] = 20.0, **kwargs, @@ -40,27 +37,6 @@ class WebSocketProtocol(HttpProtocol): self.websocket: Optional[WebsocketImplProtocol] = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size - if websocket_max_queue is not None and websocket_max_queue > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses queueing, so websocket_max_queue" - " is no longer required.", - 22.3, - ) - if websocket_read_limit is not None and websocket_read_limit > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses read buffers, so " - "websocket_read_limit is not required.", - 22.3, - ) - if websocket_write_limit is not None and websocket_write_limit > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses write buffers, so " - "websocket_write_limit is not required.", - 22.3, - ) self.websocket_ping_interval = websocket_ping_interval self.websocket_ping_timeout = websocket_ping_timeout diff --git a/sanic/server/runners.py b/sanic/server/runners.py index b5a9d9c2..86779472 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -95,8 +95,47 @@ def serve( app.asgi = False if version is HTTP.VERSION_3: - return serve_http_3(host, port, app, loop, ssl) + return _serve_http_3(host, port, app, loop, ssl) + return _serve_http_1( + host, + port, + app, + ssl, + sock, + unix, + reuse_port, + loop, + protocol, + backlog, + register_sys_signals, + run_multiple, + run_async, + connections, + signal, + state, + asyncio_server_kwargs, + ) + +def _serve_http_1( + host, + port, + app, + ssl, + sock, + unix, + reuse_port, + loop, + protocol, + backlog, + register_sys_signals, + run_multiple, + run_async, + connections, + signal, + state, + asyncio_server_kwargs, +): connections = connections if connections is not None else set() protocol_kwargs = _build_protocol_kwargs(protocol, app.config) server = partial( @@ -141,7 +180,7 @@ def serve( try: http_server = loop.run_until_complete(server_coroutine) except BaseException: - error_logger.exception("Unable to start server") + error_logger.exception("Unable to start server", exc_info=True) return # Ignore SIGINT when run_multiple @@ -201,7 +240,7 @@ def serve( remove_unix_socket(unix) -def serve_http_3( +def _serve_http_3( host, port, app, diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index fef27db1..e4972516 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -9,7 +9,7 @@ from websockets.typing import Data from sanic.exceptions import ServerError -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from .impl import WebsocketImplProtocol UTF8Decoder = codecs.getincrementaldecoder("utf-8") @@ -37,7 +37,7 @@ class WebsocketFrameAssembler: "get_id", "put_id", ) - if TYPE_CHECKING: + if TYPE_CHECKING: # no cov protocol: "WebsocketImplProtocol" read_mutex: asyncio.Lock write_mutex: asyncio.Lock @@ -131,7 +131,7 @@ class WebsocketFrameAssembler: if self.paused: self.protocol.resume_frames() self.paused = False - if not self.get_in_progress: + if not self.get_in_progress: # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -204,7 +204,7 @@ class WebsocketFrameAssembler: if self.paused: self.protocol.resume_frames() self.paused = False - if not self.get_in_progress: + if not self.get_in_progress: # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -212,7 +212,7 @@ class WebsocketFrameAssembler: "asynchronous get was in progress." ) self.get_in_progress = False - if not self.message_complete.is_set(): + if not self.message_complete.is_set(): # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -220,7 +220,7 @@ class WebsocketFrameAssembler: "message was complete." ) self.message_complete.clear() - if self.message_fetched.is_set(): + if self.message_fetched.is_set(): # no cov # This should be guarded against with the read_mutex, # and get_in_progress check, this exception is # here as a failsafe diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index ed0d7fed..aaccfaca 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -518,8 +518,12 @@ class WebsocketImplProtocol: ) try: self.recv_cancel = asyncio.Future() + tasks = ( + self.recv_cancel, + asyncio.ensure_future(self.assembler.get(timeout)), + ) done, pending = await asyncio.wait( - (self.recv_cancel, self.assembler.get(timeout)), + tasks, return_when=asyncio.FIRST_COMPLETED, ) done_task = next(iter(done)) @@ -570,8 +574,12 @@ class WebsocketImplProtocol: self.can_pause = False self.recv_cancel = asyncio.Future() while True: + tasks = ( + self.recv_cancel, + asyncio.ensure_future(self.assembler.get(timeout=0)), + ) done, pending = await asyncio.wait( - (self.recv_cancel, self.assembler.get(timeout=0)), + tasks, return_when=asyncio.FIRST_COMPLETED, ) done_task = next(iter(done)) diff --git a/sanic/touchup/schemes/__init__.py b/sanic/touchup/schemes/__init__.py index 87057a5f..dd4145ab 100644 --- a/sanic/touchup/schemes/__init__.py +++ b/sanic/touchup/schemes/__init__.py @@ -1,3 +1,4 @@ +from .altsvc import AltSvcCheck # noqa from .base import BaseScheme from .ode import OptionalDispatchEvent # noqa diff --git a/sanic/touchup/schemes/base.py b/sanic/touchup/schemes/base.py index d16619b2..9e32c323 100644 --- a/sanic/touchup/schemes/base.py +++ b/sanic/touchup/schemes/base.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod -from typing import Set, Type +from ast import NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any, Dict, List, Set, Type class BaseScheme(ABC): @@ -10,11 +13,26 @@ class BaseScheme(ABC): self.app = app @abstractmethod - def run(self, method, module_globals) -> None: + def visitors(self) -> List[NodeTransformer]: ... def __init_subclass__(cls): BaseScheme._registry.add(cls) - def __call__(self, method, module_globals): - return self.run(method, module_globals) + def __call__(self): + return self.visitors() + + @classmethod + def build(cls, method, module_globals, app): + raw_source = getsource(method) + src = dedent(raw_source) + node = parse(src) + + for scheme in cls._registry: + for visitor in scheme(app)(): + node = visitor.visit(node) + + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + return exec_locals[method.__name__] diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py index aa7d4bd9..c9b78c8b 100644 --- a/sanic/touchup/schemes/ode.py +++ b/sanic/touchup/schemes/ode.py @@ -1,7 +1,5 @@ -from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse -from inspect import getsource -from textwrap import dedent -from typing import Any +from ast import Attribute, Await, Expr, NodeTransformer +from typing import Any, List from sanic.log import logger @@ -10,26 +8,49 @@ from .base import BaseScheme class OptionalDispatchEvent(BaseScheme): ident = "ODE" + SYNC_SIGNAL_NAMESPACES = "http." def __init__(self, app) -> None: super().__init__(app) + self._sync_events() self._registered_events = [ - signal.path for signal in app.signal_router.routes + signal.name for signal in app.signal_router.routes ] - def run(self, method, module_globals): - raw_source = getsource(method) - src = dedent(raw_source) - tree = parse(src) - node = RemoveDispatch( - self._registered_events, self.app.state.verbosity - ).visit(tree) - compiled_src = compile(node, method.__name__, "exec") - exec_locals: Dict[str, Any] = {} - exec(compiled_src, module_globals, exec_locals) # nosec + def visitors(self) -> List[NodeTransformer]: + return [ + RemoveDispatch(self._registered_events, self.app.state.verbosity) + ] - return exec_locals[method.__name__] + def _sync_events(self): + all_events = set() + app_events = {} + for app in self.app.__class__._app_registry.values(): + if app.state.server_info: + app_events[app] = { + signal.name for signal in app.signal_router.routes + } + all_events.update(app_events[app]) + + for app, events in app_events.items(): + missing = { + x + for x in all_events.difference(events) + if any(x.startswith(y) for y in self.SYNC_SIGNAL_NAMESPACES) + } + if missing: + was_finalized = app.signal_router.finalized + if was_finalized: # no cov + app.signal_router.reset() + for event in missing: + app.signal(event)(self.noop) + if was_finalized: # no cov + app.signal_router.finalize() + + @staticmethod + async def noop(**_): # no cov + ... class RemoveDispatch(NodeTransformer): diff --git a/sanic/touchup/service.py b/sanic/touchup/service.py index 95792dca..b1b996fb 100644 --- a/sanic/touchup/service.py +++ b/sanic/touchup/service.py @@ -21,10 +21,8 @@ class TouchUp: module = getmodule(target) module_globals = dict(getmembers(module)) - - for scheme in BaseScheme._registry: - modified = scheme(app)(method, module_globals) - setattr(target, method_name, modified) + modified = BaseScheme.build(method, module_globals, app) + setattr(target, method_name, modified) target.__touched__ = True diff --git a/sanic/worker.py b/sanic/worker.py deleted file mode 100644 index befe8d78..00000000 --- a/sanic/worker.py +++ /dev/null @@ -1,243 +0,0 @@ -import asyncio -import logging -import os -import signal -import sys -import traceback - -from gunicorn.workers import base # type: ignore - -from sanic.compat import UVLOOP_INSTALLED -from sanic.log import logger -from sanic.server import HttpProtocol, Signal, serve, try_use_uvloop -from sanic.server.protocols.websocket_protocol import WebSocketProtocol - - -try: - import ssl # type: ignore -except ImportError: # no cov - ssl = None # type: ignore - -if UVLOOP_INSTALLED: # no cov - try_use_uvloop() - - -class GunicornWorker(base.Worker): - - http_protocol = HttpProtocol - websocket_protocol = WebSocketProtocol - - def __init__(self, *args, **kw): # pragma: no cover - super().__init__(*args, **kw) - cfg = self.cfg - if cfg.is_ssl: - self.ssl_context = self._create_ssl_context(cfg) - else: - self.ssl_context = None - self.servers = {} - self.connections = set() - self.exit_code = 0 - self.signal = Signal() - - def init_process(self): - # create new event_loop after fork - asyncio.get_event_loop().close() - - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - super().init_process() - - def run(self): - is_debug = self.log.loglevel == logging.DEBUG - protocol = ( - self.websocket_protocol - if self.app.callable.websocket_enabled - else self.http_protocol - ) - - self._server_settings = self.app.callable._helper( - loop=self.loop, - debug=is_debug, - protocol=protocol, - ssl=self.ssl_context, - run_async=True, - ) - self._server_settings["signal"] = self.signal - self._server_settings.pop("sock") - self._await(self.app.callable._startup()) - self._await( - self.app.callable._server_event("init", "before", loop=self.loop) - ) - - main_start = self._server_settings.pop("main_start", None) - main_stop = self._server_settings.pop("main_stop", None) - - if main_start or main_stop: # noqa - logger.warning( - "Listener events for the main process are not available " - "with GunicornWorker" - ) - - try: - self._await(self._run()) - self.app.callable.is_running = True - self._await( - self.app.callable._server_event( - "init", "after", loop=self.loop - ) - ) - self.loop.run_until_complete(self._check_alive()) - self._await( - self.app.callable._server_event( - "shutdown", "before", loop=self.loop - ) - ) - self.loop.run_until_complete(self.close()) - except BaseException: - traceback.print_exc() - finally: - try: - self._await( - self.app.callable._server_event( - "shutdown", "after", loop=self.loop - ) - ) - except BaseException: - traceback.print_exc() - finally: - self.loop.close() - - sys.exit(self.exit_code) - - async def close(self): - if self.servers: - # stop accepting connections - self.log.info( - "Stopping server: %s, connections: %s", - self.pid, - len(self.connections), - ) - for server in self.servers: - server.close() - await server.wait_closed() - self.servers.clear() - - # prepare connections for closing - self.signal.stopped = True - for conn in self.connections: - conn.close_if_idle() - - # gracefully shutdown timeout - start_shutdown = 0 - graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and ( - start_shutdown < graceful_shutdown_timeout - ): - await asyncio.sleep(0.1) - start_shutdown = start_shutdown + 0.1 - - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - for conn in self.connections: - if hasattr(conn, "websocket") and conn.websocket: - conn.websocket.fail_connection(code=1001) - else: - conn.abort() - - async def _run(self): - for sock in self.sockets: - state = dict(requests_count=0) - self._server_settings["host"] = None - self._server_settings["port"] = None - server = await serve( - sock=sock, - connections=self.connections, - state=state, - **self._server_settings - ) - self.servers[server] = state - - async def _check_alive(self): - # If our parent changed then we shut down. - pid = os.getpid() - try: - while self.alive: - self.notify() - - req_count = sum( - self.servers[srv]["requests_count"] for srv in self.servers - ) - if self.max_requests and req_count > self.max_requests: - self.alive = False - self.log.info( - "Max requests exceeded, shutting down: %s", self - ) - elif pid == os.getpid() and self.ppid != os.getppid(): - self.alive = False - self.log.info("Parent changed, shutting down: %s", self) - else: - await asyncio.sleep(1.0, loop=self.loop) - except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): - pass - - @staticmethod - def _create_ssl_context(cfg): - """Creates SSLContext instance for usage in asyncio.create_server. - See ssl.SSLSocket.__init__ for more details. - """ - ctx = ssl.SSLContext(cfg.ssl_version) - ctx.load_cert_chain(cfg.certfile, cfg.keyfile) - ctx.verify_mode = cfg.cert_reqs - if cfg.ca_certs: - ctx.load_verify_locations(cfg.ca_certs) - if cfg.ciphers: - ctx.set_ciphers(cfg.ciphers) - return ctx - - def init_signals(self): - # Set up signals through the event loop API. - - self.loop.add_signal_handler( - signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None - ) - - self.loop.add_signal_handler( - signal.SIGTERM, self.handle_exit, signal.SIGTERM, None - ) - - self.loop.add_signal_handler( - signal.SIGINT, self.handle_quit, signal.SIGINT, None - ) - - self.loop.add_signal_handler( - signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None - ) - - self.loop.add_signal_handler( - signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None - ) - - self.loop.add_signal_handler( - signal.SIGABRT, self.handle_abort, signal.SIGABRT, None - ) - - # Don't let SIGTERM and SIGUSR1 disturb active requests - # by interrupting system calls - signal.siginterrupt(signal.SIGTERM, False) - signal.siginterrupt(signal.SIGUSR1, False) - - def handle_quit(self, sig, frame): - self.alive = False - self.app.callable.is_running = False - self.cfg.worker_int(self) - - def handle_abort(self, sig, frame): - self.alive = False - self.exit_code = 1 - self.cfg.worker_abort(self) - sys.exit(1) - - def _await(self, coro): - fut = asyncio.ensure_future(coro, loop=self.loop) - self.loop.run_until_complete(fut) diff --git a/setup.py b/setup.py index ea6f285a..5dc874c3 100644 --- a/setup.py +++ b/setup.py @@ -148,6 +148,7 @@ extras_require = { "docs": docs_require, "all": all_require, "ext": ["sanic-ext"], + "http3": ["aioquic"], } setup_kwargs["install_requires"] = requirements diff --git a/tests/asyncmock.py b/tests/asyncmock.py new file mode 100644 index 00000000..eec17646 --- /dev/null +++ b/tests/asyncmock.py @@ -0,0 +1,34 @@ +""" +For 3.7 compat + +""" + + +from unittest.mock import Mock + + +class AsyncMock(Mock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.await_count = 0 + + def __call__(self, *args, **kwargs): + self.call_count += 1 + parent = super(AsyncMock, self) + + async def dummy(): + self.await_count += 1 + return parent.__call__(*args, **kwargs) + + return dummy() + + def __await__(self): + return self().__await__() + + def assert_awaited_once(self): + if not self.await_count == 1: + msg = ( + f"Expected to have been awaited once." + f" Awaited {self.await_count} times." + ) + raise AssertionError(msg) diff --git a/tests/conftest.py b/tests/conftest.py index 22decde5..fe4ba47d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -175,6 +175,21 @@ def run_startup(caplog): return run +@pytest.fixture +def run_multi(caplog): + def run(app, level=logging.DEBUG): + @app.after_server_start + async def stop(app, _): + app.stop() + + with caplog.at_level(level): + Sanic.serve() + + return caplog.record_tuples + + return run + + @pytest.fixture(scope="function") def message_in_records(): def msg_in_log(records: List[LogRecord], msg: str): diff --git a/tests/test_app.py b/tests/test_app.py index 467aaee4..1c8f705b 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -197,7 +197,7 @@ def test_app_enable_websocket(app, websocket_enabled, enable): assert app.websocket_enabled == True -@patch("sanic.app.WebSocketProtocol") +@patch("sanic.mixins.runner.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 app.config.WEBSOCKET_PING_TIMEOUT = 48 @@ -473,13 +473,14 @@ def test_custom_context(): assert app.ctx == ctx -def test_uvloop_config(app, monkeypatch): +@pytest.mark.parametrize("use", (False, True)) +def test_uvloop_config(app, monkeypatch, use): @app.get("/test") def handler(request): return text("ok") try_use_uvloop = Mock() - monkeypatch.setattr(sanic.app, "try_use_uvloop", try_use_uvloop) + monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop) # Default config app.test_client.get("/test") @@ -489,14 +490,13 @@ def test_uvloop_config(app, monkeypatch): try_use_uvloop.assert_called_once() try_use_uvloop.reset_mock() - app.config["USE_UVLOOP"] = False + app.config["USE_UVLOOP"] = use app.test_client.get("/test") - try_use_uvloop.assert_not_called() - try_use_uvloop.reset_mock() - app.config["USE_UVLOOP"] = True - app.test_client.get("/test") - try_use_uvloop.assert_called_once() + if use: + try_use_uvloop.assert_called_once() + else: + try_use_uvloop.assert_not_called() def test_uvloop_cannot_never_called_with_create_server(caplog, monkeypatch): @@ -506,7 +506,7 @@ def test_uvloop_cannot_never_called_with_create_server(caplog, monkeypatch): apps[2].config.USE_UVLOOP = True try_use_uvloop = Mock() - monkeypatch.setattr(sanic.app, "try_use_uvloop", try_use_uvloop) + monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop) loop = asyncio.get_event_loop() @@ -569,3 +569,8 @@ def test_cannot_run_fast_and_workers(app): message = "You cannot use both fast=True and workers=X" with pytest.raises(RuntimeError, match=message): app.run(fast=True, workers=4) + + +def test_no_workers(app): + with pytest.raises(RuntimeError, match="Cannot serve with no workers"): + app.run(workers=0) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d00a70bd..3687f576 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,4 +1,5 @@ import asyncio +import logging from collections import deque, namedtuple @@ -6,6 +7,7 @@ import pytest import uvicorn from sanic import Sanic +from sanic.application.state import Mode from sanic.asgi import MockTransport from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request @@ -44,7 +46,7 @@ def protocol(transport): return transport.get_protocol() -def test_listeners_triggered(): +def test_listeners_triggered(caplog): app = Sanic("app") before_server_start = False after_server_start = False @@ -82,9 +84,31 @@ def test_listeners_triggered(): config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) server = CustomServer(config=config) - with pytest.warns(UserWarning): + start_message = ( + 'You have set a listener for "before_server_start" in ASGI mode. ' + "It will be executed as early as possible, but not before the ASGI " + "server is started." + ) + stop_message = ( + 'You have set a listener for "after_server_stop" in ASGI mode. ' + "It will be executed as late as possible, but not after the ASGI " + "server is stopped." + ) + + with caplog.at_level(logging.DEBUG): server.run() + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) not in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) not in caplog.record_tuples + all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) for task in all_tasks: task.cancel() @@ -94,8 +118,38 @@ def test_listeners_triggered(): assert before_server_stop assert after_server_stop + app.state.mode = Mode.DEBUG + with caplog.at_level(logging.DEBUG): + server.run() -def test_listeners_triggered_async(app): + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) not in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) not in caplog.record_tuples + + app.state.verbosity = 2 + with caplog.at_level(logging.DEBUG): + server.run() + + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) in caplog.record_tuples + + +def test_listeners_triggered_async(app, caplog): before_server_start = False after_server_start = False before_server_stop = False @@ -132,9 +186,31 @@ def test_listeners_triggered_async(app): config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) server = CustomServer(config=config) - with pytest.warns(UserWarning): + start_message = ( + 'You have set a listener for "before_server_start" in ASGI mode. ' + "It will be executed as early as possible, but not before the ASGI " + "server is started." + ) + stop_message = ( + 'You have set a listener for "after_server_stop" in ASGI mode. ' + "It will be executed as late as possible, but not after the ASGI " + "server is stopped." + ) + + with caplog.at_level(logging.DEBUG): server.run() + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) not in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) not in caplog.record_tuples + all_tasks = asyncio.all_tasks(asyncio.get_event_loop()) for task in all_tasks: task.cancel() @@ -144,6 +220,36 @@ def test_listeners_triggered_async(app): assert before_server_stop assert after_server_stop + app.state.mode = Mode.DEBUG + with caplog.at_level(logging.DEBUG): + server.run() + + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) not in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) not in caplog.record_tuples + + app.state.verbosity = 2 + with caplog.at_level(logging.DEBUG): + server.run() + + assert ( + "sanic.root", + logging.DEBUG, + start_message, + ) in caplog.record_tuples + assert ( + "sanic.root", + logging.DEBUG, + stop_message, + ) in caplog.record_tuples + def test_non_default_uvloop_config_raises_warning(app): app.config.USE_UVLOOP = True diff --git a/tests/test_cli.py b/tests/test_cli.py index 254c91d4..7ec1c28c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -103,7 +103,7 @@ def test_tls_wrong_options(cmd): assert not out lines = err.decode().split("\n") - errmsg = lines[8] + errmsg = lines[6] assert errmsg == "TLS certificates must be specified by either of:" @@ -118,10 +118,10 @@ def test_host_port_localhost(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://localhost:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://localhost:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -135,10 +135,10 @@ def test_host_port_ipv4(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://127.0.0.127:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://127.0.0.127:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -152,10 +152,10 @@ def test_host_port_ipv6_any(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://[::]:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://[::]:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -169,10 +169,10 @@ def test_host_port_ipv6_loopback(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://[::1]:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://[::1]:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -191,24 +191,40 @@ def test_num_workers(num, cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - worker_lines = [ - line - for line in lines - if b"Starting worker" in line or b"Stopping worker" in line - ] + if num == 1: + expected = b"mode: production, single worker" + else: + expected = (f"mode: production, w/ {num} workers").encode() + assert exitcode != 1 - assert len(worker_lines) == num * 2, f"Lines found: {lines}" + assert expected in lines, f"Expected {expected}\nLines found: {lines}" -@pytest.mark.parametrize("cmd", ("--debug", "-d")) +@pytest.mark.parametrize("cmd", ("--debug",)) def test_debug(cmd): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") info = read_app_info(lines) - assert info["debug"] is True - assert info["auto_reload"] is True + assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is False + ), f"Lines found: {lines}\nErr output: {err}" + assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}" + + +@pytest.mark.parametrize("cmd", ("--dev", "-d")) +def test_dev(cmd): + command = ["sanic", "fake.server.app", cmd] + out, err, exitcode = capture(command) + lines = out.split(b"\n") + info = read_app_info(lines) + + assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is True + ), f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize("cmd", ("--auto-reload", "-r")) @@ -218,8 +234,11 @@ def test_auto_reload(cmd): lines = out.split(b"\n") info = read_app_info(lines) - assert info["debug"] is False - assert info["auto_reload"] is True + assert info["debug"] is False, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is True + ), f"Lines found: {lines}\nErr output: {err}" + assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -231,7 +250,9 @@ def test_access_logs(cmd, expected): lines = out.split(b"\n") info = read_app_info(lines) - assert info["access_log"] is expected + assert ( + info["access_log"] is expected + ), f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize("cmd", ("--version", "-v")) @@ -256,4 +277,6 @@ def test_noisy_exceptions(cmd, expected): lines = out.split(b"\n") info = read_app_info(lines) - assert info["noisy_exceptions"] is expected + assert ( + info["noisy_exceptions"] is expected + ), f"Lines found: {lines}\nErr output: {err}" diff --git a/tests/test_config.py b/tests/test_config.py index 9237b55c..d8a7bd85 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from os import environ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from unittest.mock import Mock +from unittest.mock import Mock, call import pytest @@ -301,6 +301,9 @@ def test_config_access_log_passing_in_run(app: Sanic): app.run(port=1340, access_log=False) assert app.config.ACCESS_LOG is False + app.router.reset() + app.signal_router.reset() + app.run(port=1340, access_log=True) assert app.config.ACCESS_LOG is True @@ -399,5 +402,36 @@ def test_config_set_methods(app: Sanic, monkeypatch: MonkeyPatch): post_set.assert_called_once_with("FOO", 5) post_set.reset_mock() - app.config.update_config({"FOO": 6}) - post_set.assert_called_once_with("FOO", 6) + app.config.update({"FOO": 6}, {"BAR": 7}) + post_set.assert_has_calls( + calls=[ + call("FOO", 6), + call("BAR", 7), + ] + ) + post_set.reset_mock() + + app.config.update({"FOO": 8}, BAR=9) + post_set.assert_has_calls( + calls=[ + call("FOO", 8), + call("BAR", 9), + ], + any_order=True, + ) + post_set.reset_mock() + + app.config.update_config({"FOO": 10}) + post_set.assert_called_once_with("FOO", 10) + + +def test_negative_proxy_count(app: Sanic): + app.config.PROXIES_COUNT = -1 + + message = ( + "PROXIES_COUNT cannot be negative. " + "https://sanic.readthedocs.io/en/latest/sanic/config.html" + "#proxy-configuration" + ) + with pytest.raises(ValueError, match=message): + app.prepare() diff --git a/tests/test_create_task.py b/tests/test_create_task.py index c98666a9..a11bc302 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -2,6 +2,7 @@ import asyncio import sys from threading import Event +from unittest.mock import Mock import pytest @@ -77,6 +78,25 @@ def test_create_named_task(app): app.run() +def test_named_task_called(app): + e = Event() + + async def coro(): + e.set() + + @app.route("/") + async def isset(request): + await asyncio.sleep(0.05) + return text(str(e.is_set())) + + @app.before_server_start + async def setup(app, _): + app.add_task(coro, name="dummy_task") + + request, response = app.test_client.get("/") + assert response.body == b"True" + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") def test_create_named_task_fails_outside_app(app): async def dummy(): diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index f8e425b0..042c51d6 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -67,7 +67,7 @@ def test_auto_fallback_with_data(app): _, response = app.test_client.get("/error") assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.post("/error", json={"foo": "bar"}) assert response.status == 500 @@ -75,7 +75,7 @@ def test_auto_fallback_with_data(app): _, response = app.test_client.post("/error", data={"foo": "bar"}) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" def test_auto_fallback_with_content_type(app): @@ -91,7 +91,7 @@ def test_auto_fallback_with_content_type(app): "/error", headers={"content-type": "foo/bar", "accept": "*/*"} ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" def test_route_error_format_set_on_auto(app): @@ -174,6 +174,17 @@ def test_route_error_format_unknown(app): ... +def test_fallback_with_content_type_html(app): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "application/json", "accept": "text/html"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + def test_fallback_with_content_type_mismatch_accept(app): app.config.FALLBACK_ERROR_FORMAT = "auto" @@ -186,10 +197,10 @@ def test_fallback_with_content_type_mismatch_accept(app): _, response = app.test_client.get( "/error", - headers={"content-type": "text/plain", "accept": "foo/bar"}, + headers={"content-type": "text/html", "accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" app.router.reset() @@ -208,7 +219,7 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.get( "/alt1", headers={"accept": "foo/bar,*/*"}, @@ -221,7 +232,7 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.get( "/alt2", headers={"accept": "foo/bar,*/*"}, @@ -234,6 +245,13 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/alt3", + headers={"accept": "foo/bar,text/html"}, + ) + assert response.status == 500 assert response.content_type == "text/html; charset=utf-8" @@ -288,6 +306,10 @@ def test_allow_fallback_error_format_set_main_process_start(app): def test_setting_fallback_on_config_changes_as_expected(app): app.error_handler = ErrorHandler() + _, response = app.test_client.get("/error") + assert response.content_type == "text/plain; charset=utf-8" + + app.config.FALLBACK_ERROR_FORMAT = "html" _, response = app.test_client.get("/error") assert response.content_type == "text/html; charset=utf-8" @@ -334,6 +356,22 @@ def test_config_fallback_before_and_after_startup(app): assert response.content_type == "text/plain; charset=utf-8" +def test_config_fallback_using_update_dict(app): + app.config.update({"FALLBACK_ERROR_FORMAT": "text"}) + + _, response = app.test_client.get("/error") + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_config_fallback_using_update_kwarg(app): + app.config.update(FALLBACK_ERROR_FORMAT="text") + + _, response = app.test_client.get("/error") + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + def test_config_fallback_bad_value(app): message = "Unknown format: fake" with pytest.raises(SanicException, match=message): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3ec2959d..e603718d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,4 @@ import logging -import warnings import pytest @@ -34,6 +33,7 @@ class SanicExceptionTestException(Exception): @pytest.fixture(scope="module") def exception_app(): app = Sanic("test_exceptions") + app.config.FALLBACK_ERROR_FORMAT = "html" @app.route("/") def handler(request): diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index 8be3d28e..e9bdb21e 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -216,31 +216,6 @@ def test_exception_handler_processed_request_middleware( assert response.text == "Done." -def test_single_arg_exception_handler_notice( - exception_handler_app: Sanic, caplog: LogCaptureFixture -): - class CustomErrorHandler(ErrorHandler): - def lookup(self, exception): - return super().lookup(exception, None) - - exception_handler_app.error_handler = CustomErrorHandler() - - message = ( - "[DEPRECATION v22.3] You are using a deprecated error handler. The " - "lookup method should accept two positional parameters: (exception, " - "route_name: Optional[str]). Until you upgrade your " - "ErrorHandler.lookup, Blueprint specific exceptions will not work " - "properly. Beginning in v22.3, the legacy style lookup method will " - "not work at all." - ) - with pytest.warns(DeprecationWarning) as record: - _, response = exception_handler_app.test_client.get("/1") - - assert len(record) == 1 - assert record[0].message.args[0] == message - assert response.status == 400 - - def test_error_handler_noisy_log( exception_handler_app: Sanic, monkeypatch: MonkeyPatch ): @@ -279,7 +254,7 @@ def test_exception_handler_response_was_sent( @app.route("/2") async def handler2(request: Request): - response = await request.respond() + await request.respond() raise ServerError("Exception") with caplog.at_level(logging.WARNING): diff --git a/tests/test_motd.py b/tests/test_motd.py index fe45bc47..f3f95a25 100644 --- a/tests/test_motd.py +++ b/tests/test_motd.py @@ -1,11 +1,15 @@ import logging +import os import platform +import sys from unittest.mock import Mock -from sanic import __version__ +import pytest + +from sanic import Sanic, __version__ from sanic.application.logo import BASE_LOGO -from sanic.application.motd import MOTDTTY +from sanic.application.motd import MOTD, MOTDTTY def test_logo_base(app, run_startup): @@ -83,3 +87,25 @@ def test_motd_display(caplog): └───────────────────────┴────────┘ """ ) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not on 3.7") +def test_reload_dirs(app): + app.config.LOGO = None + app.config.AUTO_RELOAD = True + app.prepare(reload_dir="./", auto_reload=True, motd_display={"foo": "bar"}) + + existing = MOTD.output + MOTD.output = Mock() + + app.motd("foo") + + MOTD.output.assert_called_once() + assert ( + MOTD.output.call_args.args[2]["auto-reload"] + == f"enabled, {os.getcwd()}" + ) + assert MOTD.output.call_args.args[3] == {"foo": "bar"} + + MOTD.output = existing + Sanic._app_registry = {} diff --git a/tests/test_multi_serve.py b/tests/test_multi_serve.py new file mode 100644 index 00000000..dde72b5c --- /dev/null +++ b/tests/test_multi_serve.py @@ -0,0 +1,207 @@ +import logging + +from unittest.mock import Mock + +import pytest + +from sanic import Sanic +from sanic.response import text +from sanic.server.async_server import AsyncioServer +from sanic.signals import Event +from sanic.touchup.schemes.ode import OptionalDispatchEvent + + +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + + +@pytest.fixture +def app_one(): + app = Sanic("One") + + @app.get("/one") + async def one(request): + return text("one") + + return app + + +@pytest.fixture +def app_two(): + app = Sanic("Two") + + @app.get("/two") + async def two(request): + return text("two") + + return app + + +@pytest.fixture(autouse=True) +def clean(): + Sanic._app_registry = {} + yield + + +def test_serve_same_app_multiple_tuples(app_one, run_multi): + app_one.prepare(port=23456) + app_one.prepare(port=23457) + + logs = run_multi(app_one) + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23456", + ) in logs + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23457", + ) in logs + + +def test_serve_multiple_apps(app_one, app_two, run_multi): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + logs = run_multi(app_one) + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23456", + ) in logs + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23457", + ) in logs + + +def test_listeners_on_secondary_app(app_one, app_two, run_multi): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + before_start = AsyncMock() + after_start = AsyncMock() + before_stop = AsyncMock() + after_stop = AsyncMock() + + app_two.before_server_start(before_start) + app_two.after_server_start(after_start) + app_two.before_server_stop(before_stop) + app_two.after_server_stop(after_stop) + + run_multi(app_one) + + before_start.assert_awaited_once() + after_start.assert_awaited_once() + before_stop.assert_awaited_once() + after_stop.assert_awaited_once() + + +@pytest.mark.parametrize( + "events", + ( + (Event.HTTP_LIFECYCLE_BEGIN,), + (Event.HTTP_LIFECYCLE_BEGIN, Event.HTTP_LIFECYCLE_COMPLETE), + ( + Event.HTTP_LIFECYCLE_BEGIN, + Event.HTTP_LIFECYCLE_COMPLETE, + Event.HTTP_LIFECYCLE_REQUEST, + ), + ), +) +def test_signal_synchronization(app_one, app_two, run_multi, events): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + for event in events: + app_one.signal(event)(AsyncMock()) + + run_multi(app_one) + + assert len(app_two.signal_router.routes) == len(events) + 1 + + signal_handlers = { + signal.handler + for signal in app_two.signal_router.routes + if signal.name.startswith("http") + } + + assert len(signal_handlers) == 1 + assert list(signal_handlers)[0] is OptionalDispatchEvent.noop + + +def test_warning_main_process_listeners_on_secondary( + app_one, app_two, run_multi +): + app_two.main_process_start(AsyncMock()) + app_two.main_process_stop(AsyncMock()) + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + log = run_multi(app_one) + + message = ( + f"Sanic found 2 listener(s) on " + "secondary applications attached to the main " + "process. These will be ignored since main " + "process listeners can only be attached to your " + "primary application: " + f"{repr(app_one)}" + ) + + assert ("sanic.error", logging.WARNING, message) in log + + +def test_no_applications(): + Sanic._app_registry = {} + message = "Did not find any applications." + with pytest.raises(RuntimeError, match=message): + Sanic.serve() + + +def test_oserror_warning(app_one, app_two, run_multi, capfd): + orig = AsyncioServer.__await__ + AsyncioServer.__await__ = Mock(side_effect=OSError("foo")) + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457, workers=2) + + run_multi(app_one) + + captured = capfd.readouterr() + assert ( + "An OSError was detected on startup. The encountered error was: foo" + ) in captured.err + + AsyncioServer.__await__ = orig + + +def test_running_multiple_offset_warning(app_one, app_two, run_multi, capfd): + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457) + + run_multi(app_one) + + captured = capfd.readouterr() + assert ( + f"The primary application {repr(app_one)} is running " + "with 2 worker(s). All " + "application instances will run with the same number. " + f"You requested {repr(app_two)} to run with " + "1 worker(s), which will be ignored " + "in favor of the primary application." + ) in captured.err + + +def test_running_multiple_secondary(app_one, app_two, run_multi, capfd): + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457) + + before_start = AsyncMock() + app_two.before_server_start(before_start) + run_multi(app_one) + + before_start.await_count == 2 diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 066b89ac..ab7c18de 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -132,11 +132,11 @@ def test_main_process_event(app, caplog): logger.info("main_process_stop") @app.main_process_start - def main_process_start(app, loop): + def main_process_start2(app, loop): logger.info("main_process_start") @app.main_process_stop - def main_process_stop(app, loop): + def main_process_stop2(app, loop): logger.info("main_process_stop") with caplog.at_level(logging.INFO): diff --git a/tests/test_pipelining.py b/tests/test_pipelining.py index 2bb29c52..6c998756 100644 --- a/tests/test_pipelining.py +++ b/tests/test_pipelining.py @@ -62,19 +62,15 @@ def test_streaming_body_requests(app): data = ["hello", "world"] - class Data(AsyncByteStream): - def __init__(self, data): - self.data = data - - async def __aiter__(self): - for value in self.data: - yield value.encode("utf-8") - client = ReusableClient(app, port=1234) + async def stream(data): + for value in data: + yield value.encode("utf-8") + with client: - _, response1 = client.post("/", data=Data(data)) - _, response2 = client.post("/", data=Data(data)) + _, response1 = client.post("/", data=stream(data)) + _, response2 = client.post("/", data=stream(data)) assert response1.status == response2.status == 200 assert response1.json["data"] == response2.json["data"] == data diff --git a/tests/test_prepare.py b/tests/test_prepare.py new file mode 100644 index 00000000..db8a8db5 --- /dev/null +++ b/tests/test_prepare.py @@ -0,0 +1,71 @@ +import logging +import os + +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from sanic import Sanic +from sanic.application.state import ApplicationServerInfo + + +@pytest.fixture(autouse=True) +def no_skip(): + should_auto_reload = Sanic.should_auto_reload + Sanic.should_auto_reload = Mock(return_value=False) + yield + Sanic._app_registry = {} + Sanic.should_auto_reload = should_auto_reload + + +def get_primary(app: Sanic) -> ApplicationServerInfo: + return app.state.server_info[0] + + +def test_dev(app: Sanic): + app.prepare(dev=True) + + assert app.state.is_debug + assert app.state.auto_reload + + +def test_motd_display(app: Sanic): + app.prepare(motd_display={"foo": "bar"}) + + assert app.config.MOTD_DISPLAY["foo"] == "bar" + del app.config.MOTD_DISPLAY["foo"] + + +@pytest.mark.parametrize("dirs", ("./foo", ("./foo", "./bar"))) +def test_reload_dir(app: Sanic, dirs, caplog): + messages = [] + with caplog.at_level(logging.WARNING): + app.prepare(reload_dir=dirs) + + if isinstance(dirs, str): + dirs = (dirs,) + for d in dirs: + assert Path(d) in app.state.reload_dirs + messages.append( + f"Directory {d} could not be located", + ) + + for message in messages: + assert ("sanic.root", logging.WARNING, message) in caplog.record_tuples + + +def test_fast(app: Sanic, run_multi): + app.prepare(fast=True) + try: + workers = len(os.sched_getaffinity(0)) + except AttributeError: + workers = os.cpu_count() or 1 + + assert app.state.fast + assert app.state.workers == workers + + logs = run_multi(app, logging.INFO) + + messages = [m[2] for m in logs] + assert f"mode: production, goin' fast w/ {workers} workers" in messages diff --git a/tests/test_requests.py b/tests/test_requests.py index c8f6e3f0..d752f045 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,3 +1,4 @@ +import base64 import logging from json import dumps as json_dumps @@ -15,11 +16,15 @@ from sanic_testing.testing import ( ) from sanic import Blueprint, Sanic -from sanic.exceptions import SanicException, ServerError -from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters +from sanic.exceptions import ServerError +from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters from sanic.response import html, json, text +def encode_basic_auth_credentials(username, password): + return base64.b64encode(f"{username}:{password}".encode()).decode("ascii") + + # ------------------------------------------------------------ # # GET # ------------------------------------------------------------ # @@ -362,93 +367,95 @@ async def test_uri_template_asgi(app): assert request.uri_template == "/foo//bar/" -def test_token(app): +@pytest.mark.parametrize( + ("auth_type", "token"), + [ + # uuid4 generated token set in "Authorization" header + (None, "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # uuid4 generated token with API Token authorization + ("Token", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # uuid4 generated token with Bearer Token authorization + ("Bearer", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # no Authorization header + (None, None), + ], +) +def test_token(app, auth_type, token): @app.route("/") async def handler(request): return text("OK") - # uuid4 generated token. - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"{token}", - } + if token: + headers = { + "content-type": "application/json", + "Authorization": f"{auth_type} {token}" + if auth_type + else f"{token}", + } + else: + headers = {"content-type": "application/json"} request, response = app.test_client.get("/", headers=headers) - assert request.token == token - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Token {token}", - } - request, response = app.test_client.get("/", headers=headers) - - assert request.token == token - - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Bearer {token}", - } - - request, response = app.test_client.get("/", headers=headers) - - assert request.token == token - - # no Authorization headers - headers = {"content-type": "application/json"} - - request, response = app.test_client.get("/", headers=headers) - - assert request.token is None - - -@pytest.mark.asyncio -async def test_token_asgi(app): +@pytest.mark.parametrize( + ("auth_type", "token", "username", "password"), + [ + # uuid4 generated token set in "Authorization" header + (None, "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # uuid4 generated token with API Token authorization + ("Token", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # uuid4 generated token with Bearer Token authorization + ("Bearer", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # username and password with Basic Auth authorization + ( + "Basic", + encode_basic_auth_credentials("some_username", "some_pass"), + "some_username", + "some_pass", + ), + # no Authorization header + (None, None, None, None), + ], +) +def test_credentials(app, capfd, auth_type, token, username, password): @app.route("/") async def handler(request): return text("OK") - # uuid4 generated token. - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"{token}", - } + if token: + headers = { + "content-type": "application/json", + "Authorization": f"{auth_type} {token}" + if auth_type + else f"{token}", + } + else: + headers = {"content-type": "application/json"} - request, response = await app.asgi_client.get("/", headers=headers) + request, response = app.test_client.get("/", headers=headers) - assert request.token == token + if auth_type == "Basic": + assert request.credentials.username == username + assert request.credentials.password == password + else: + _, err = capfd.readouterr() + with pytest.raises(AttributeError): + request.credentials.password + assert "Password is available for Basic Auth only" in err + request.credentials.username + assert "Username is available for Basic Auth only" in err - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Token {token}", - } - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token == token - - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Bearer {token}", - } - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token == token - - # no Authorization headers - headers = {"content-type": "application/json"} - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token is None + if token: + assert request.credentials.token == token + assert request.credentials.auth_type == auth_type + else: + assert request.credentials is None + assert not hasattr(request.credentials, "token") + assert not hasattr(request.credentials, "auth_type") + assert not hasattr(request.credentials, "_username") + assert not hasattr(request.credentials, "_password") def test_content_type(app): @@ -1714,7 +1721,6 @@ async def test_request_query_args_custom_parsing_asgi(app): def test_request_cookies(app): - cookies = {"test": "OK"} @app.get("/") @@ -1729,7 +1735,6 @@ def test_request_cookies(app): @pytest.mark.asyncio async def test_request_cookies_asgi(app): - cookies = {"test": "OK"} @app.get("/") diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 058a7cf6..cd3f5266 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -35,7 +35,7 @@ def create_listener(listener_name, in_list): def start_stop_app(random_name_app, **run_kwargs): def stop_on_alarm(signum, frame): - raise KeyboardInterrupt("SIGINT for sanic to stop gracefully") + random_name_app.stop() signal.signal(signal.SIGALRM, stop_on_alarm) signal.alarm(1) @@ -130,6 +130,9 @@ async def test_trigger_before_events_create_server_missing_event(app): def test_create_server_trigger_events(app): """Test if create_server can trigger server events""" + def stop_on_alarm(signum, frame): + raise KeyboardInterrupt("...") + flag1 = False flag2 = False flag3 = False @@ -137,8 +140,7 @@ def test_create_server_trigger_events(app): async def stop(app, loop): nonlocal flag1 flag1 = True - await asyncio.sleep(0.1) - app.stop() + signal.alarm(1) async def before_stop(app, loop): nonlocal flag2 @@ -155,6 +157,8 @@ def test_create_server_trigger_events(app): loop = asyncio.get_event_loop() # Use random port for tests + + signal.signal(signal.SIGALRM, stop_on_alarm) with closing(socket()) as sock: sock.bind(("127.0.0.1", 0)) diff --git a/tests/test_server_loop.py b/tests/test_server_loop.py index fbd5cc2b..30077178 100644 --- a/tests/test_server_loop.py +++ b/tests/test_server_loop.py @@ -4,8 +4,8 @@ from unittest.mock import Mock, patch import pytest -from sanic.server import loop from sanic.compat import OS_IS_WINDOWS, UVLOOP_INSTALLED +from sanic.server import loop @pytest.mark.skipif( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a1b98a81..63de50af 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,5 +1,4 @@ import asyncio -import sys from asyncio.tasks import Task from unittest.mock import Mock, call @@ -7,9 +6,15 @@ from unittest.mock import Mock, call import pytest from sanic.app import Sanic +from sanic.application.state import ApplicationServerInfo, ServerStage from sanic.response import empty +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + pytestmark = pytest.mark.asyncio @@ -20,11 +25,14 @@ async def dummy(n=0): @pytest.fixture(autouse=True) -def mark_app_running(app): - app.is_running = True +def mark_app_running(app: Sanic): + app.state.server_info.append( + ApplicationServerInfo( + stage=ServerStage.SERVING, settings={}, server=AsyncMock() + ) + ) -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_add_task_returns_task(app: Sanic): task = app.add_task(dummy()) @@ -32,7 +40,6 @@ async def test_add_task_returns_task(app: Sanic): assert len(app._task_registry) == 0 -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_add_task_with_name(app: Sanic): task = app.add_task(dummy(), name="dummy") @@ -44,7 +51,6 @@ async def test_add_task_with_name(app: Sanic): assert task in app._task_registry.values() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_cancel_task(app: Sanic): task = app.add_task(dummy(3), name="dummy") @@ -62,7 +68,6 @@ async def test_cancel_task(app: Sanic): assert task.cancelled() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_purge_tasks(app: Sanic): app.add_task(dummy(3), name="dummy") @@ -75,7 +80,6 @@ async def test_purge_tasks(app: Sanic): assert len(app._task_registry) == 0 -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") def test_shutdown_tasks_on_app_stop(): class TestSanic(Sanic): shutdown_tasks = Mock() diff --git a/tests/test_unix_socket.py b/tests/test_unix_socket.py index c26ebe06..ea3b4b1d 100644 --- a/tests/test_unix_socket.py +++ b/tests/test_unix_socket.py @@ -72,14 +72,12 @@ def test_unix_socket_creation(caplog): assert not os.path.exists(SOCKPATH) -def test_invalid_paths(): +@pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock")) +def test_invalid_paths(path): app = Sanic(name=__name__) - with pytest.raises(FileExistsError): - app.run(unix=".") - - with pytest.raises(FileNotFoundError): - app.run(unix="no-such-directory/sanictest.sock") + with pytest.raises((FileExistsError, FileNotFoundError)): + app.run(unix=path) def test_dont_replace_file(): @@ -201,7 +199,7 @@ async def test_zero_downtime(): for _ in range(40): async with httpx.AsyncClient(transport=transport) as client: r = await client.get("http://localhost/sleep/0.1") - assert r.status_code == 200 + assert r.status_code == 200, r.content assert r.text == "Slept 0.1 seconds.\n" def spawn(): @@ -209,6 +207,7 @@ async def test_zero_downtime(): sys.executable, "-m", "sanic", + "--debug", "--unix", SOCKPATH, "examples.delayed_response.app", diff --git a/tests/test_websockets.py b/tests/test_websockets.py new file mode 100644 index 00000000..329eff45 --- /dev/null +++ b/tests/test_websockets.py @@ -0,0 +1,243 @@ +import re + +from asyncio import Event, Queue, TimeoutError +from unittest.mock import Mock, call + +import pytest + +from websockets.frames import CTRL_OPCODES, DATA_OPCODES, Frame + +from sanic.exceptions import ServerError +from sanic.server.websockets.frame import WebsocketFrameAssembler + + +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_incomplete_timeout_0(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get(0) + + assert data is None + assembler.message_complete.is_set.assert_called_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_in_progress(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.get_in_progress = True + + message = re.escape( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + + with pytest.raises(ServerError, match=message): + await assembler.get() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_incomplete(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get() + + assert data is None + assembler.message_complete.wait.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + data = await assembler.get() + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_with_timeout(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + data = await assembler.get(0.1) + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + assert assembler.message_complete.is_set.call_count == 2 + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_with_timeouterror(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + assembler.message_complete.wait.side_effect = TimeoutError("...") + data = await assembler.get(0.1) + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + assert assembler.message_complete.is_set.call_count == 2 + + +@pytest.mark.asyncio +async def test_ws_frame_get_not_completed(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get() + + assert data is None + + +@pytest.mark.asyncio +async def test_ws_frame_get_not_completed_start(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(side_effect=[False, True]) + data = await assembler.get(0.1) + + assert data is None + + +@pytest.mark.asyncio +async def test_ws_frame_get_paused(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(side_effect=[False, True]) + assembler.paused = True + data = await assembler.get() + + assert data is None + assembler.protocol.resume_frames.assert_called_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_data(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=True) + assembler.chunks = [b"foo", b"bar"] + data = await assembler.get() + + assert data == b"foobar" + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_in_progress(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.get_in_progress = True + + message = re.escape( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + + with pytest.raises(ServerError, match=message): + [x async for x in assembler.get_iter()] + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_none_in_queue(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + assembler.chunks = [b"foo", b"bar"] + + chunks = [x async for x in assembler.get_iter()] + + assert chunks == [b"foo", b"bar"] + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_paused(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + assembler.paused = True + + [x async for x in assembler.get_iter()] + assembler.protocol.resume_frames.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_not_fetched(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_fetched.set() + + message = re.escape( + "Websocket put() got a new message when the previous message was " + "not yet fetched." + ) + with pytest.raises(ServerError, match=message): + await assembler.put(Frame(opcode, b"")) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_fetched(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_fetched = AsyncMock() + assembler.message_fetched.is_set = Mock(return_value=False) + + await assembler.put(Frame(opcode, b"")) + assembler.message_fetched.wait.assert_awaited_once() + assembler.message_fetched.clear.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_message_complete(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + + message = re.escape( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + with pytest.raises(ServerError, match=message): + await assembler.put(Frame(opcode, b"")) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_message_into_queue(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.chunks_queue = AsyncMock(spec=Queue) + assembler.message_fetched = AsyncMock() + assembler.message_fetched.is_set = Mock(return_value=False) + + await assembler.put(Frame(opcode, b"foo")) + + assembler.chunks_queue.put.has_calls( + call(b"foo"), + call(None), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_not_fin(opcode): + assembler = WebsocketFrameAssembler(Mock()) + + retval = await assembler.put(Frame(opcode, b"foo", fin=False)) + + assert retval is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", CTRL_OPCODES) +async def test_ws_frame_put_skip_ctrl(opcode): + assembler = WebsocketFrameAssembler(Mock()) + + retval = await assembler.put(Frame(opcode, b"")) + + assert retval is None diff --git a/tests/test_worker.py b/tests/test_worker.py deleted file mode 100644 index cdc30a05..00000000 --- a/tests/test_worker.py +++ /dev/null @@ -1,200 +0,0 @@ -import asyncio -import json -import shlex -import subprocess -import time -import urllib.request - -from unittest import mock - -import pytest - -from sanic_testing.testing import ASGI_PORT as PORT - -from sanic.app import Sanic -from sanic.worker import GunicornWorker - - -@pytest.fixture -def gunicorn_worker(): - command = ( - "gunicorn " - f"--bind 127.0.0.1:{PORT} " - "--worker-class sanic.worker.GunicornWorker " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command)) - time.sleep(2) - yield - worker.kill() - - -@pytest.fixture -def gunicorn_worker_with_access_logs(): - command = ( - "gunicorn " - f"--bind 127.0.0.1:{PORT + 1} " - "--worker-class sanic.worker.GunicornWorker " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE) - time.sleep(2) - return worker - - -@pytest.fixture -def gunicorn_worker_with_env_var(): - command = ( - 'env SANIC_ACCESS_LOG="False" ' - "gunicorn " - f"--bind 127.0.0.1:{PORT + 2} " - "--worker-class sanic.worker.GunicornWorker " - "--log-level info " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE) - time.sleep(2) - return worker - - -def test_gunicorn_worker(gunicorn_worker): - with urllib.request.urlopen(f"http://localhost:{PORT}/") as f: - res = json.loads(f.read(100).decode()) - assert res["test"] - - -def test_gunicorn_worker_no_logs(gunicorn_worker_with_env_var): - """ - if SANIC_ACCESS_LOG was set to False do not show access logs - """ - with urllib.request.urlopen(f"http://localhost:{PORT + 2}/") as _: - gunicorn_worker_with_env_var.kill() - logs = list( - filter( - lambda x: b"sanic.access" in x, - gunicorn_worker_with_env_var.stdout.read().split(b"\n"), - ) - ) - assert len(logs) == 0 - - -def test_gunicorn_worker_with_logs(gunicorn_worker_with_access_logs): - """ - default - show access logs - """ - with urllib.request.urlopen(f"http://localhost:{PORT + 1}/") as _: - gunicorn_worker_with_access_logs.kill() - assert ( - b"(sanic.access)[INFO][127.0.0.1" - in gunicorn_worker_with_access_logs.stdout.read() - ) - - -class GunicornTestWorker(GunicornWorker): - def __init__(self): - self.app = mock.Mock() - self.app.callable = Sanic("test_gunicorn_worker") - self.servers = {} - self.exit_code = 0 - self.cfg = mock.Mock() - self.notify = mock.Mock() - - -@pytest.fixture -def worker(): - return GunicornTestWorker() - - -def test_worker_init_process(worker): - with mock.patch("sanic.worker.asyncio") as mock_asyncio: - try: - worker.init_process() - except TypeError: - pass - - assert mock_asyncio.get_event_loop.return_value.close.called - assert mock_asyncio.new_event_loop.called - assert mock_asyncio.set_event_loop.called - - -def test_worker_init_signals(worker): - worker.loop = mock.Mock() - worker.init_signals() - assert worker.loop.add_signal_handler.called - - -def test_handle_abort(worker): - with mock.patch("sanic.worker.sys") as mock_sys: - worker.handle_abort(object(), object()) - assert not worker.alive - assert worker.exit_code == 1 - mock_sys.exit.assert_called_with(1) - - -def test_handle_quit(worker): - worker.handle_quit(object(), object()) - assert not worker.alive - assert worker.exit_code == 0 - - -async def _a_noop(*a, **kw): - pass - - -def test_run_max_requests_exceeded(worker): - loop = asyncio.new_event_loop() - worker.ppid = 1 - worker.alive = True - sock = mock.Mock() - sock.cfg_addr = ("localhost", 8080) - worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.connections = set() - worker.log = mock.Mock() - worker.loop = loop - worker.servers = { - "server1": {"requests_count": 14}, - "server2": {"requests_count": 15}, - } - worker.max_requests = 10 - worker._run = mock.Mock(wraps=_a_noop) - - # exceeding request count - _runner = asyncio.ensure_future(worker._check_alive(), loop=loop) - loop.run_until_complete(_runner) - - assert not worker.alive - worker.notify.assert_called_with() - worker.log.info.assert_called_with( - "Max requests exceeded, shutting " "down: %s", worker - ) - - -def test_worker_close(worker): - loop = asyncio.new_event_loop() - asyncio.sleep = mock.Mock(wraps=_a_noop) - worker.ppid = 1 - worker.pid = 2 - worker.cfg.graceful_timeout = 1.0 - worker.signal = mock.Mock() - worker.signal.stopped = False - worker.wsgi = mock.Mock() - conn = mock.Mock() - conn.websocket = mock.Mock() - conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) - worker.connections = set([conn]) - worker.log = mock.Mock() - worker.loop = loop - server = mock.Mock() - server.close = mock.Mock(wraps=lambda *a, **kw: None) - server.wait_closed = mock.Mock(wraps=_a_noop) - worker.servers = {server: {"requests_count": 14}} - worker.max_requests = 10 - - # close worker - _close = asyncio.ensure_future(worker.close(), loop=loop) - loop.run_until_complete(_close) - - assert worker.signal.stopped - assert conn.websocket.fail_connection.called - assert len(worker.servers) == 0 diff --git a/tox.ini b/tox.ini index 609ceb48..68919e59 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,9 @@ setenv = {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UJSON=1 {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 extras = test +allowlist_externals = + pytest + coverage commands = pytest {posargs:tests --cov sanic} - coverage combine --append @@ -41,7 +44,7 @@ commands = [testenv:docs] platform = linux|linux2|darwin -whitelist_externals = make +allowlist_externals = make extras = docs commands = make docs-test