diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 13a5783d..00000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,28 +0,0 @@ -exclude_patterns: - - "sanic/__main__.py" - - "sanic/application/logo.py" - - "sanic/application/motd.py" - - "sanic/reloader_helpers.py" - - "sanic/simple.py" - - "sanic/utils.py" - - ".github/" - - "changelogs/" - - "docker/" - - "docs/" - - "examples/" - - "scripts/" - - "tests/" -checks: - argument-count: - enabled: false - file-lines: - config: - threshold: 1000 - method-count: - config: - threshold: 40 - complex-logic: - enabled: false - method-complexity: - config: - threshold: 10 diff --git a/.coveragerc b/.coveragerc index 63bec82c..22856065 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,13 +3,12 @@ branch = True source = sanic omit = site-packages - sanic/application/logo.py - sanic/application/motd.py - sanic/cli sanic/__main__.py + sanic/compat.py sanic/reloader_helpers.py sanic/simple.py sanic/utils.py + sanic/cli [html] directory = coverage @@ -21,3 +20,12 @@ exclude_lines = noqa NOQA pragma: no cover +omit = + site-packages + sanic/__main__.py + sanic/compat.py + sanic/reloader_helpers.py + sanic/simple.py + sanic/utils.py + sanic/cli +skip_empty = True diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 5108c247..1113fa80 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -2,9 +2,13 @@ name: "CodeQL" on: push: - branches: [ main ] + branches: + - main + - "*LTS" pull_request: - branches: [ main ] + branches: + - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] schedule: - cron: '25 16 * * 0' diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 92d93aa7..13e82615 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -3,13 +3,15 @@ on: push: branches: - main + - "*LTS" tags: - "!*" # Do not execute on tags pull_request: - types: [opened, synchronize, reopened, ready_for_review] + branches: + - main + - "*LTS" jobs: test: - if: github.event.pull_request.draft == false runs-on: ${{ matrix.os }} strategy: matrix: @@ -19,7 +21,6 @@ jobs: steps: - uses: actions/checkout@v2 - - uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} @@ -28,9 +29,10 @@ jobs: run: | python -m pip install --upgrade pip pip install tox - - uses: paambaati/codeclimate-action@v2.5.3 - if: always() - env: - CC_TEST_REPORTER_ID: ${{ secrets.CODECLIMATE }} + - name: Run coverage + run: tox -e coverage + continue-on-error: true + - uses: codecov/codecov-action@v2 with: - coverageCommand: tox -e coverage + files: ./coverage.xml + fail_ci_if_error: false diff --git a/.github/workflows/pr-bandit.yml b/.github/workflows/pr-bandit.yml index ca91312a..2bd70204 100644 --- a/.github/workflows/pr-bandit.yml +++ b/.github/workflows/pr-bandit.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-docs.yml b/.github/workflows/pr-docs.yml index 7b3c2f6e..8479aef5 100644 --- a/.github/workflows/pr-docs.yml +++ b/.github/workflows/pr-docs.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-linter.yml b/.github/workflows/pr-linter.yml index 9ed45d0a..11ad9d29 100644 --- a/.github/workflows/pr-linter.yml +++ b/.github/workflows/pr-linter.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python310.yml b/.github/workflows/pr-python310.yml index f3f7c607..5e66deec 100644 --- a/.github/workflows/pr-python310.yml +++ b/.github/workflows/pr-python310.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python37.yml b/.github/workflows/pr-python37.yml index 50f79c6e..c0051d33 100644 --- a/.github/workflows/pr-python37.yml +++ b/.github/workflows/pr-python37.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python38.yml b/.github/workflows/pr-python38.yml index 1e0b8050..09e93f3f 100644 --- a/.github/workflows/pr-python38.yml +++ b/.github/workflows/pr-python38.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-python39.yml b/.github/workflows/pr-python39.yml index 1abd6bcb..ff479459 100644 --- a/.github/workflows/pr-python39.yml +++ b/.github/workflows/pr-python39.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/.github/workflows/pr-type-check.yml b/.github/workflows/pr-type-check.yml index 2fae03be..58a90ee3 100644 --- a/.github/workflows/pr-type-check.yml +++ b/.github/workflows/pr-type-check.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: @@ -15,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] config: - - { python-version: 3.7, tox-env: type-checking} + # - { python-version: 3.7, tox-env: type-checking} - { python-version: 3.8, tox-env: type-checking} - { python-version: 3.9, tox-env: type-checking} - { python-version: "3.10", tox-env: type-checking} diff --git a/.github/workflows/pr-windows.yml b/.github/workflows/pr-windows.yml index 9721b5b5..050ee2bb 100644 --- a/.github/workflows/pr-windows.yml +++ b/.github/workflows/pr-windows.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - main + - "*LTS" types: [opened, synchronize, reopened, ready_for_review] jobs: diff --git a/README.rst b/README.rst index 6b6d0408..316c323d 100644 --- a/README.rst +++ b/README.rst @@ -66,7 +66,7 @@ Sanic | Build fast. Run fast. Sanic is a **Python 3.7+** web server and web framework that's written to go fast. It allows the usage of the ``async/await`` syntax added in Python 3.5, which makes your code non-blocking and speedy. -Sanic is also ASGI compliant, so you can deploy it with an `alternative ASGI webserver `_. +Sanic is also ASGI compliant, so you can deploy it with an `alternative ASGI webserver `_. `Source code on GitHub `_ | `Help and discussion board `_ | `User Guide `_ | `Chat on Discord `_ diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..5905afd5 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,27 @@ +coverage: + status: + patch: + default: + target: auto + threshold: 0.75 + project: + default: + target: auto + threshold: 0.5 + precision: 3 +codecov: + require_ci_to_pass: false +ignore: + - "sanic/__main__.py" + - "sanic/compat.py" + - "sanic/reloader_helpers.py" + - "sanic/simple.py" + - "sanic/utils.py" + - "sanic/cli" + - ".github/" + - "changelogs/" + - "docker/" + - "docs/" + - "examples/" + - "scripts/" + - "tests/" diff --git a/docs/sanic/releases/21/21.12.md b/docs/sanic/releases/21/21.12.md index 6c4dc419..f8f0d954 100644 --- a/docs/sanic/releases/21/21.12.md +++ b/docs/sanic/releases/21/21.12.md @@ -1,3 +1,9 @@ +## Version 21.12.1 + +- [#2349](https://github.com/sanic-org/sanic/pull/2349) Only display MOTD on startup +- [#2354](https://github.com/sanic-org/sanic/pull/2354) Ignore name argument in Python 3.7 +- [#2355](https://github.com/sanic-org/sanic/pull/2355) Add config.update support for all config values + ## Version 21.12.0 ### Features diff --git a/sanic/app.py b/sanic/app.py index 55c4920e..5f3a0438 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -3,28 +3,24 @@ from __future__ import annotations import asyncio import logging import logging.config -import os -import platform import re import sys from asyncio import ( AbstractEventLoop, CancelledError, - Protocol, Task, ensure_future, get_event_loop, + get_running_loop, wait_for, ) from asyncio.futures import Future from collections import defaultdict, deque +from contextlib import suppress from functools import partial -from importlib import import_module from inspect import isawaitable -from pathlib import Path from socket import socket -from ssl import SSLContext from traceback import format_exc from types import SimpleNamespace from typing import ( @@ -38,7 +34,6 @@ from typing import ( Dict, Iterable, List, - Literal, Optional, Set, Tuple, @@ -55,11 +50,8 @@ from sanic_routing.exceptions import ( # type: ignore ) from sanic_routing.route import Route # type: ignore -from sanic import reloader_helpers from sanic.application.ext import setup_ext -from sanic.application.logo import get_logo -from sanic.application.motd import MOTD -from sanic.application.state import ApplicationState, Mode +from sanic.application.state import ApplicationState, Mode, ServerStage from sanic.asgi import ASGIApp from sanic.base.root import BaseSanic from sanic.blueprint_group import BlueprintGroup @@ -73,18 +65,15 @@ from sanic.exceptions import ( URLBuildError, ) from sanic.handlers import ErrorHandler -from sanic.helpers import _default from sanic.http import Stage -from sanic.http.constants import HTTP -from sanic.http.tls import process_to_context from sanic.log import ( LOGGING_CONFIG_DEFAULTS, - Colors, deprecation, error_logger, logger, ) from sanic.mixins.listeners import ListenerEvent +from sanic.mixins.runner import RunnerMixin from sanic.models.futures import ( FutureException, FutureListener, @@ -99,10 +88,6 @@ from sanic.models.handler_types import Sanic as SanicVar from sanic.request import Request from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream from sanic.router import Router -from sanic.server import AsyncioServer, HttpProtocol -from sanic.server import Signal as ServerSignal -from sanic.server import serve, serve_multiple, serve_single, try_use_uvloop -from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta @@ -116,15 +101,13 @@ if TYPE_CHECKING: # no cov Extend = TypeVar("Extend") # type: ignore -if OS_IS_WINDOWS: +if OS_IS_WINDOWS: # no cov enable_windows_color_support() filterwarnings("once", category=DeprecationWarning) -SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") - -class Sanic(BaseSanic, metaclass=TouchUpMeta): +class Sanic(BaseSanic, RunnerMixin, metaclass=TouchUpMeta): """ The main application instance """ @@ -223,7 +206,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.blueprints: Dict[str, Blueprint] = {} self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() - self.debug = False self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} @@ -267,7 +249,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): Only supported when using the `app.run` method. """ - if not self.is_running and self.asgi is False: + if self.state.stage is ServerStage.STOPPED and self.asgi is False: raise SanicException( "Loop can only be retrieved after the app has started " "running. Not supported with `create_server` function" @@ -1054,286 +1036,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # Execution # -------------------------------------------------------------------- # - def make_coffee(self, *args, **kwargs): - self.state.coffee = True - self.run(*args, **kwargs) - - def run( - self, - host: Optional[str] = None, - port: Optional[int] = None, - *, - debug: bool = False, - auto_reload: Optional[bool] = None, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - workers: int = 1, - protocol: Optional[Type[Protocol]] = None, - backlog: int = 100, - register_sys_signals: bool = True, - access_log: Optional[bool] = None, - unix: Optional[str] = None, - loop: AbstractEventLoop = None, - reload_dir: Optional[Union[List[str], str]] = None, - noisy_exceptions: Optional[bool] = None, - motd: bool = True, - fast: bool = False, - verbosity: int = 0, - motd_display: Optional[Dict[str, str]] = None, - version: HTTP = HTTP.VERSION_1, - ) -> None: - """ - Run the HTTP Server and listen until keyboard interrupt or term - signal. On termination, drain connections before closing. - - :param host: Address to host on - :type host: str - :param port: Port to host on - :type port: int - :param debug: Enables debug output (slows server) - :type debug: bool - :param auto_reload: Reload app whenever its source code is changed. - Enabled by default in debug mode. - :type auto_relaod: bool - :param ssl: SSLContext, or location of certificate and key - for SSL encryption of worker(s) - :type ssl: str, dict, SSLContext or list - :param sock: Socket for the server to accept connections from - :type sock: socket - :param workers: Number of processes received before it is respected - :type workers: int - :param protocol: Subclass of asyncio Protocol class - :type protocol: type[Protocol] - :param backlog: a number of unaccepted connections that the system - will allow before refusing new connections - :type backlog: int - :param register_sys_signals: Register SIG* events - :type register_sys_signals: bool - :param access_log: Enables writing access logs (slows server) - :type access_log: bool - :param unix: Unix socket to listen on instead of TCP port - :type unix: str - :param noisy_exceptions: Log exceptions that are normally considered - to be quiet/silent - :type noisy_exceptions: bool - :return: Nothing - """ - self.state.verbosity = verbosity - - if fast and workers != 1: - raise RuntimeError("You cannot use both fast=True and workers=X") - - if motd_display: - self.config.MOTD_DISPLAY.update(motd_display) - - if reload_dir: - if isinstance(reload_dir, str): - reload_dir = [reload_dir] - - for directory in reload_dir: - direc = Path(directory) - if not direc.is_dir(): - logger.warning( - f"Directory {directory} could not be located" - ) - self.state.reload_dirs.add(Path(directory)) - - if loop is not None: - raise TypeError( - "loop is not a valid argument. To use an existing loop, " - "change to create_server().\nSee more: " - "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" - "#asynchronous-support" - ) - - if auto_reload or auto_reload is None and debug: - auto_reload = True - if os.environ.get("SANIC_SERVER_RUNNING") != "true": - return reloader_helpers.watchdog(1.0, self) - - if sock is None: - host, port = host or "127.0.0.1", port or 8000 - - if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) - - # Set explicitly passed configuration values - for attribute, value in { - "ACCESS_LOG": access_log, - "AUTO_RELOAD": auto_reload, - "MOTD": motd, - "NOISY_EXCEPTIONS": noisy_exceptions, - }.items(): - if value is not None: - setattr(self.config, attribute, value) - - if fast: - self.state.fast = True - try: - workers = len(os.sched_getaffinity(0)) - except AttributeError: - workers = os.cpu_count() or 1 - - server_settings = self._helper( - host=host, - port=port, - debug=debug, - ssl=ssl, - sock=sock, - unix=unix, - workers=workers, - protocol=protocol, - backlog=backlog, - register_sys_signals=register_sys_signals, - version=version, - ) - - if self.config.USE_UVLOOP is True or ( - self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS - ): - try_use_uvloop() - - try: - self.is_running = True - self.is_stopping = False - if workers > 1 and os.name != "posix": - logger.warn( - f"Multiprocessing is currently not supported on {os.name}," - " using workers=1 instead" - ) - workers = 1 - if workers == 1: - serve_single(server_settings) - else: - serve_multiple(server_settings, workers) - except BaseException: - error_logger.exception( - "Experienced exception while trying to serve" - ) - raise - finally: - self.is_running = False - logger.info("Server Stopped") - print("END OF RUN") - - def stop(self): - """ - This kills the Sanic - """ - if not self.is_stopping: - self.shutdown_tasks(timeout=0) - self.is_stopping = True - get_event_loop().stop() - - async def create_server( - self, - host: Optional[str] = None, - port: Optional[int] = None, - *, - debug: bool = False, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - protocol: Type[Protocol] = None, - backlog: int = 100, - access_log: Optional[bool] = None, - unix: Optional[str] = None, - return_asyncio_server: bool = False, - asyncio_server_kwargs: Dict[str, Any] = None, - noisy_exceptions: Optional[bool] = None, - ) -> Optional[AsyncioServer]: - """ - Asynchronous version of :func:`run`. - - This method will take care of the operations necessary to invoke - the *before_start* events via :func:`trigger_events` method invocation - before starting the *sanic* app in Async mode. - - .. note:: - This does not support multiprocessing and is not the preferred - way to run a :class:`Sanic` application. - - :param host: Address to host on - :type host: str - :param port: Port to host on - :type port: int - :param debug: Enables debug output (slows server) - :type debug: bool - :param ssl: SSLContext, or location of certificate and key - for SSL encryption of worker(s) - :type ssl: SSLContext or dict - :param sock: Socket for the server to accept connections from - :type sock: socket - :param protocol: Subclass of asyncio Protocol class - :type protocol: type[Protocol] - :param backlog: a number of unaccepted connections that the system - will allow before refusing new connections - :type backlog: int - :param access_log: Enables writing access logs (slows server) - :type access_log: bool - :param return_asyncio_server: flag that defines whether there's a need - to return asyncio.Server or - start it serving right away - :type return_asyncio_server: bool - :param asyncio_server_kwargs: key-value arguments for - asyncio/uvloop create_server method - :type asyncio_server_kwargs: dict - :param noisy_exceptions: Log exceptions that are normally considered - to be quiet/silent - :type noisy_exceptions: bool - :return: AsyncioServer if return_asyncio_server is true, else Nothing - """ - - if sock is None: - host, port = host or "127.0.0.1", port or 8000 - - if protocol is None: - protocol = ( - WebSocketProtocol if self.websocket_enabled else HttpProtocol - ) - - # Set explicitly passed configuration values - for attribute, value in { - "ACCESS_LOG": access_log, - "NOISY_EXCEPTIONS": noisy_exceptions, - }.items(): - if value is not None: - setattr(self.config, attribute, value) - - server_settings = self._helper( - host=host, - port=port, - debug=debug, - ssl=ssl, - sock=sock, - unix=unix, - loop=get_event_loop(), - protocol=protocol, - backlog=backlog, - run_async=return_asyncio_server, - ) - - if self.config.USE_UVLOOP is not _default: - error_logger.warning( - "You are trying to change the uvloop configuration, but " - "this is only effective when using the run(...) method. " - "When using the create_server(...) method Sanic will use " - "the already existing loop." - ) - - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - if main_start or main_stop: - logger.warning( - "Listener events for the main process are not available " - "with create_server()" - ) - - return await serve( - asyncio_server_kwargs=asyncio_server_kwargs, **server_settings - ) - async def _run_request_middleware( self, request, request_name=None ): # no cov @@ -1417,104 +1119,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): break return response - def _helper( - self, - host: Optional[str] = None, - port: Optional[int] = None, - debug: bool = False, - ssl: Union[None, SSLContext, dict, str, list, tuple] = None, - sock: Optional[socket] = None, - unix: Optional[str] = None, - workers: int = 1, - loop: AbstractEventLoop = None, - protocol: Type[Protocol] = HttpProtocol, - backlog: int = 100, - register_sys_signals: bool = True, - run_async: bool = False, - version: Union[HTTP, Literal[1], Literal[3]] = HTTP.VERSION_1, - ): - """Helper function used by `run` and `create_server`.""" - if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: - raise ValueError( - "PROXIES_COUNT cannot be negative. " - "https://sanic.readthedocs.io/en/latest/sanic/config.html" - "#proxy-configuration" - ) - - if isinstance(version, int): - version = HTTP(version) - ssl = process_to_context(ssl) - - self.debug = debug - self.state.host = host - self.state.port = port - self.state.workers = workers - self.state.ssl = ssl - self.state.unix = unix - self.state.sock = sock - - server_settings = { - "protocol": protocol, - "host": host, - "port": port, - "sock": sock, - "unix": unix, - "ssl": ssl, - "app": self, - "signal": ServerSignal(), - "loop": loop, - "register_sys_signals": register_sys_signals, - "backlog": backlog, - "version": version, - } - - self.motd(self.serve_location) - - if sys.stdout.isatty() and not self.state.is_debug: - error_logger.warning( - f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. " - "Consider using '--debug' or '--dev' while actively " - f"developing your application.{Colors.END}" - ) - - # Register start/stop events - for event_name, settings_name, reverse in ( - ("main_process_start", "main_start", False), - ("main_process_stop", "main_stop", True), - ): - listeners = self.listeners[event_name].copy() - if reverse: - listeners.reverse() - # Prepend sanic to the arguments when listeners are triggered - listeners = [partial(listener, self) for listener in listeners] - server_settings[settings_name] = listeners # type: ignore - - if run_async: - server_settings["run_async"] = True - - return server_settings - - @property - def serve_location(self) -> str: - serve_location = "" - proto = "http" - if self.state.ssl is not None: - proto = "https" - if self.state.unix: - serve_location = f"{self.state.unix} {proto}://..." - elif self.state.sock: - serve_location = f"{self.state.sock.getsockname()} {proto}://..." - elif self.state.host and self.state.port: - # colon(:) is legal for a host only in an ipv6 address - display_host = ( - f"[{self.state.host}]" - if ":" in self.state.host - else self.state.host - ) - serve_location = f"{proto}://{display_host}:{self.state.port}" - - return serve_location - def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) @@ -1561,10 +1165,20 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): name: Optional[str] = None, register: bool = True, ) -> Task: - prepped = cls._prep_task(task, app, loop) - task = loop.create_task(prepped, name=name) + if not isinstance(task, Future): + prepped = cls._prep_task(task, app, loop) + if sys.version_info < (3, 8): # no cov + task = loop.create_task(prepped) + if name: + error_logger.warning( + "Cannot set a name for a task when using Python 3.7. " + "Your task will be created without a name." + ) + task.get_name = lambda: name + else: + task = loop.create_task(prepped, name=name) - if name and register: + if name and register and sys.version_info > (3, 7): app._task_registry[name] = task return task @@ -1598,12 +1212,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): :param task: future, couroutine or awaitable """ - if name and sys.version_info == (3, 7): - name = None - error_logger.warning( - "Cannot set a name for a task when using Python 3.7. Your " - "task will be created without a name." - ) try: loop = self.loop # Will raise SanicError if loop is not started return self._loop_add_task( @@ -1626,10 +1234,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): def get_task( self, name: str, *, raise_exception: bool = True ) -> Optional[Task]: - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) try: return self._task_registry[name] except KeyError: @@ -1646,17 +1250,13 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): *, raise_exception: bool = True, ) -> None: - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) task = self.get_task(name, raise_exception=raise_exception) if task and not task.cancelled(): args: Tuple[str, ...] = () if msg: if sys.version_info >= (3, 9): args = (msg,) - else: + else: # no cov raise RuntimeError( "Cancelling a task with a message is only supported " "on Python 3.9+." @@ -1668,10 +1268,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): ... def purge_tasks(self): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) for task in self.tasks: if task.done() or task.cancelled(): name = task.get_name() @@ -1684,27 +1280,22 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): def shutdown_tasks( self, timeout: Optional[float] = None, increment: float = 0.1 ): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) for task in self.tasks: - task.cancel() + if task.get_name() != "RunServer": + task.cancel() if timeout is None: timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT while len(self._task_registry) and timeout: - self.loop.run_until_complete(asyncio.sleep(increment)) + with suppress(RuntimeError): + running_loop = get_running_loop() + running_loop.run_until_complete(asyncio.sleep(increment)) self.purge_tasks() timeout -= increment @property def tasks(self): - if sys.version_info == (3, 7): - raise RuntimeError( - "This feature is only supported on using Python 3.8+." - ) return iter(self._task_registry.values()) # -------------------------------------------------------------------- # @@ -1718,7 +1309,8 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): details: https://asgi.readthedocs.io/en/latest """ self.asgi = True - self.motd("") + if scope["type"] == "lifespan": + self.motd("") self._asgi_app = await ASGIApp.create(self, scope, receive, send) asgi_app = self._asgi_app await asgi_app() @@ -1753,6 +1345,13 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): @debug.setter def debug(self, value: bool): + deprecation( + "Setting the value of a Sanic application's debug value directly " + "is deprecated and will be removed in v22.9. Please set it using " + "the CLI, app.run, app.prepare, or directly set " + "app.state.mode to Mode.DEBUG.", + 22.9, + ) mode = Mode.DEBUG if value else Mode.PRODUCTION self.state.mode = mode @@ -1770,80 +1369,60 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): @property def is_running(self): + deprecation( + "Use of the is_running property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) return self.state.is_running @is_running.setter def is_running(self, value: bool): + deprecation( + "Use of the is_running property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) self.state.is_running = value @property def is_stopping(self): + deprecation( + "Use of the is_stopping property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) return self.state.is_stopping @is_stopping.setter def is_stopping(self, value: bool): + deprecation( + "Use of the is_stopping property is no longer used by Sanic " + "internally. The property is now deprecated and will be removed " + "in version 22.9. You may continue to set the property for your " + "own needs until that time. If you would like to check whether " + "the application is operational, please use app.state.stage. More " + "information is available at ___.", + 22.9, + ) self.state.is_stopping = value @property def reload_dirs(self): return self.state.reload_dirs - def motd(self, serve_location): - if self.config.MOTD: - mode = [f"{self.state.mode},"] - if self.state.fast: - mode.append("goin' fast") - if self.state.asgi: - mode.append("ASGI") - else: - if self.state.workers == 1: - mode.append("single worker") - else: - mode.append(f"w/ {self.state.workers} workers") - - display = { - "mode": " ".join(mode), - "server": self.state.server, - "python": platform.python_version(), - "platform": platform.platform(), - } - extra = {} - if self.config.AUTO_RELOAD: - reload_display = "enabled" - if self.state.reload_dirs: - reload_display += ", ".join( - [ - "", - *( - str(path.absolute()) - for path in self.state.reload_dirs - ), - ] - ) - display["auto-reload"] = reload_display - - packages = [] - for package_name in SANIC_PACKAGES: - module_name = package_name.replace("-", "_") - try: - module = import_module(module_name) - packages.append(f"{package_name}=={module.__version__}") - except ImportError: - ... - - if packages: - display["packages"] = ", ".join(packages) - - if self.config.MOTD_DISPLAY: - extra.update(self.config.MOTD_DISPLAY) - - logo = ( - get_logo(coffee=self.state.coffee) - if self.config.LOGO == "" or self.config.LOGO is True - else self.config.LOGO - ) - MOTD.output(logo, serve_location, display, extra) - @property def ext(self) -> Extend: if not hasattr(self, "_ext"): @@ -1941,7 +1520,6 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): async def _startup(self): self._future_registry.clear() - # Startup Sanic Extensions if not hasattr(self, "_ext"): setup_ext(self) if hasattr(self, "_ext"): @@ -1964,8 +1542,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): self.__class__._uvloop_setting = self.config.USE_UVLOOP # Startup time optimizations - ErrorHandler.finalize(self.error_handler, config=self.config) - TouchUp.run(self) + if self.state.primary: + # TODO: + # - Raise warning if secondary apps have error handler config + ErrorHandler.finalize(self.error_handler, config=self.config) + TouchUp.run(self) self.state.is_started = True diff --git a/sanic/application/constants.py b/sanic/application/constants.py new file mode 100644 index 00000000..9d46cb8e --- /dev/null +++ b/sanic/application/constants.py @@ -0,0 +1,23 @@ +from enum import Enum, IntEnum, auto + + +class StrEnum(str, Enum): + def _generate_next_value_(name: str, *args) -> str: # type: ignore + return name.lower() + + +class Server(StrEnum): + SANIC = auto() + ASGI = auto() + GUNICORN = auto() + + +class Mode(StrEnum): + PRODUCTION = auto() + DEBUG = auto() + + +class ServerStage(IntEnum): + STOPPED = auto() + PARTIAL = auto() + SERVING = auto() diff --git a/sanic/application/state.py b/sanic/application/state.py index 0345dad5..4dad18e8 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -3,33 +3,25 @@ from __future__ import annotations import logging from dataclasses import dataclass, field -from enum import Enum, auto from pathlib import Path from socket import socket from ssl import SSLContext -from typing import TYPE_CHECKING, Any, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from sanic.application.constants import Mode, Server, ServerStage from sanic.log import logger +from sanic.server.async_server import AsyncioServer -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic import Sanic -class StrEnum(str, Enum): - def _generate_next_value_(name: str, *args) -> str: # type: ignore - return name.lower() - - -class Server(StrEnum): - SANIC = auto() - ASGI = auto() - GUNICORN = auto() - - -class Mode(StrEnum): - PRODUCTION = auto() - DEBUG = auto() +@dataclass +class ApplicationServerInfo: + settings: Dict[str, Any] + stage: ServerStage = field(default=ServerStage.STOPPED) + server: Optional[AsyncioServer] = field(default=None) @dataclass @@ -45,12 +37,15 @@ class ApplicationState: unix: Optional[str] = field(default=None) mode: Mode = field(default=Mode.PRODUCTION) reload_dirs: Set[Path] = field(default_factory=set) + auto_reload: bool = field(default=False) server: Server = field(default=Server.SANIC) is_running: bool = field(default=False) is_started: bool = field(default=False) is_stopping: bool = field(default=False) verbosity: int = field(default=0) workers: int = field(default=0) + primary: bool = field(default=True) + server_info: List[ApplicationServerInfo] = field(default_factory=list) # This property relates to the ApplicationState instance and should # not be changed except in the __post_init__ method @@ -77,3 +72,17 @@ class ApplicationState: @property def is_debug(self): return self.mode is Mode.DEBUG + + @property + def stage(self) -> ServerStage: + if not self.server_info: + return ServerStage.STOPPED + + if all(info.stage is ServerStage.SERVING for info in self.server_info): + return ServerStage.SERVING + elif any( + info.stage is ServerStage.SERVING for info in self.server_info + ): + return ServerStage.PARTIAL + + return ServerStage.STOPPED diff --git a/sanic/cli/app.py b/sanic/cli/app.py index 9c5e55d2..83410bb3 100644 --- a/sanic/cli/app.py +++ b/sanic/cli/app.py @@ -80,13 +80,6 @@ Or, a path to a directory to run as a simple HTTP server: error_logger.exception("Failed to run app") def _precheck(self): - if self.args.debug and self.main_process: - error_logger.warning( - "Starting in v22.3, --debug will no " - "longer automatically run the auto-reloader.\n Switch to " - "--dev to continue using that functionality." - ) - # # Custom TLS mismatch handling for better diagnostics if self.main_process and ( # one of cert/key missing @@ -177,16 +170,11 @@ Or, a path to a directory to run as a simple HTTP server: "version": version, } - if self.args.auto_reload: - kwargs["auto_reload"] = True + for maybe_arg in ("auto_reload", "dev"): + if getattr(self.args, maybe_arg, False): + kwargs[maybe_arg] = True if self.args.path: - if self.args.auto_reload or self.args.debug: - kwargs["reload_dir"] = self.args.path - else: - error_logger.warning( - "Ignoring '--reload-dir' since auto reloading was not " - "enabled. If you would like to watch directories for " - "changes, consider using --debug or --auto-reload." - ) + kwargs["auto_reload"] = True + kwargs["reload_dir"] = self.args.path return kwargs diff --git a/sanic/cli/arguments.py b/sanic/cli/arguments.py index 05e228f4..b56132cc 100644 --- a/sanic/cli/arguments.py +++ b/sanic/cli/arguments.py @@ -204,18 +204,10 @@ class DevelopmentGroup(Group): "--debug", dest="debug", action="store_true", - help="Run the server in debug mode", - ) - self.container.add_argument( - "-d", - "--dev", - dest="debug", - action="store_true", help=( - "Currently is an alias for --debug. But starting in v22.3, \n" - "--debug will no longer automatically trigger auto_restart. \n" - "However, --dev will continue, effectively making it the \n" - "same as debug + auto_reload." + "Run the server in DEBUG mode. It includes DEBUG logging,\n" + "additional context on exceptions, and other settings\n" + "not-safe for PRODUCTION, but helpful for debugging problems." ), ) self.container.add_argument( @@ -236,6 +228,13 @@ class DevelopmentGroup(Group): action="append", help="Extra directories to watch and reload on changes", ) + self.container.add_argument( + "-d", + "--dev", + dest="dev", + action="store_true", + help=("debug + auto reload."), + ) class OutputGroup(Group): diff --git a/sanic/config.py b/sanic/config.py index c99d9abe..7e519fe9 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -130,22 +130,27 @@ class Config(dict, metaclass=DescriptorMeta): raise AttributeError(f"Config has no '{ke.args[0]}'") def __setattr__(self, attr, value) -> None: - if attr in self.__class__.__setters__: - try: - super().__setattr__(attr, value) - except AttributeError: - ... - else: - return None self.update({attr: value}) def __setitem__(self, attr, value) -> None: self.update({attr: value}) def update(self, *other, **kwargs) -> None: - other_mapping = {k: v for item in other for k, v in dict(item).items()} - super().update(*other, **kwargs) - for attr, value in {**other_mapping, **kwargs}.items(): + kwargs.update({k: v for item in other for k, v in dict(item).items()}) + setters: Dict[str, Any] = { + k: kwargs.pop(k) + for k in {**kwargs}.keys() + if k in self.__class__.__setters__ + } + + for key, value in setters.items(): + try: + super().__setattr__(key, value) + except AttributeError: + ... + + super().update(**kwargs) + for attr, value in {**setters, **kwargs}.items(): self._post_set(attr, value) def _post_set(self, attr, value) -> None: diff --git a/sanic/handlers.py b/sanic/handlers.py index 44ff77a7..b3f01566 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,13 +1,12 @@ from __future__ import annotations -from inspect import signature from typing import Dict, List, Optional, Tuple, Type, Union from sanic.config import Config from sanic.errorpages import ( DEFAULT_FORMAT, BaseRenderer, - HTMLRenderer, + TextRenderer, exception_response, ) from sanic.exceptions import ( @@ -35,13 +34,11 @@ class ErrorHandler: """ - # Beginning in v22.3, the base renderer will be TextRenderer def __init__( self, fallback: Union[str, Default] = _default, - base: Type[BaseRenderer] = HTMLRenderer, + base: Type[BaseRenderer] = TextRenderer, ): - self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] self.cached_handlers: Dict[ Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] ] = {} @@ -53,14 +50,14 @@ class ErrorHandler: self._warn_fallback_deprecation() @property - def fallback(self): + def fallback(self): # no cov # This is for backwards compat and can be removed in v22.6 if self._fallback is _default: return DEFAULT_FORMAT return self._fallback @fallback.setter - def fallback(self, value: str): + def fallback(self, value: str): # no cov self._warn_fallback_deprecation() if not isinstance(value, str): raise SanicException( @@ -95,8 +92,8 @@ class ErrorHandler: def finalize( cls, error_handler: ErrorHandler, + config: Config, fallback: Optional[str] = None, - config: Optional[Config] = None, ): if fallback: deprecation( @@ -107,14 +104,10 @@ class ErrorHandler: 22.6, ) - if config is None: - deprecation( - "Starting in v22.3, config will be a required argument " - "for ErrorHandler.finalize().", - 22.3, - ) + if not fallback: + fallback = config.FALLBACK_ERROR_FORMAT - if fallback and fallback != DEFAULT_FORMAT: + if fallback != DEFAULT_FORMAT: if error_handler._fallback is not _default: error_logger.warning( f"Setting the fallback value to {fallback}. This changes " @@ -128,27 +121,9 @@ class ErrorHandler: f"Error handler is non-conforming: {type(error_handler)}" ) - sig = signature(error_handler.lookup) - if len(sig.parameters) == 1: - deprecation( - "You are using a deprecated error handler. The lookup " - "method should accept two positional parameters: " - "(exception, route_name: Optional[str]). " - "Until you upgrade your ErrorHandler.lookup, Blueprint " - "specific exceptions will not work properly. Beginning " - "in v22.3, the legacy style lookup method will not " - "work at all.", - 22.3, - ) - legacy_lookup = error_handler._legacy_lookup - error_handler._lookup = legacy_lookup # type: ignore - def _full_lookup(self, exception, route_name: Optional[str] = None): return self.lookup(exception, route_name) - def _legacy_lookup(self, exception, route_name: Optional[str] = None): - return self.lookup(exception) - def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -162,9 +137,6 @@ class ErrorHandler: :return: None """ - # self.handlers is deprecated and will be removed in version 22.3 - self.handlers.append((exception, handler)) - if route_names: for route in route_names: self.cached_handlers[(exception, route)] = handler @@ -236,7 +208,7 @@ class ErrorHandler: except Exception: try: url = repr(request.url) - except AttributeError: + except AttributeError: # no cov url = "unknown" response_message = ( "Exception raised in exception handler " '"%s" for uri: %s' @@ -281,7 +253,7 @@ class ErrorHandler: if quiet is False or noisy is True: try: url = repr(request.url) - except AttributeError: + except AttributeError: # no cov url = "unknown" error_logger.exception( diff --git a/sanic/headers.py b/sanic/headers.py index b744974c..b4457653 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -2,7 +2,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import unquote from sanic.exceptions import InvalidHeader @@ -394,3 +394,17 @@ def parse_accept(accept: str) -> AcceptContainer: return AcceptContainer( sorted(accept_list, key=_sort_accept_value, reverse=True) ) + + +def parse_credentials( + header: Optional[str], + prefixes: Union[List, Tuple, Set] = None, +) -> Tuple[Optional[str], Optional[str]]: + """Parses any header with the aim to retrieve any credentials from it.""" + if not prefixes or not isinstance(prefixes, (list, tuple, set)): + prefixes = ("Basic", "Bearer", "Token") + if header is not None: + for prefix in prefixes: + if prefix in header: + return prefix, header.partition(prefix)[-1].strip() + return None, header diff --git a/sanic/http/http3.py b/sanic/http/http3.py index ce0d0797..3d1b2b56 100644 --- a/sanic/http/http3.py +++ b/sanic/http/http3.py @@ -3,9 +3,8 @@ from __future__ import annotations import asyncio from abc import ABC -from ast import Mod from ssl import SSLContext -from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from aioquic.h0.connection import H0_ALPN, H0Connection from aioquic.h3.connection import H3_ALPN, H3Connection diff --git a/sanic/http/tls.py b/sanic/http/tls.py index 844fe435..232e1b62 100644 --- a/sanic/http/tls.py +++ b/sanic/http/tls.py @@ -11,10 +11,10 @@ from ssl import SSLContext from tempfile import mkdtemp from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union -from sanic.application.state import Mode +from sanic.application.constants import Mode from sanic.constants import DEFAULT_LOCAL_TLS_CERT, DEFAULT_LOCAL_TLS_KEY from sanic.exceptions import SanicException -from sanic.helpers import Default, _default +from sanic.helpers import Default from sanic.log import logger diff --git a/sanic/log.py b/sanic/log.py index 000b3ed8..4b3b960c 100644 --- a/sanic/log.py +++ b/sanic/log.py @@ -6,7 +6,7 @@ from typing import Any, Dict from warnings import warn -LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( +LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( # no cov version=1, disable_existing_loggers=False, loggers={ @@ -57,7 +57,7 @@ LOGGING_CONFIG_DEFAULTS: Dict[str, Any] = dict( ) -class Colors(str, Enum): +class Colors(str, Enum): # no cov END = "\033[0m" BLUE = "\033[01;34m" GREEN = "\033[01;32m" @@ -65,23 +65,23 @@ class Colors(str, Enum): RED = "\033[01;31m" -logger = logging.getLogger("sanic.root") +logger = logging.getLogger("sanic.root") # no cov """ General Sanic logger """ -error_logger = logging.getLogger("sanic.error") +error_logger = logging.getLogger("sanic.error") # no cov """ Logger used by Sanic for error logging """ -access_logger = logging.getLogger("sanic.access") +access_logger = logging.getLogger("sanic.access") # no cov """ Logger used by Sanic for access logging """ -def deprecation(message: str, version: float): +def deprecation(message: str, version: float): # no cov version_info = f"[DEPRECATION v{version}] " if sys.stdout.isatty(): version_info = f"{Colors.RED}{version_info}" diff --git a/sanic/mixins/runner.py b/sanic/mixins/runner.py new file mode 100644 index 00000000..39aa42fe --- /dev/null +++ b/sanic/mixins/runner.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +import os +import platform +import sys + +from asyncio import ( + AbstractEventLoop, + CancelledError, + Protocol, + all_tasks, + get_event_loop, + get_running_loop, +) +from contextlib import suppress +from functools import partial +from importlib import import_module +from pathlib import Path +from socket import socket +from ssl import SSLContext +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Set, + Type, + Union, +) + +from sanic import reloader_helpers +from sanic.application.logo import get_logo +from sanic.application.motd import MOTD +from sanic.application.state import ApplicationServerInfo, Mode, ServerStage +from sanic.base.meta import SanicMeta +from sanic.compat import OS_IS_WINDOWS +from sanic.helpers import _default +from sanic.http.constants import HTTP +from sanic.http.tls import process_to_context +from sanic.log import Colors, error_logger, logger +from sanic.models.handler_types import ListenerType +from sanic.server import Signal as ServerSignal +from sanic.server import try_use_uvloop +from sanic.server.async_server import AsyncioServer +from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.protocols.websocket_protocol import WebSocketProtocol +from sanic.server.runners import serve, serve_multiple, serve_single + + +if TYPE_CHECKING: # no cov + from sanic import Sanic + from sanic.application.state import ApplicationState + from sanic.config import Config + +SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") +HTTPVersion = Union[HTTP, Literal[1], Literal[3]] + + +class RunnerMixin(metaclass=SanicMeta): + _app_registry: Dict[str, Sanic] + config: Config + listeners: Dict[str, List[ListenerType[Any]]] + state: ApplicationState + websocket_enabled: bool + + def make_coffee(self, *args, **kwargs): + self.state.coffee = True + self.run(*args, **kwargs) + + def run( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + dev: bool = False, + debug: bool = False, + auto_reload: Optional[bool] = None, + version: HTTPVersion = HTTP.VERSION_1, + ssl: Union[None, SSLContext, dict, str, list, tuple] = None, + sock: Optional[socket] = None, + workers: int = 1, + protocol: Optional[Type[Protocol]] = None, + backlog: int = 100, + register_sys_signals: bool = True, + access_log: Optional[bool] = None, + unix: Optional[str] = None, + loop: AbstractEventLoop = None, + reload_dir: Optional[Union[List[str], str]] = None, + noisy_exceptions: Optional[bool] = None, + motd: bool = True, + fast: bool = False, + verbosity: int = 0, + motd_display: Optional[Dict[str, str]] = None, + ) -> None: + """ + Run the HTTP Server and listen until keyboard interrupt or term + signal. On termination, drain connections before closing. + + :param host: Address to host on + :type host: str + :param port: Port to host on + :type port: int + :param debug: Enables debug output (slows server) + :type debug: bool + :param auto_reload: Reload app whenever its source code is changed. + Enabled by default in debug mode. + :type auto_relaod: bool + :param ssl: SSLContext, or location of certificate and key + for SSL encryption of worker(s) + :type ssl: str, dict, SSLContext or list + :param sock: Socket for the server to accept connections from + :type sock: socket + :param workers: Number of processes received before it is respected + :type workers: int + :param protocol: Subclass of asyncio Protocol class + :type protocol: type[Protocol] + :param backlog: a number of unaccepted connections that the system + will allow before refusing new connections + :type backlog: int + :param register_sys_signals: Register SIG* events + :type register_sys_signals: bool + :param access_log: Enables writing access logs (slows server) + :type access_log: bool + :param unix: Unix socket to listen on instead of TCP port + :type unix: str + :param noisy_exceptions: Log exceptions that are normally considered + to be quiet/silent + :type noisy_exceptions: bool + :return: Nothing + """ + self.prepare( + host=host, + port=port, + dev=dev, + debug=debug, + auto_reload=auto_reload, + version=version, + ssl=ssl, + sock=sock, + workers=workers, + protocol=protocol, + backlog=backlog, + register_sys_signals=register_sys_signals, + access_log=access_log, + unix=unix, + loop=loop, + reload_dir=reload_dir, + noisy_exceptions=noisy_exceptions, + motd=motd, + fast=fast, + verbosity=verbosity, + motd_display=motd_display, + ) + + self.__class__.serve(primary=self) # type: ignore + + def prepare( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + dev: bool = False, + debug: bool = False, + auto_reload: Optional[bool] = None, + version: HTTPVersion = HTTP.VERSION_1, + ssl: Union[None, SSLContext, dict, str, list, tuple] = None, + sock: Optional[socket] = None, + workers: int = 1, + protocol: Optional[Type[Protocol]] = None, + backlog: int = 100, + register_sys_signals: bool = True, + access_log: Optional[bool] = None, + unix: Optional[str] = None, + loop: AbstractEventLoop = None, + reload_dir: Optional[Union[List[str], str]] = None, + noisy_exceptions: Optional[bool] = None, + motd: bool = True, + fast: bool = False, + verbosity: int = 0, + motd_display: Optional[Dict[str, str]] = None, + ) -> None: + if dev: + debug = True + auto_reload = True + + self.state.verbosity = verbosity + if not self.state.auto_reload: + self.state.auto_reload = bool(auto_reload) + + if fast and workers != 1: + raise RuntimeError("You cannot use both fast=True and workers=X") + + if motd_display: + self.config.MOTD_DISPLAY.update(motd_display) + + if reload_dir: + if isinstance(reload_dir, str): + reload_dir = [reload_dir] + + for directory in reload_dir: + direc = Path(directory) + if not direc.is_dir(): + logger.warning( + f"Directory {directory} could not be located" + ) + self.state.reload_dirs.add(Path(directory)) + + if loop is not None: + raise TypeError( + "loop is not a valid argument. To use an existing loop, " + "change to create_server().\nSee more: " + "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" + "#asynchronous-support" + ) + + if ( + self.__class__.should_auto_reload() + and os.environ.get("SANIC_SERVER_RUNNING") != "true" + ): # no cov + return + + if sock is None: + host, port = host or "127.0.0.1", port or 8000 + + if protocol is None: + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) + + # Set explicitly passed configuration values + for attribute, value in { + "ACCESS_LOG": access_log, + "AUTO_RELOAD": auto_reload, + "MOTD": motd, + "NOISY_EXCEPTIONS": noisy_exceptions, + }.items(): + if value is not None: + setattr(self.config, attribute, value) + + if fast: + self.state.fast = True + try: + workers = len(os.sched_getaffinity(0)) + except AttributeError: # no cov + workers = os.cpu_count() or 1 + + server_settings = self._helper( + host=host, + port=port, + debug=debug, + version=version, + ssl=ssl, + sock=sock, + unix=unix, + workers=workers, + protocol=protocol, + backlog=backlog, + register_sys_signals=register_sys_signals, + ) + self.state.server_info.append( + ApplicationServerInfo(settings=server_settings) + ) + + if self.config.USE_UVLOOP is True or ( + self.config.USE_UVLOOP is _default and not OS_IS_WINDOWS + ): + try_use_uvloop() + + async def create_server( + self, + host: Optional[str] = None, + port: Optional[int] = None, + *, + debug: bool = False, + ssl: Union[None, SSLContext, dict, str, list, tuple] = None, + sock: Optional[socket] = None, + protocol: Type[Protocol] = None, + backlog: int = 100, + access_log: Optional[bool] = None, + unix: Optional[str] = None, + return_asyncio_server: bool = False, + asyncio_server_kwargs: Dict[str, Any] = None, + noisy_exceptions: Optional[bool] = None, + ) -> Optional[AsyncioServer]: + """ + Asynchronous version of :func:`run`. + + This method will take care of the operations necessary to invoke + the *before_start* events via :func:`trigger_events` method invocation + before starting the *sanic* app in Async mode. + + .. note:: + This does not support multiprocessing and is not the preferred + way to run a :class:`Sanic` application. + + :param host: Address to host on + :type host: str + :param port: Port to host on + :type port: int + :param debug: Enables debug output (slows server) + :type debug: bool + :param ssl: SSLContext, or location of certificate and key + for SSL encryption of worker(s) + :type ssl: SSLContext or dict + :param sock: Socket for the server to accept connections from + :type sock: socket + :param protocol: Subclass of asyncio Protocol class + :type protocol: type[Protocol] + :param backlog: a number of unaccepted connections that the system + will allow before refusing new connections + :type backlog: int + :param access_log: Enables writing access logs (slows server) + :type access_log: bool + :param return_asyncio_server: flag that defines whether there's a need + to return asyncio.Server or + start it serving right away + :type return_asyncio_server: bool + :param asyncio_server_kwargs: key-value arguments for + asyncio/uvloop create_server method + :type asyncio_server_kwargs: dict + :param noisy_exceptions: Log exceptions that are normally considered + to be quiet/silent + :type noisy_exceptions: bool + :return: AsyncioServer if return_asyncio_server is true, else Nothing + """ + + if sock is None: + host, port = host or "127.0.0.1", port or 8000 + + if protocol is None: + protocol = ( + WebSocketProtocol if self.websocket_enabled else HttpProtocol + ) + + # Set explicitly passed configuration values + for attribute, value in { + "ACCESS_LOG": access_log, + "NOISY_EXCEPTIONS": noisy_exceptions, + }.items(): + if value is not None: + setattr(self.config, attribute, value) + + server_settings = self._helper( + host=host, + port=port, + debug=debug, + ssl=ssl, + sock=sock, + unix=unix, + loop=get_event_loop(), + protocol=protocol, + backlog=backlog, + run_async=return_asyncio_server, + ) + + if self.config.USE_UVLOOP is not _default: + error_logger.warning( + "You are trying to change the uvloop configuration, but " + "this is only effective when using the run(...) method. " + "When using the create_server(...) method Sanic will use " + "the already existing loop." + ) + + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + if main_start or main_stop: + logger.warning( + "Listener events for the main process are not available " + "with create_server()" + ) + + return await serve( + asyncio_server_kwargs=asyncio_server_kwargs, **server_settings + ) + + def stop(self): + """ + This kills the Sanic + """ + if self.state.stage is not ServerStage.STOPPED: + self.shutdown_tasks(timeout=0) + for task in all_tasks(): + with suppress(AttributeError): + if task.get_name() == "RunServer": + task.cancel() + get_event_loop().stop() + + def _helper( + self, + host: Optional[str] = None, + port: Optional[int] = None, + debug: bool = False, + version: HTTPVersion = HTTP.VERSION_1, + ssl: Union[None, SSLContext, dict, str, list, tuple] = None, + sock: Optional[socket] = None, + unix: Optional[str] = None, + workers: int = 1, + loop: AbstractEventLoop = None, + protocol: Type[Protocol] = HttpProtocol, + backlog: int = 100, + register_sys_signals: bool = True, + run_async: bool = False, + ) -> Dict[str, Any]: + """Helper function used by `run` and `create_server`.""" + if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: + raise ValueError( + "PROXIES_COUNT cannot be negative. " + "https://sanic.readthedocs.io/en/latest/sanic/config.html" + "#proxy-configuration" + ) + + if isinstance(version, int): + version = HTTP(version) + + ssl = process_to_context(ssl) + + if not self.state.is_debug: + self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION + + self.state.host = host or "" + self.state.port = port or 0 + self.state.workers = workers + self.state.ssl = ssl + self.state.unix = unix + self.state.sock = sock + + server_settings = { + "protocol": protocol, + "host": host, + "port": port, + "version": version, + "sock": sock, + "unix": unix, + "ssl": ssl, + "app": self, + "signal": ServerSignal(), + "loop": loop, + "register_sys_signals": register_sys_signals, + "backlog": backlog, + } + + self.motd(self.serve_location) + + if sys.stdout.isatty() and not self.state.is_debug: + error_logger.warning( + f"{Colors.YELLOW}Sanic is running in PRODUCTION mode. " + "Consider using '--debug' or '--dev' while actively " + f"developing your application.{Colors.END}" + ) + + # Register start/stop events + for event_name, settings_name, reverse in ( + ("main_process_start", "main_start", False), + ("main_process_stop", "main_stop", True), + ): + listeners = self.listeners[event_name].copy() + if reverse: + listeners.reverse() + # Prepend sanic to the arguments when listeners are triggered + listeners = [partial(listener, self) for listener in listeners] + server_settings[settings_name] = listeners # type: ignore + + if run_async: + server_settings["run_async"] = True + + return server_settings + + def motd(self, serve_location): + if self.config.MOTD: + mode = [f"{self.state.mode},"] + if self.state.fast: + mode.append("goin' fast") + if self.state.asgi: + mode.append("ASGI") + else: + if self.state.workers == 1: + mode.append("single worker") + else: + mode.append(f"w/ {self.state.workers} workers") + + display = { + "mode": " ".join(mode), + "server": self.state.server, + "python": platform.python_version(), + "platform": platform.platform(), + } + extra = {} + if self.config.AUTO_RELOAD: + reload_display = "enabled" + if self.state.reload_dirs: + reload_display += ", ".join( + [ + "", + *( + str(path.absolute()) + for path in self.state.reload_dirs + ), + ] + ) + display["auto-reload"] = reload_display + + packages = [] + for package_name in SANIC_PACKAGES: + module_name = package_name.replace("-", "_") + try: + module = import_module(module_name) + packages.append(f"{package_name}=={module.__version__}") + except ImportError: + ... + + if packages: + display["packages"] = ", ".join(packages) + + if self.config.MOTD_DISPLAY: + extra.update(self.config.MOTD_DISPLAY) + + logo = ( + get_logo(coffee=self.state.coffee) + if self.config.LOGO == "" or self.config.LOGO is True + else self.config.LOGO + ) + + MOTD.output(logo, serve_location, display, extra) + + @property + def serve_location(self) -> str: + serve_location = "" + proto = "http" + if self.state.ssl is not None: + proto = "https" + if self.state.unix: + serve_location = f"{self.state.unix} {proto}://..." + elif self.state.sock: + serve_location = f"{self.state.sock.getsockname()} {proto}://..." + elif self.state.host and self.state.port: + # colon(:) is legal for a host only in an ipv6 address + display_host = ( + f"[{self.state.host}]" + if ":" in self.state.host + else self.state.host + ) + serve_location = f"{proto}://{display_host}:{self.state.port}" + + return serve_location + + @classmethod + def should_auto_reload(cls) -> bool: + return any(app.state.auto_reload for app in cls._app_registry.values()) + + @classmethod + def serve(cls, primary: Optional[Sanic] = None) -> None: + apps = list(cls._app_registry.values()) + + if not primary: + try: + primary = apps[0] + except IndexError: + raise RuntimeError("Did not find any applications.") + + # We want to run auto_reload if ANY of the applications have it enabled + if ( + cls.should_auto_reload() + and os.environ.get("SANIC_SERVER_RUNNING") != "true" + ): + reload_dirs: Set[Path] = primary.state.reload_dirs.union( + *(app.state.reload_dirs for app in apps) + ) + return reloader_helpers.watchdog(1.0, reload_dirs) + + # This exists primarily for unit testing + if not primary.state.server_info: # no cov + for app in apps: + app.state.server_info.clear() + return + + primary_server_info = primary.state.server_info[0] + primary.before_server_start(partial(primary._start_servers, apps=apps)) + + try: + primary_server_info.stage = ServerStage.SERVING + + if primary.state.workers > 1 and os.name != "posix": # no cov + logger.warn( + f"Multiprocessing is currently not supported on {os.name}," + " using workers=1 instead" + ) + primary.state.workers = 1 + if primary.state.workers == 1: + serve_single(primary_server_info.settings) + elif primary.state.workers == 0: + raise RuntimeError("Cannot serve with no workers") + else: + serve_multiple( + primary_server_info.settings, primary.state.workers + ) + except BaseException: + error_logger.exception( + "Experienced exception while trying to serve" + ) + raise + finally: + primary_server_info.stage = ServerStage.STOPPED + logger.info("Server Stopped") + for app in apps: + app.state.server_info.clear() + app.router.reset() + app.signal_router.reset() + + async def _start_servers( + self, + primary: Sanic, + _, + apps: List[Sanic], + ) -> None: + for app in apps: + if ( + app.name is not primary.name + and app.state.workers != primary.state.workers + and app.state.server_info + ): + message = ( + f"The primary application {repr(primary)} is running " + f"with {primary.state.workers} worker(s). All " + "application instances will run with the same number. " + f"You requested {repr(app)} to run with " + f"{app.state.workers} worker(s), which will be ignored " + "in favor of the primary application." + ) + if sys.stdout.isatty(): + message = "".join( + [ + Colors.YELLOW, + message, + Colors.END, + ] + ) + error_logger.warning(message, exc_info=True) + for server_info in app.state.server_info: + if server_info.stage is not ServerStage.SERVING: + app.state.primary = False + handlers = [ + *server_info.settings.pop("main_start", []), + *server_info.settings.pop("main_stop", []), + ] + if handlers: + error_logger.warning( + f"Sanic found {len(handlers)} listener(s) on " + "secondary applications attached to the main " + "process. These will be ignored since main " + "process listeners can only be attached to your " + "primary application: " + f"{repr(primary)}" + ) + + if not server_info.settings["loop"]: + server_info.settings["loop"] = get_running_loop() + + try: + server_info.server = await serve( + **server_info.settings, + run_async=True, + reuse_port=bool(primary.state.workers - 1), + ) + except OSError as e: # no cov + first_message = ( + "An OSError was detected on startup. " + "The encountered error was: " + ) + second_message = str(e) + if sys.stdout.isatty(): + message_parts = [ + Colors.YELLOW, + first_message, + Colors.RED, + second_message, + Colors.END, + ] + else: + message_parts = [first_message, second_message] + message = "".join(message_parts) + error_logger.warning(message, exc_info=True) + continue + primary.add_task( + self._run_server(app, server_info), name="RunServer" + ) + + async def _run_server( + self, + app: RunnerMixin, + server_info: ApplicationServerInfo, + ) -> None: + + try: + # We should never get to this point without a server + # This is primarily to keep mypy happy + if not server_info.server: # no cov + raise RuntimeError("Could not locate AsyncioServer") + if app.state.stage is ServerStage.STOPPED: + server_info.stage = ServerStage.SERVING + await server_info.server.startup() + await server_info.server.before_start() + await server_info.server.after_start() + await server_info.server.serve_forever() + except CancelledError: + # We should never get to this point without a server + # This is primarily to keep mypy happy + if not server_info.server: # no cov + raise RuntimeError("Could not locate AsyncioServer") + await server_info.server.before_stop() + await server_info.server.close() + await server_info.server.after_stop() + finally: + server_info.stage = ServerStage.STOPPED + server_info.server = None diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 57b755ee..74e0e8b2 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -13,7 +13,7 @@ ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGIReceive = Callable[[], Awaitable[ASGIMessage]] -class MockProtocol: +class MockProtocol: # no cov def __init__(self, transport: "MockTransport", loop): # This should be refactored when < 3.8 support is dropped self.transport = transport @@ -56,7 +56,7 @@ class MockProtocol: await self._not_paused.wait() -class MockTransport: +class MockTransport: # no cov _protocol: Optional[MockProtocol] def __init__( diff --git a/sanic/models/http_types.py b/sanic/models/http_types.py new file mode 100644 index 00000000..595eaf0e --- /dev/null +++ b/sanic/models/http_types.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from base64 import b64decode +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass() +class Credentials: + auth_type: Optional[str] + token: Optional[str] + _username: Optional[str] = field(default=None) + _password: Optional[str] = field(default=None) + + def __post_init__(self): + if self._auth_is_basic: + self._username, self._password = ( + b64decode(self.token.encode("utf-8")).decode().split(":") + ) + + @property + def username(self): + if not self._auth_is_basic: + raise AttributeError("Username is available for Basic Auth only") + return self._username + + @property + def password(self): + if not self._auth_is_basic: + raise AttributeError("Password is available for Basic Auth only") + return self._password + + @property + def _auth_is_basic(self) -> bool: + return self.auth_type == "Basic" diff --git a/sanic/models/server_types.py b/sanic/models/server_types.py index ad8872e1..ba9f2918 100644 --- a/sanic/models/server_types.py +++ b/sanic/models/server_types.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ssl import SSLObject from types import SimpleNamespace from typing import Any, Dict, Optional diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index 3c726edb..4111cc71 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -77,7 +77,7 @@ def _check_file(filename, mtimes): return need_reload -def watchdog(sleep_interval, app): +def watchdog(sleep_interval, reload_dirs): """Watch project files, restart worker process if a change happened. :param sleep_interval: interval in second. @@ -100,7 +100,7 @@ def watchdog(sleep_interval, app): changed = set() for filename in itertools.chain( _iter_module_files(), - *(d.glob("**/*") for d in app.reload_dirs), + *(d.glob("**/*") for d in reload_dirs), ): try: if _check_file(filename, mtimes): diff --git a/sanic/request.py b/sanic/request.py index ee535ac2..a1ce91c3 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -15,6 +15,8 @@ from typing import ( from sanic_routing.route import Route # type: ignore +from sanic.models.http_types import Credentials + if TYPE_CHECKING: # no cov from sanic.server import ConnInfo @@ -38,6 +40,7 @@ from sanic.headers import ( Options, parse_accept, parse_content_header, + parse_credentials, parse_forwarded, parse_host, parse_xforwarded, @@ -99,11 +102,13 @@ class Request: "method", "parsed_accept", "parsed_args", - "parsed_not_grouped_args", + "parsed_credentials", "parsed_files", "parsed_form", - "parsed_json", "parsed_forwarded", + "parsed_json", + "parsed_not_grouped_args", + "parsed_token", "raw_url", "responded", "request_middleware_started", @@ -123,6 +128,7 @@ class Request: app: Sanic, head: bytes = b"", ): + self.raw_url = url_bytes # TODO: Content-Encoding detection self._parsed_url = parse_url(url_bytes) @@ -142,9 +148,11 @@ class Request: self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None self.parsed_accept: Optional[AcceptContainer] = None + self.parsed_credentials: Optional[Credentials] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None + self.parsed_token: Optional[str] = None self.parsed_args: DefaultDict[ Tuple[bool, bool, str, str], RequestParameters ] = defaultdict(RequestParameters) @@ -336,20 +344,41 @@ class Request: return self.parsed_accept @property - def token(self): + def token(self) -> Optional[str]: """Attempt to return the auth header token. :return: token related to request """ - prefixes = ("Bearer", "Token") - auth_header = self.headers.getone("authorization", None) + if self.parsed_token is None: + prefixes = ("Bearer", "Token") + _, token = parse_credentials( + self.headers.getone("authorization", None), prefixes + ) + self.parsed_token = token + return self.parsed_token - if auth_header is not None: - for prefix in prefixes: - if prefix in auth_header: - return auth_header.partition(prefix)[-1].strip() + @property + def credentials(self) -> Optional[Credentials]: + """Attempt to return the auth header value. - return auth_header + Covers NoAuth, Basic Auth, Bearer Token, Api Token authentication + schemas. + + :return: A Credentials object with token, or username and password + related to the request + """ + if self.parsed_credentials is None: + try: + prefix, credentials = parse_credentials( + self.headers.getone("authorization", None) + ) + if credentials: + self.parsed_credentials = Credentials( + auth_type=prefix, token=credentials + ) + except ValueError: + pass + return self.parsed_credentials @property def form(self): diff --git a/sanic/server/protocols/base_protocol.py b/sanic/server/protocols/base_protocol.py index 63d4bfb5..3a271669 100644 --- a/sanic/server/protocols/base_protocol.py +++ b/sanic/server/protocols/base_protocol.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.app import Sanic import asyncio diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index df2ed699..c215208a 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -8,7 +8,7 @@ from sanic.http.http3 import Http3 from sanic.touchup.meta import TouchUpMeta -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from sanic.app import Sanic from asyncio import CancelledError diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index ad6f8f95..d9539696 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -5,13 +5,13 @@ from websockets.server import ServerConnection from websockets.typing import Subprotocol from sanic.exceptions import ServerError -from sanic.log import deprecation, error_logger +from sanic.log import error_logger from sanic.server import HttpProtocol from ..websockets.impl import WebsocketImplProtocol -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from websockets import http11 @@ -29,9 +29,6 @@ class WebSocketProtocol(HttpProtocol): *args, websocket_timeout: float = 10.0, websocket_max_size: Optional[int] = None, - websocket_max_queue: Optional[int] = None, # max_queue is deprecated - websocket_read_limit: Optional[int] = None, # read_limit is deprecated - websocket_write_limit: Optional[int] = None, # write_limit deprecated websocket_ping_interval: Optional[float] = 20.0, websocket_ping_timeout: Optional[float] = 20.0, **kwargs, @@ -40,27 +37,6 @@ class WebSocketProtocol(HttpProtocol): self.websocket: Optional[WebsocketImplProtocol] = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size - if websocket_max_queue is not None and websocket_max_queue > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses queueing, so websocket_max_queue" - " is no longer required.", - 22.3, - ) - if websocket_read_limit is not None and websocket_read_limit > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses read buffers, so " - "websocket_read_limit is not required.", - 22.3, - ) - if websocket_write_limit is not None and websocket_write_limit > 0: - # TODO: Reminder remove this warning in v22.3 - deprecation( - "Websocket no longer uses write buffers, so " - "websocket_write_limit is not required.", - 22.3, - ) self.websocket_ping_interval = websocket_ping_interval self.websocket_ping_timeout = websocket_ping_timeout diff --git a/sanic/server/runners.py b/sanic/server/runners.py index b5a9d9c2..980f8348 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -141,7 +141,7 @@ def serve( try: http_server = loop.run_until_complete(server_coroutine) except BaseException: - error_logger.exception("Unable to start server") + error_logger.exception("Unable to start server", exc_info=True) return # Ignore SIGINT when run_multiple diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index fef27db1..e4972516 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -9,7 +9,7 @@ from websockets.typing import Data from sanic.exceptions import ServerError -if TYPE_CHECKING: +if TYPE_CHECKING: # no cov from .impl import WebsocketImplProtocol UTF8Decoder = codecs.getincrementaldecoder("utf-8") @@ -37,7 +37,7 @@ class WebsocketFrameAssembler: "get_id", "put_id", ) - if TYPE_CHECKING: + if TYPE_CHECKING: # no cov protocol: "WebsocketImplProtocol" read_mutex: asyncio.Lock write_mutex: asyncio.Lock @@ -131,7 +131,7 @@ class WebsocketFrameAssembler: if self.paused: self.protocol.resume_frames() self.paused = False - if not self.get_in_progress: + if not self.get_in_progress: # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -204,7 +204,7 @@ class WebsocketFrameAssembler: if self.paused: self.protocol.resume_frames() self.paused = False - if not self.get_in_progress: + if not self.get_in_progress: # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -212,7 +212,7 @@ class WebsocketFrameAssembler: "asynchronous get was in progress." ) self.get_in_progress = False - if not self.message_complete.is_set(): + if not self.message_complete.is_set(): # no cov # This should be guarded against with the read_mutex, # exception is here as a failsafe raise ServerError( @@ -220,7 +220,7 @@ class WebsocketFrameAssembler: "message was complete." ) self.message_complete.clear() - if self.message_fetched.is_set(): + if self.message_fetched.is_set(): # no cov # This should be guarded against with the read_mutex, # and get_in_progress check, this exception is # here as a failsafe diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py index aa7d4bd9..7c6ed3d7 100644 --- a/sanic/touchup/schemes/ode.py +++ b/sanic/touchup/schemes/ode.py @@ -10,12 +10,14 @@ from .base import BaseScheme class OptionalDispatchEvent(BaseScheme): ident = "ODE" + SYNC_SIGNAL_NAMESPACES = "http." def __init__(self, app) -> None: super().__init__(app) + self._sync_events() self._registered_events = [ - signal.path for signal in app.signal_router.routes + signal.name for signal in app.signal_router.routes ] def run(self, method, module_globals): @@ -31,6 +33,35 @@ class OptionalDispatchEvent(BaseScheme): return exec_locals[method.__name__] + def _sync_events(self): + all_events = set() + app_events = {} + for app in self.app.__class__._app_registry.values(): + if app.state.server_info: + app_events[app] = { + signal.name for signal in app.signal_router.routes + } + all_events.update(app_events[app]) + + for app, events in app_events.items(): + missing = { + x + for x in all_events.difference(events) + if any(x.startswith(y) for y in self.SYNC_SIGNAL_NAMESPACES) + } + if missing: + was_finalized = app.signal_router.finalized + if was_finalized: # no cov + app.signal_router.reset() + for event in missing: + app.signal(event)(self.noop) + if was_finalized: # no cov + app.signal_router.finalize() + + @staticmethod + async def noop(**_): # no cov + ... + class RemoveDispatch(NodeTransformer): def __init__(self, registered_events, verbosity: int = 0) -> None: diff --git a/sanic/worker.py b/sanic/worker.py deleted file mode 100644 index befe8d78..00000000 --- a/sanic/worker.py +++ /dev/null @@ -1,243 +0,0 @@ -import asyncio -import logging -import os -import signal -import sys -import traceback - -from gunicorn.workers import base # type: ignore - -from sanic.compat import UVLOOP_INSTALLED -from sanic.log import logger -from sanic.server import HttpProtocol, Signal, serve, try_use_uvloop -from sanic.server.protocols.websocket_protocol import WebSocketProtocol - - -try: - import ssl # type: ignore -except ImportError: # no cov - ssl = None # type: ignore - -if UVLOOP_INSTALLED: # no cov - try_use_uvloop() - - -class GunicornWorker(base.Worker): - - http_protocol = HttpProtocol - websocket_protocol = WebSocketProtocol - - def __init__(self, *args, **kw): # pragma: no cover - super().__init__(*args, **kw) - cfg = self.cfg - if cfg.is_ssl: - self.ssl_context = self._create_ssl_context(cfg) - else: - self.ssl_context = None - self.servers = {} - self.connections = set() - self.exit_code = 0 - self.signal = Signal() - - def init_process(self): - # create new event_loop after fork - asyncio.get_event_loop().close() - - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - super().init_process() - - def run(self): - is_debug = self.log.loglevel == logging.DEBUG - protocol = ( - self.websocket_protocol - if self.app.callable.websocket_enabled - else self.http_protocol - ) - - self._server_settings = self.app.callable._helper( - loop=self.loop, - debug=is_debug, - protocol=protocol, - ssl=self.ssl_context, - run_async=True, - ) - self._server_settings["signal"] = self.signal - self._server_settings.pop("sock") - self._await(self.app.callable._startup()) - self._await( - self.app.callable._server_event("init", "before", loop=self.loop) - ) - - main_start = self._server_settings.pop("main_start", None) - main_stop = self._server_settings.pop("main_stop", None) - - if main_start or main_stop: # noqa - logger.warning( - "Listener events for the main process are not available " - "with GunicornWorker" - ) - - try: - self._await(self._run()) - self.app.callable.is_running = True - self._await( - self.app.callable._server_event( - "init", "after", loop=self.loop - ) - ) - self.loop.run_until_complete(self._check_alive()) - self._await( - self.app.callable._server_event( - "shutdown", "before", loop=self.loop - ) - ) - self.loop.run_until_complete(self.close()) - except BaseException: - traceback.print_exc() - finally: - try: - self._await( - self.app.callable._server_event( - "shutdown", "after", loop=self.loop - ) - ) - except BaseException: - traceback.print_exc() - finally: - self.loop.close() - - sys.exit(self.exit_code) - - async def close(self): - if self.servers: - # stop accepting connections - self.log.info( - "Stopping server: %s, connections: %s", - self.pid, - len(self.connections), - ) - for server in self.servers: - server.close() - await server.wait_closed() - self.servers.clear() - - # prepare connections for closing - self.signal.stopped = True - for conn in self.connections: - conn.close_if_idle() - - # gracefully shutdown timeout - start_shutdown = 0 - graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and ( - start_shutdown < graceful_shutdown_timeout - ): - await asyncio.sleep(0.1) - start_shutdown = start_shutdown + 0.1 - - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - for conn in self.connections: - if hasattr(conn, "websocket") and conn.websocket: - conn.websocket.fail_connection(code=1001) - else: - conn.abort() - - async def _run(self): - for sock in self.sockets: - state = dict(requests_count=0) - self._server_settings["host"] = None - self._server_settings["port"] = None - server = await serve( - sock=sock, - connections=self.connections, - state=state, - **self._server_settings - ) - self.servers[server] = state - - async def _check_alive(self): - # If our parent changed then we shut down. - pid = os.getpid() - try: - while self.alive: - self.notify() - - req_count = sum( - self.servers[srv]["requests_count"] for srv in self.servers - ) - if self.max_requests and req_count > self.max_requests: - self.alive = False - self.log.info( - "Max requests exceeded, shutting down: %s", self - ) - elif pid == os.getpid() and self.ppid != os.getppid(): - self.alive = False - self.log.info("Parent changed, shutting down: %s", self) - else: - await asyncio.sleep(1.0, loop=self.loop) - except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): - pass - - @staticmethod - def _create_ssl_context(cfg): - """Creates SSLContext instance for usage in asyncio.create_server. - See ssl.SSLSocket.__init__ for more details. - """ - ctx = ssl.SSLContext(cfg.ssl_version) - ctx.load_cert_chain(cfg.certfile, cfg.keyfile) - ctx.verify_mode = cfg.cert_reqs - if cfg.ca_certs: - ctx.load_verify_locations(cfg.ca_certs) - if cfg.ciphers: - ctx.set_ciphers(cfg.ciphers) - return ctx - - def init_signals(self): - # Set up signals through the event loop API. - - self.loop.add_signal_handler( - signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None - ) - - self.loop.add_signal_handler( - signal.SIGTERM, self.handle_exit, signal.SIGTERM, None - ) - - self.loop.add_signal_handler( - signal.SIGINT, self.handle_quit, signal.SIGINT, None - ) - - self.loop.add_signal_handler( - signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None - ) - - self.loop.add_signal_handler( - signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None - ) - - self.loop.add_signal_handler( - signal.SIGABRT, self.handle_abort, signal.SIGABRT, None - ) - - # Don't let SIGTERM and SIGUSR1 disturb active requests - # by interrupting system calls - signal.siginterrupt(signal.SIGTERM, False) - signal.siginterrupt(signal.SIGUSR1, False) - - def handle_quit(self, sig, frame): - self.alive = False - self.app.callable.is_running = False - self.cfg.worker_int(self) - - def handle_abort(self, sig, frame): - self.alive = False - self.exit_code = 1 - self.cfg.worker_abort(self) - sys.exit(1) - - def _await(self, coro): - fut = asyncio.ensure_future(coro, loop=self.loop) - self.loop.run_until_complete(fut) diff --git a/tests/asyncmock.py b/tests/asyncmock.py new file mode 100644 index 00000000..eec17646 --- /dev/null +++ b/tests/asyncmock.py @@ -0,0 +1,34 @@ +""" +For 3.7 compat + +""" + + +from unittest.mock import Mock + + +class AsyncMock(Mock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.await_count = 0 + + def __call__(self, *args, **kwargs): + self.call_count += 1 + parent = super(AsyncMock, self) + + async def dummy(): + self.await_count += 1 + return parent.__call__(*args, **kwargs) + + return dummy() + + def __await__(self): + return self().__await__() + + def assert_awaited_once(self): + if not self.await_count == 1: + msg = ( + f"Expected to have been awaited once." + f" Awaited {self.await_count} times." + ) + raise AssertionError(msg) diff --git a/tests/conftest.py b/tests/conftest.py index 22decde5..fe4ba47d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -175,6 +175,21 @@ def run_startup(caplog): return run +@pytest.fixture +def run_multi(caplog): + def run(app, level=logging.DEBUG): + @app.after_server_start + async def stop(app, _): + app.stop() + + with caplog.at_level(level): + Sanic.serve() + + return caplog.record_tuples + + return run + + @pytest.fixture(scope="function") def message_in_records(): def msg_in_log(records: List[LogRecord], msg: str): diff --git a/tests/test_app.py b/tests/test_app.py index 467aaee4..1c8f705b 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -197,7 +197,7 @@ def test_app_enable_websocket(app, websocket_enabled, enable): assert app.websocket_enabled == True -@patch("sanic.app.WebSocketProtocol") +@patch("sanic.mixins.runner.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 app.config.WEBSOCKET_PING_TIMEOUT = 48 @@ -473,13 +473,14 @@ def test_custom_context(): assert app.ctx == ctx -def test_uvloop_config(app, monkeypatch): +@pytest.mark.parametrize("use", (False, True)) +def test_uvloop_config(app, monkeypatch, use): @app.get("/test") def handler(request): return text("ok") try_use_uvloop = Mock() - monkeypatch.setattr(sanic.app, "try_use_uvloop", try_use_uvloop) + monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop) # Default config app.test_client.get("/test") @@ -489,14 +490,13 @@ def test_uvloop_config(app, monkeypatch): try_use_uvloop.assert_called_once() try_use_uvloop.reset_mock() - app.config["USE_UVLOOP"] = False + app.config["USE_UVLOOP"] = use app.test_client.get("/test") - try_use_uvloop.assert_not_called() - try_use_uvloop.reset_mock() - app.config["USE_UVLOOP"] = True - app.test_client.get("/test") - try_use_uvloop.assert_called_once() + if use: + try_use_uvloop.assert_called_once() + else: + try_use_uvloop.assert_not_called() def test_uvloop_cannot_never_called_with_create_server(caplog, monkeypatch): @@ -506,7 +506,7 @@ def test_uvloop_cannot_never_called_with_create_server(caplog, monkeypatch): apps[2].config.USE_UVLOOP = True try_use_uvloop = Mock() - monkeypatch.setattr(sanic.app, "try_use_uvloop", try_use_uvloop) + monkeypatch.setattr(sanic.mixins.runner, "try_use_uvloop", try_use_uvloop) loop = asyncio.get_event_loop() @@ -569,3 +569,8 @@ def test_cannot_run_fast_and_workers(app): message = "You cannot use both fast=True and workers=X" with pytest.raises(RuntimeError, match=message): app.run(fast=True, workers=4) + + +def test_no_workers(app): + with pytest.raises(RuntimeError, match="Cannot serve with no workers"): + app.run(workers=0) diff --git a/tests/test_cli.py b/tests/test_cli.py index 254c91d4..7ec1c28c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -103,7 +103,7 @@ def test_tls_wrong_options(cmd): assert not out lines = err.decode().split("\n") - errmsg = lines[8] + errmsg = lines[6] assert errmsg == "TLS certificates must be specified by either of:" @@ -118,10 +118,10 @@ def test_host_port_localhost(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://localhost:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://localhost:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -135,10 +135,10 @@ def test_host_port_ipv4(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://127.0.0.127:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://127.0.0.127:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -152,10 +152,10 @@ def test_host_port_ipv6_any(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://[::]:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://[::]:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -169,10 +169,10 @@ def test_host_port_ipv6_loopback(cmd): command = ["sanic", "fake.server.app", *cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") - firstline = lines[starting_line(lines) + 1] + expected = b"Goin' Fast @ http://[::1]:9999" assert exitcode != 1 - assert firstline == b"Goin' Fast @ http://[::1]:9999" + assert expected in lines, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -191,24 +191,40 @@ def test_num_workers(num, cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - worker_lines = [ - line - for line in lines - if b"Starting worker" in line or b"Stopping worker" in line - ] + if num == 1: + expected = b"mode: production, single worker" + else: + expected = (f"mode: production, w/ {num} workers").encode() + assert exitcode != 1 - assert len(worker_lines) == num * 2, f"Lines found: {lines}" + assert expected in lines, f"Expected {expected}\nLines found: {lines}" -@pytest.mark.parametrize("cmd", ("--debug", "-d")) +@pytest.mark.parametrize("cmd", ("--debug",)) def test_debug(cmd): command = ["sanic", "fake.server.app", cmd] out, err, exitcode = capture(command) lines = out.split(b"\n") info = read_app_info(lines) - assert info["debug"] is True - assert info["auto_reload"] is True + assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is False + ), f"Lines found: {lines}\nErr output: {err}" + assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}" + + +@pytest.mark.parametrize("cmd", ("--dev", "-d")) +def test_dev(cmd): + command = ["sanic", "fake.server.app", cmd] + out, err, exitcode = capture(command) + lines = out.split(b"\n") + info = read_app_info(lines) + + assert info["debug"] is True, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is True + ), f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize("cmd", ("--auto-reload", "-r")) @@ -218,8 +234,11 @@ def test_auto_reload(cmd): lines = out.split(b"\n") info = read_app_info(lines) - assert info["debug"] is False - assert info["auto_reload"] is True + assert info["debug"] is False, f"Lines found: {lines}\nErr output: {err}" + assert ( + info["auto_reload"] is True + ), f"Lines found: {lines}\nErr output: {err}" + assert "dev" not in info, f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize( @@ -231,7 +250,9 @@ def test_access_logs(cmd, expected): lines = out.split(b"\n") info = read_app_info(lines) - assert info["access_log"] is expected + assert ( + info["access_log"] is expected + ), f"Lines found: {lines}\nErr output: {err}" @pytest.mark.parametrize("cmd", ("--version", "-v")) @@ -256,4 +277,6 @@ def test_noisy_exceptions(cmd, expected): lines = out.split(b"\n") info = read_app_info(lines) - assert info["noisy_exceptions"] is expected + assert ( + info["noisy_exceptions"] is expected + ), f"Lines found: {lines}\nErr output: {err}" diff --git a/tests/test_config.py b/tests/test_config.py index 9237b55c..d8a7bd85 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from os import environ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent -from unittest.mock import Mock +from unittest.mock import Mock, call import pytest @@ -301,6 +301,9 @@ def test_config_access_log_passing_in_run(app: Sanic): app.run(port=1340, access_log=False) assert app.config.ACCESS_LOG is False + app.router.reset() + app.signal_router.reset() + app.run(port=1340, access_log=True) assert app.config.ACCESS_LOG is True @@ -399,5 +402,36 @@ def test_config_set_methods(app: Sanic, monkeypatch: MonkeyPatch): post_set.assert_called_once_with("FOO", 5) post_set.reset_mock() - app.config.update_config({"FOO": 6}) - post_set.assert_called_once_with("FOO", 6) + app.config.update({"FOO": 6}, {"BAR": 7}) + post_set.assert_has_calls( + calls=[ + call("FOO", 6), + call("BAR", 7), + ] + ) + post_set.reset_mock() + + app.config.update({"FOO": 8}, BAR=9) + post_set.assert_has_calls( + calls=[ + call("FOO", 8), + call("BAR", 9), + ], + any_order=True, + ) + post_set.reset_mock() + + app.config.update_config({"FOO": 10}) + post_set.assert_called_once_with("FOO", 10) + + +def test_negative_proxy_count(app: Sanic): + app.config.PROXIES_COUNT = -1 + + message = ( + "PROXIES_COUNT cannot be negative. " + "https://sanic.readthedocs.io/en/latest/sanic/config.html" + "#proxy-configuration" + ) + with pytest.raises(ValueError, match=message): + app.prepare() diff --git a/tests/test_create_task.py b/tests/test_create_task.py index c98666a9..a11bc302 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -2,6 +2,7 @@ import asyncio import sys from threading import Event +from unittest.mock import Mock import pytest @@ -77,6 +78,25 @@ def test_create_named_task(app): app.run() +def test_named_task_called(app): + e = Event() + + async def coro(): + e.set() + + @app.route("/") + async def isset(request): + await asyncio.sleep(0.05) + return text(str(e.is_set())) + + @app.before_server_start + async def setup(app, _): + app.add_task(coro, name="dummy_task") + + request, response = app.test_client.get("/") + assert response.body == b"True" + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") def test_create_named_task_fails_outside_app(app): async def dummy(): diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index f8e425b0..042c51d6 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -67,7 +67,7 @@ def test_auto_fallback_with_data(app): _, response = app.test_client.get("/error") assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.post("/error", json={"foo": "bar"}) assert response.status == 500 @@ -75,7 +75,7 @@ def test_auto_fallback_with_data(app): _, response = app.test_client.post("/error", data={"foo": "bar"}) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" def test_auto_fallback_with_content_type(app): @@ -91,7 +91,7 @@ def test_auto_fallback_with_content_type(app): "/error", headers={"content-type": "foo/bar", "accept": "*/*"} ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" def test_route_error_format_set_on_auto(app): @@ -174,6 +174,17 @@ def test_route_error_format_unknown(app): ... +def test_fallback_with_content_type_html(app): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "application/json", "accept": "text/html"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + def test_fallback_with_content_type_mismatch_accept(app): app.config.FALLBACK_ERROR_FORMAT = "auto" @@ -186,10 +197,10 @@ def test_fallback_with_content_type_mismatch_accept(app): _, response = app.test_client.get( "/error", - headers={"content-type": "text/plain", "accept": "foo/bar"}, + headers={"content-type": "text/html", "accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" app.router.reset() @@ -208,7 +219,7 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.get( "/alt1", headers={"accept": "foo/bar,*/*"}, @@ -221,7 +232,7 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 - assert response.content_type == "text/html; charset=utf-8" + assert response.content_type == "text/plain; charset=utf-8" _, response = app.test_client.get( "/alt2", headers={"accept": "foo/bar,*/*"}, @@ -234,6 +245,13 @@ def test_fallback_with_content_type_mismatch_accept(app): headers={"accept": "foo/bar"}, ) assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/alt3", + headers={"accept": "foo/bar,text/html"}, + ) + assert response.status == 500 assert response.content_type == "text/html; charset=utf-8" @@ -288,6 +306,10 @@ def test_allow_fallback_error_format_set_main_process_start(app): def test_setting_fallback_on_config_changes_as_expected(app): app.error_handler = ErrorHandler() + _, response = app.test_client.get("/error") + assert response.content_type == "text/plain; charset=utf-8" + + app.config.FALLBACK_ERROR_FORMAT = "html" _, response = app.test_client.get("/error") assert response.content_type == "text/html; charset=utf-8" @@ -334,6 +356,22 @@ def test_config_fallback_before_and_after_startup(app): assert response.content_type == "text/plain; charset=utf-8" +def test_config_fallback_using_update_dict(app): + app.config.update({"FALLBACK_ERROR_FORMAT": "text"}) + + _, response = app.test_client.get("/error") + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_config_fallback_using_update_kwarg(app): + app.config.update(FALLBACK_ERROR_FORMAT="text") + + _, response = app.test_client.get("/error") + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + def test_config_fallback_bad_value(app): message = "Unknown format: fake" with pytest.raises(SanicException, match=message): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3ec2959d..e603718d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,4 @@ import logging -import warnings import pytest @@ -34,6 +33,7 @@ class SanicExceptionTestException(Exception): @pytest.fixture(scope="module") def exception_app(): app = Sanic("test_exceptions") + app.config.FALLBACK_ERROR_FORMAT = "html" @app.route("/") def handler(request): diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index 8be3d28e..e9bdb21e 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -216,31 +216,6 @@ def test_exception_handler_processed_request_middleware( assert response.text == "Done." -def test_single_arg_exception_handler_notice( - exception_handler_app: Sanic, caplog: LogCaptureFixture -): - class CustomErrorHandler(ErrorHandler): - def lookup(self, exception): - return super().lookup(exception, None) - - exception_handler_app.error_handler = CustomErrorHandler() - - message = ( - "[DEPRECATION v22.3] You are using a deprecated error handler. The " - "lookup method should accept two positional parameters: (exception, " - "route_name: Optional[str]). Until you upgrade your " - "ErrorHandler.lookup, Blueprint specific exceptions will not work " - "properly. Beginning in v22.3, the legacy style lookup method will " - "not work at all." - ) - with pytest.warns(DeprecationWarning) as record: - _, response = exception_handler_app.test_client.get("/1") - - assert len(record) == 1 - assert record[0].message.args[0] == message - assert response.status == 400 - - def test_error_handler_noisy_log( exception_handler_app: Sanic, monkeypatch: MonkeyPatch ): @@ -279,7 +254,7 @@ def test_exception_handler_response_was_sent( @app.route("/2") async def handler2(request: Request): - response = await request.respond() + await request.respond() raise ServerError("Exception") with caplog.at_level(logging.WARNING): diff --git a/tests/test_motd.py b/tests/test_motd.py index fe45bc47..f3f95a25 100644 --- a/tests/test_motd.py +++ b/tests/test_motd.py @@ -1,11 +1,15 @@ import logging +import os import platform +import sys from unittest.mock import Mock -from sanic import __version__ +import pytest + +from sanic import Sanic, __version__ from sanic.application.logo import BASE_LOGO -from sanic.application.motd import MOTDTTY +from sanic.application.motd import MOTD, MOTDTTY def test_logo_base(app, run_startup): @@ -83,3 +87,25 @@ def test_motd_display(caplog): └───────────────────────┴────────┘ """ ) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not on 3.7") +def test_reload_dirs(app): + app.config.LOGO = None + app.config.AUTO_RELOAD = True + app.prepare(reload_dir="./", auto_reload=True, motd_display={"foo": "bar"}) + + existing = MOTD.output + MOTD.output = Mock() + + app.motd("foo") + + MOTD.output.assert_called_once() + assert ( + MOTD.output.call_args.args[2]["auto-reload"] + == f"enabled, {os.getcwd()}" + ) + assert MOTD.output.call_args.args[3] == {"foo": "bar"} + + MOTD.output = existing + Sanic._app_registry = {} diff --git a/tests/test_multi_serve.py b/tests/test_multi_serve.py new file mode 100644 index 00000000..dde72b5c --- /dev/null +++ b/tests/test_multi_serve.py @@ -0,0 +1,207 @@ +import logging + +from unittest.mock import Mock + +import pytest + +from sanic import Sanic +from sanic.response import text +from sanic.server.async_server import AsyncioServer +from sanic.signals import Event +from sanic.touchup.schemes.ode import OptionalDispatchEvent + + +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + + +@pytest.fixture +def app_one(): + app = Sanic("One") + + @app.get("/one") + async def one(request): + return text("one") + + return app + + +@pytest.fixture +def app_two(): + app = Sanic("Two") + + @app.get("/two") + async def two(request): + return text("two") + + return app + + +@pytest.fixture(autouse=True) +def clean(): + Sanic._app_registry = {} + yield + + +def test_serve_same_app_multiple_tuples(app_one, run_multi): + app_one.prepare(port=23456) + app_one.prepare(port=23457) + + logs = run_multi(app_one) + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23456", + ) in logs + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23457", + ) in logs + + +def test_serve_multiple_apps(app_one, app_two, run_multi): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + logs = run_multi(app_one) + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23456", + ) in logs + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ http://127.0.0.1:23457", + ) in logs + + +def test_listeners_on_secondary_app(app_one, app_two, run_multi): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + before_start = AsyncMock() + after_start = AsyncMock() + before_stop = AsyncMock() + after_stop = AsyncMock() + + app_two.before_server_start(before_start) + app_two.after_server_start(after_start) + app_two.before_server_stop(before_stop) + app_two.after_server_stop(after_stop) + + run_multi(app_one) + + before_start.assert_awaited_once() + after_start.assert_awaited_once() + before_stop.assert_awaited_once() + after_stop.assert_awaited_once() + + +@pytest.mark.parametrize( + "events", + ( + (Event.HTTP_LIFECYCLE_BEGIN,), + (Event.HTTP_LIFECYCLE_BEGIN, Event.HTTP_LIFECYCLE_COMPLETE), + ( + Event.HTTP_LIFECYCLE_BEGIN, + Event.HTTP_LIFECYCLE_COMPLETE, + Event.HTTP_LIFECYCLE_REQUEST, + ), + ), +) +def test_signal_synchronization(app_one, app_two, run_multi, events): + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + for event in events: + app_one.signal(event)(AsyncMock()) + + run_multi(app_one) + + assert len(app_two.signal_router.routes) == len(events) + 1 + + signal_handlers = { + signal.handler + for signal in app_two.signal_router.routes + if signal.name.startswith("http") + } + + assert len(signal_handlers) == 1 + assert list(signal_handlers)[0] is OptionalDispatchEvent.noop + + +def test_warning_main_process_listeners_on_secondary( + app_one, app_two, run_multi +): + app_two.main_process_start(AsyncMock()) + app_two.main_process_stop(AsyncMock()) + app_one.prepare(port=23456) + app_two.prepare(port=23457) + + log = run_multi(app_one) + + message = ( + f"Sanic found 2 listener(s) on " + "secondary applications attached to the main " + "process. These will be ignored since main " + "process listeners can only be attached to your " + "primary application: " + f"{repr(app_one)}" + ) + + assert ("sanic.error", logging.WARNING, message) in log + + +def test_no_applications(): + Sanic._app_registry = {} + message = "Did not find any applications." + with pytest.raises(RuntimeError, match=message): + Sanic.serve() + + +def test_oserror_warning(app_one, app_two, run_multi, capfd): + orig = AsyncioServer.__await__ + AsyncioServer.__await__ = Mock(side_effect=OSError("foo")) + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457, workers=2) + + run_multi(app_one) + + captured = capfd.readouterr() + assert ( + "An OSError was detected on startup. The encountered error was: foo" + ) in captured.err + + AsyncioServer.__await__ = orig + + +def test_running_multiple_offset_warning(app_one, app_two, run_multi, capfd): + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457) + + run_multi(app_one) + + captured = capfd.readouterr() + assert ( + f"The primary application {repr(app_one)} is running " + "with 2 worker(s). All " + "application instances will run with the same number. " + f"You requested {repr(app_two)} to run with " + "1 worker(s), which will be ignored " + "in favor of the primary application." + ) in captured.err + + +def test_running_multiple_secondary(app_one, app_two, run_multi, capfd): + app_one.prepare(port=23456, workers=2) + app_two.prepare(port=23457) + + before_start = AsyncMock() + app_two.before_server_start(before_start) + run_multi(app_one) + + before_start.await_count == 2 diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 066b89ac..ab7c18de 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -132,11 +132,11 @@ def test_main_process_event(app, caplog): logger.info("main_process_stop") @app.main_process_start - def main_process_start(app, loop): + def main_process_start2(app, loop): logger.info("main_process_start") @app.main_process_stop - def main_process_stop(app, loop): + def main_process_stop2(app, loop): logger.info("main_process_stop") with caplog.at_level(logging.INFO): diff --git a/tests/test_pipelining.py b/tests/test_pipelining.py index 2bb29c52..6c998756 100644 --- a/tests/test_pipelining.py +++ b/tests/test_pipelining.py @@ -62,19 +62,15 @@ def test_streaming_body_requests(app): data = ["hello", "world"] - class Data(AsyncByteStream): - def __init__(self, data): - self.data = data - - async def __aiter__(self): - for value in self.data: - yield value.encode("utf-8") - client = ReusableClient(app, port=1234) + async def stream(data): + for value in data: + yield value.encode("utf-8") + with client: - _, response1 = client.post("/", data=Data(data)) - _, response2 = client.post("/", data=Data(data)) + _, response1 = client.post("/", data=stream(data)) + _, response2 = client.post("/", data=stream(data)) assert response1.status == response2.status == 200 assert response1.json["data"] == response2.json["data"] == data diff --git a/tests/test_prepare.py b/tests/test_prepare.py new file mode 100644 index 00000000..db8a8db5 --- /dev/null +++ b/tests/test_prepare.py @@ -0,0 +1,71 @@ +import logging +import os + +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from sanic import Sanic +from sanic.application.state import ApplicationServerInfo + + +@pytest.fixture(autouse=True) +def no_skip(): + should_auto_reload = Sanic.should_auto_reload + Sanic.should_auto_reload = Mock(return_value=False) + yield + Sanic._app_registry = {} + Sanic.should_auto_reload = should_auto_reload + + +def get_primary(app: Sanic) -> ApplicationServerInfo: + return app.state.server_info[0] + + +def test_dev(app: Sanic): + app.prepare(dev=True) + + assert app.state.is_debug + assert app.state.auto_reload + + +def test_motd_display(app: Sanic): + app.prepare(motd_display={"foo": "bar"}) + + assert app.config.MOTD_DISPLAY["foo"] == "bar" + del app.config.MOTD_DISPLAY["foo"] + + +@pytest.mark.parametrize("dirs", ("./foo", ("./foo", "./bar"))) +def test_reload_dir(app: Sanic, dirs, caplog): + messages = [] + with caplog.at_level(logging.WARNING): + app.prepare(reload_dir=dirs) + + if isinstance(dirs, str): + dirs = (dirs,) + for d in dirs: + assert Path(d) in app.state.reload_dirs + messages.append( + f"Directory {d} could not be located", + ) + + for message in messages: + assert ("sanic.root", logging.WARNING, message) in caplog.record_tuples + + +def test_fast(app: Sanic, run_multi): + app.prepare(fast=True) + try: + workers = len(os.sched_getaffinity(0)) + except AttributeError: + workers = os.cpu_count() or 1 + + assert app.state.fast + assert app.state.workers == workers + + logs = run_multi(app, logging.INFO) + + messages = [m[2] for m in logs] + assert f"mode: production, goin' fast w/ {workers} workers" in messages diff --git a/tests/test_requests.py b/tests/test_requests.py index c8f6e3f0..d752f045 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,3 +1,4 @@ +import base64 import logging from json import dumps as json_dumps @@ -15,11 +16,15 @@ from sanic_testing.testing import ( ) from sanic import Blueprint, Sanic -from sanic.exceptions import SanicException, ServerError -from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters +from sanic.exceptions import ServerError +from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters from sanic.response import html, json, text +def encode_basic_auth_credentials(username, password): + return base64.b64encode(f"{username}:{password}".encode()).decode("ascii") + + # ------------------------------------------------------------ # # GET # ------------------------------------------------------------ # @@ -362,93 +367,95 @@ async def test_uri_template_asgi(app): assert request.uri_template == "/foo//bar/" -def test_token(app): +@pytest.mark.parametrize( + ("auth_type", "token"), + [ + # uuid4 generated token set in "Authorization" header + (None, "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # uuid4 generated token with API Token authorization + ("Token", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # uuid4 generated token with Bearer Token authorization + ("Bearer", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf"), + # no Authorization header + (None, None), + ], +) +def test_token(app, auth_type, token): @app.route("/") async def handler(request): return text("OK") - # uuid4 generated token. - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"{token}", - } + if token: + headers = { + "content-type": "application/json", + "Authorization": f"{auth_type} {token}" + if auth_type + else f"{token}", + } + else: + headers = {"content-type": "application/json"} request, response = app.test_client.get("/", headers=headers) - assert request.token == token - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Token {token}", - } - request, response = app.test_client.get("/", headers=headers) - - assert request.token == token - - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Bearer {token}", - } - - request, response = app.test_client.get("/", headers=headers) - - assert request.token == token - - # no Authorization headers - headers = {"content-type": "application/json"} - - request, response = app.test_client.get("/", headers=headers) - - assert request.token is None - - -@pytest.mark.asyncio -async def test_token_asgi(app): +@pytest.mark.parametrize( + ("auth_type", "token", "username", "password"), + [ + # uuid4 generated token set in "Authorization" header + (None, "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # uuid4 generated token with API Token authorization + ("Token", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # uuid4 generated token with Bearer Token authorization + ("Bearer", "a1d895e0-553a-421a-8e22-5ff8ecb48cbf", None, None), + # username and password with Basic Auth authorization + ( + "Basic", + encode_basic_auth_credentials("some_username", "some_pass"), + "some_username", + "some_pass", + ), + # no Authorization header + (None, None, None, None), + ], +) +def test_credentials(app, capfd, auth_type, token, username, password): @app.route("/") async def handler(request): return text("OK") - # uuid4 generated token. - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"{token}", - } + if token: + headers = { + "content-type": "application/json", + "Authorization": f"{auth_type} {token}" + if auth_type + else f"{token}", + } + else: + headers = {"content-type": "application/json"} - request, response = await app.asgi_client.get("/", headers=headers) + request, response = app.test_client.get("/", headers=headers) - assert request.token == token + if auth_type == "Basic": + assert request.credentials.username == username + assert request.credentials.password == password + else: + _, err = capfd.readouterr() + with pytest.raises(AttributeError): + request.credentials.password + assert "Password is available for Basic Auth only" in err + request.credentials.username + assert "Username is available for Basic Auth only" in err - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Token {token}", - } - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token == token - - token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" - headers = { - "content-type": "application/json", - "Authorization": f"Bearer {token}", - } - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token == token - - # no Authorization headers - headers = {"content-type": "application/json"} - - request, response = await app.asgi_client.get("/", headers=headers) - - assert request.token is None + if token: + assert request.credentials.token == token + assert request.credentials.auth_type == auth_type + else: + assert request.credentials is None + assert not hasattr(request.credentials, "token") + assert not hasattr(request.credentials, "auth_type") + assert not hasattr(request.credentials, "_username") + assert not hasattr(request.credentials, "_password") def test_content_type(app): @@ -1714,7 +1721,6 @@ async def test_request_query_args_custom_parsing_asgi(app): def test_request_cookies(app): - cookies = {"test": "OK"} @app.get("/") @@ -1729,7 +1735,6 @@ def test_request_cookies(app): @pytest.mark.asyncio async def test_request_cookies_asgi(app): - cookies = {"test": "OK"} @app.get("/") diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 058a7cf6..cd3f5266 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -35,7 +35,7 @@ def create_listener(listener_name, in_list): def start_stop_app(random_name_app, **run_kwargs): def stop_on_alarm(signum, frame): - raise KeyboardInterrupt("SIGINT for sanic to stop gracefully") + random_name_app.stop() signal.signal(signal.SIGALRM, stop_on_alarm) signal.alarm(1) @@ -130,6 +130,9 @@ async def test_trigger_before_events_create_server_missing_event(app): def test_create_server_trigger_events(app): """Test if create_server can trigger server events""" + def stop_on_alarm(signum, frame): + raise KeyboardInterrupt("...") + flag1 = False flag2 = False flag3 = False @@ -137,8 +140,7 @@ def test_create_server_trigger_events(app): async def stop(app, loop): nonlocal flag1 flag1 = True - await asyncio.sleep(0.1) - app.stop() + signal.alarm(1) async def before_stop(app, loop): nonlocal flag2 @@ -155,6 +157,8 @@ def test_create_server_trigger_events(app): loop = asyncio.get_event_loop() # Use random port for tests + + signal.signal(signal.SIGALRM, stop_on_alarm) with closing(socket()) as sock: sock.bind(("127.0.0.1", 0)) diff --git a/tests/test_server_loop.py b/tests/test_server_loop.py index fbd5cc2b..30077178 100644 --- a/tests/test_server_loop.py +++ b/tests/test_server_loop.py @@ -4,8 +4,8 @@ from unittest.mock import Mock, patch import pytest -from sanic.server import loop from sanic.compat import OS_IS_WINDOWS, UVLOOP_INSTALLED +from sanic.server import loop @pytest.mark.skipif( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a1b98a81..63de50af 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,5 +1,4 @@ import asyncio -import sys from asyncio.tasks import Task from unittest.mock import Mock, call @@ -7,9 +6,15 @@ from unittest.mock import Mock, call import pytest from sanic.app import Sanic +from sanic.application.state import ApplicationServerInfo, ServerStage from sanic.response import empty +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + pytestmark = pytest.mark.asyncio @@ -20,11 +25,14 @@ async def dummy(n=0): @pytest.fixture(autouse=True) -def mark_app_running(app): - app.is_running = True +def mark_app_running(app: Sanic): + app.state.server_info.append( + ApplicationServerInfo( + stage=ServerStage.SERVING, settings={}, server=AsyncMock() + ) + ) -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_add_task_returns_task(app: Sanic): task = app.add_task(dummy()) @@ -32,7 +40,6 @@ async def test_add_task_returns_task(app: Sanic): assert len(app._task_registry) == 0 -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_add_task_with_name(app: Sanic): task = app.add_task(dummy(), name="dummy") @@ -44,7 +51,6 @@ async def test_add_task_with_name(app: Sanic): assert task in app._task_registry.values() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_cancel_task(app: Sanic): task = app.add_task(dummy(3), name="dummy") @@ -62,7 +68,6 @@ async def test_cancel_task(app: Sanic): assert task.cancelled() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") async def test_purge_tasks(app: Sanic): app.add_task(dummy(3), name="dummy") @@ -75,7 +80,6 @@ async def test_purge_tasks(app: Sanic): assert len(app._task_registry) == 0 -@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7") def test_shutdown_tasks_on_app_stop(): class TestSanic(Sanic): shutdown_tasks = Mock() diff --git a/tests/test_unix_socket.py b/tests/test_unix_socket.py index c26ebe06..ea3b4b1d 100644 --- a/tests/test_unix_socket.py +++ b/tests/test_unix_socket.py @@ -72,14 +72,12 @@ def test_unix_socket_creation(caplog): assert not os.path.exists(SOCKPATH) -def test_invalid_paths(): +@pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock")) +def test_invalid_paths(path): app = Sanic(name=__name__) - with pytest.raises(FileExistsError): - app.run(unix=".") - - with pytest.raises(FileNotFoundError): - app.run(unix="no-such-directory/sanictest.sock") + with pytest.raises((FileExistsError, FileNotFoundError)): + app.run(unix=path) def test_dont_replace_file(): @@ -201,7 +199,7 @@ async def test_zero_downtime(): for _ in range(40): async with httpx.AsyncClient(transport=transport) as client: r = await client.get("http://localhost/sleep/0.1") - assert r.status_code == 200 + assert r.status_code == 200, r.content assert r.text == "Slept 0.1 seconds.\n" def spawn(): @@ -209,6 +207,7 @@ async def test_zero_downtime(): sys.executable, "-m", "sanic", + "--debug", "--unix", SOCKPATH, "examples.delayed_response.app", diff --git a/tests/test_websockets.py b/tests/test_websockets.py new file mode 100644 index 00000000..329eff45 --- /dev/null +++ b/tests/test_websockets.py @@ -0,0 +1,243 @@ +import re + +from asyncio import Event, Queue, TimeoutError +from unittest.mock import Mock, call + +import pytest + +from websockets.frames import CTRL_OPCODES, DATA_OPCODES, Frame + +from sanic.exceptions import ServerError +from sanic.server.websockets.frame import WebsocketFrameAssembler + + +try: + from unittest.mock import AsyncMock +except ImportError: + from asyncmock import AsyncMock # type: ignore + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_incomplete_timeout_0(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get(0) + + assert data is None + assembler.message_complete.is_set.assert_called_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_in_progress(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.get_in_progress = True + + message = re.escape( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + + with pytest.raises(ServerError, match=message): + await assembler.get() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_incomplete(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get() + + assert data is None + assembler.message_complete.wait.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + data = await assembler.get() + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_with_timeout(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + data = await assembler.get(0.1) + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + assert assembler.message_complete.is_set.call_count == 2 + + +@pytest.mark.asyncio +async def test_ws_frame_get_message_with_timeouterror(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.wait = AsyncMock(return_value=True) + assembler.message_complete.is_set = Mock(return_value=True) + assembler.message_complete.wait.side_effect = TimeoutError("...") + data = await assembler.get(0.1) + + assert data == b"" + assembler.message_complete.wait.assert_awaited_once() + assert assembler.message_complete.is_set.call_count == 2 + + +@pytest.mark.asyncio +async def test_ws_frame_get_not_completed(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=False) + data = await assembler.get() + + assert data is None + + +@pytest.mark.asyncio +async def test_ws_frame_get_not_completed_start(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(side_effect=[False, True]) + data = await assembler.get(0.1) + + assert data is None + + +@pytest.mark.asyncio +async def test_ws_frame_get_paused(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(side_effect=[False, True]) + assembler.paused = True + data = await assembler.get() + + assert data is None + assembler.protocol.resume_frames.assert_called_once() + + +@pytest.mark.asyncio +async def test_ws_frame_get_data(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete = AsyncMock(spec=Event) + assembler.message_complete.is_set = Mock(return_value=True) + assembler.chunks = [b"foo", b"bar"] + data = await assembler.get() + + assert data == b"foobar" + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_in_progress(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.get_in_progress = True + + message = re.escape( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + + with pytest.raises(ServerError, match=message): + [x async for x in assembler.get_iter()] + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_none_in_queue(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + assembler.chunks = [b"foo", b"bar"] + + chunks = [x async for x in assembler.get_iter()] + + assert chunks == [b"foo", b"bar"] + + +@pytest.mark.asyncio +async def test_ws_frame_get_iter_paused(): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + assembler.paused = True + + [x async for x in assembler.get_iter()] + assembler.protocol.resume_frames.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_not_fetched(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_fetched.set() + + message = re.escape( + "Websocket put() got a new message when the previous message was " + "not yet fetched." + ) + with pytest.raises(ServerError, match=message): + await assembler.put(Frame(opcode, b"")) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_fetched(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_fetched = AsyncMock() + assembler.message_fetched.is_set = Mock(return_value=False) + + await assembler.put(Frame(opcode, b"")) + assembler.message_fetched.wait.assert_awaited_once() + assembler.message_fetched.clear.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_message_complete(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.message_complete.set() + + message = re.escape( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + with pytest.raises(ServerError, match=message): + await assembler.put(Frame(opcode, b"")) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_message_into_queue(opcode): + assembler = WebsocketFrameAssembler(Mock()) + assembler.chunks_queue = AsyncMock(spec=Queue) + assembler.message_fetched = AsyncMock() + assembler.message_fetched.is_set = Mock(return_value=False) + + await assembler.put(Frame(opcode, b"foo")) + + assembler.chunks_queue.put.has_calls( + call(b"foo"), + call(None), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", DATA_OPCODES) +async def test_ws_frame_put_not_fin(opcode): + assembler = WebsocketFrameAssembler(Mock()) + + retval = await assembler.put(Frame(opcode, b"foo", fin=False)) + + assert retval is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("opcode", CTRL_OPCODES) +async def test_ws_frame_put_skip_ctrl(opcode): + assembler = WebsocketFrameAssembler(Mock()) + + retval = await assembler.put(Frame(opcode, b"")) + + assert retval is None diff --git a/tests/test_worker.py b/tests/test_worker.py deleted file mode 100644 index cdc30a05..00000000 --- a/tests/test_worker.py +++ /dev/null @@ -1,200 +0,0 @@ -import asyncio -import json -import shlex -import subprocess -import time -import urllib.request - -from unittest import mock - -import pytest - -from sanic_testing.testing import ASGI_PORT as PORT - -from sanic.app import Sanic -from sanic.worker import GunicornWorker - - -@pytest.fixture -def gunicorn_worker(): - command = ( - "gunicorn " - f"--bind 127.0.0.1:{PORT} " - "--worker-class sanic.worker.GunicornWorker " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command)) - time.sleep(2) - yield - worker.kill() - - -@pytest.fixture -def gunicorn_worker_with_access_logs(): - command = ( - "gunicorn " - f"--bind 127.0.0.1:{PORT + 1} " - "--worker-class sanic.worker.GunicornWorker " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE) - time.sleep(2) - return worker - - -@pytest.fixture -def gunicorn_worker_with_env_var(): - command = ( - 'env SANIC_ACCESS_LOG="False" ' - "gunicorn " - f"--bind 127.0.0.1:{PORT + 2} " - "--worker-class sanic.worker.GunicornWorker " - "--log-level info " - "examples.hello_world:app" - ) - worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE) - time.sleep(2) - return worker - - -def test_gunicorn_worker(gunicorn_worker): - with urllib.request.urlopen(f"http://localhost:{PORT}/") as f: - res = json.loads(f.read(100).decode()) - assert res["test"] - - -def test_gunicorn_worker_no_logs(gunicorn_worker_with_env_var): - """ - if SANIC_ACCESS_LOG was set to False do not show access logs - """ - with urllib.request.urlopen(f"http://localhost:{PORT + 2}/") as _: - gunicorn_worker_with_env_var.kill() - logs = list( - filter( - lambda x: b"sanic.access" in x, - gunicorn_worker_with_env_var.stdout.read().split(b"\n"), - ) - ) - assert len(logs) == 0 - - -def test_gunicorn_worker_with_logs(gunicorn_worker_with_access_logs): - """ - default - show access logs - """ - with urllib.request.urlopen(f"http://localhost:{PORT + 1}/") as _: - gunicorn_worker_with_access_logs.kill() - assert ( - b"(sanic.access)[INFO][127.0.0.1" - in gunicorn_worker_with_access_logs.stdout.read() - ) - - -class GunicornTestWorker(GunicornWorker): - def __init__(self): - self.app = mock.Mock() - self.app.callable = Sanic("test_gunicorn_worker") - self.servers = {} - self.exit_code = 0 - self.cfg = mock.Mock() - self.notify = mock.Mock() - - -@pytest.fixture -def worker(): - return GunicornTestWorker() - - -def test_worker_init_process(worker): - with mock.patch("sanic.worker.asyncio") as mock_asyncio: - try: - worker.init_process() - except TypeError: - pass - - assert mock_asyncio.get_event_loop.return_value.close.called - assert mock_asyncio.new_event_loop.called - assert mock_asyncio.set_event_loop.called - - -def test_worker_init_signals(worker): - worker.loop = mock.Mock() - worker.init_signals() - assert worker.loop.add_signal_handler.called - - -def test_handle_abort(worker): - with mock.patch("sanic.worker.sys") as mock_sys: - worker.handle_abort(object(), object()) - assert not worker.alive - assert worker.exit_code == 1 - mock_sys.exit.assert_called_with(1) - - -def test_handle_quit(worker): - worker.handle_quit(object(), object()) - assert not worker.alive - assert worker.exit_code == 0 - - -async def _a_noop(*a, **kw): - pass - - -def test_run_max_requests_exceeded(worker): - loop = asyncio.new_event_loop() - worker.ppid = 1 - worker.alive = True - sock = mock.Mock() - sock.cfg_addr = ("localhost", 8080) - worker.sockets = [sock] - worker.wsgi = mock.Mock() - worker.connections = set() - worker.log = mock.Mock() - worker.loop = loop - worker.servers = { - "server1": {"requests_count": 14}, - "server2": {"requests_count": 15}, - } - worker.max_requests = 10 - worker._run = mock.Mock(wraps=_a_noop) - - # exceeding request count - _runner = asyncio.ensure_future(worker._check_alive(), loop=loop) - loop.run_until_complete(_runner) - - assert not worker.alive - worker.notify.assert_called_with() - worker.log.info.assert_called_with( - "Max requests exceeded, shutting " "down: %s", worker - ) - - -def test_worker_close(worker): - loop = asyncio.new_event_loop() - asyncio.sleep = mock.Mock(wraps=_a_noop) - worker.ppid = 1 - worker.pid = 2 - worker.cfg.graceful_timeout = 1.0 - worker.signal = mock.Mock() - worker.signal.stopped = False - worker.wsgi = mock.Mock() - conn = mock.Mock() - conn.websocket = mock.Mock() - conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) - worker.connections = set([conn]) - worker.log = mock.Mock() - worker.loop = loop - server = mock.Mock() - server.close = mock.Mock(wraps=lambda *a, **kw: None) - server.wait_closed = mock.Mock(wraps=_a_noop) - worker.servers = {server: {"requests_count": 14}} - worker.max_requests = 10 - - # close worker - _close = asyncio.ensure_future(worker.close(), loop=loop) - loop.run_until_complete(_close) - - assert worker.signal.stopped - assert conn.websocket.fail_connection.called - assert len(worker.servers) == 0 diff --git a/tox.ini b/tox.ini index 609ceb48..68919e59 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,9 @@ setenv = {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UJSON=1 {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 extras = test +allowlist_externals = + pytest + coverage commands = pytest {posargs:tests --cov sanic} - coverage combine --append @@ -41,7 +44,7 @@ commands = [testenv:docs] platform = linux|linux2|darwin -whitelist_externals = make +allowlist_externals = make extras = docs commands = make docs-test