From c7307c1f85dfa25cced1620925d1848510fd5b72 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 22 Feb 2021 13:39:10 +0200 Subject: [PATCH] Add support for 3.7 checking of mock transport --- sanic/asgi.py | 95 ++-------------------------------- sanic/models/protocol_types.py | 9 ---- sanic/request.py | 2 +- sanic/server.py | 2 +- 4 files changed, 5 insertions(+), 103 deletions(-) delete mode 100644 sanic/models/protocol_types.py diff --git a/sanic/asgi.py b/sanic/asgi.py index 73b2c99e..dd9f8c11 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,108 +1,19 @@ -import asyncio import warnings from inspect import isawaitable -from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union +from typing import Optional from urllib.parse import quote import sanic.app # noqa from sanic.compat import Header -from sanic.exceptions import InvalidUsage, ServerError +from sanic.exceptions import ServerError +from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.server import ConnInfo from sanic.websocket import WebSocketConnection -ASGIScope = MutableMapping[str, Any] -ASGIMessage = MutableMapping[str, Any] -ASGISend = Callable[[ASGIMessage], Awaitable[None]] -ASGIReceive = Callable[[], Awaitable[ASGIMessage]] - - -class MockProtocol: - def __init__(self, transport: "MockTransport", loop): - self.transport = transport - self._not_paused = asyncio.Event(loop=loop) - self._not_paused.set() - self._complete = asyncio.Event(loop=loop) - - def pause_writing(self) -> None: - self._not_paused.clear() - - def resume_writing(self) -> None: - self._not_paused.set() - - async def complete(self) -> None: - self._not_paused.set() - await self.transport.send( - {"type": "http.response.body", "body": b"", "more_body": False} - ) - - @property - def is_complete(self) -> bool: - return self._complete.is_set() - - async def push_data(self, data: bytes) -> None: - if not self.is_complete: - await self.transport.send( - {"type": "http.response.body", "body": data, "more_body": True} - ) - - async def drain(self) -> None: - await self._not_paused.wait() - - -class MockTransport: - _protocol: Optional[MockProtocol] - - def __init__( - self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend - ) -> None: - self.scope = scope - self._receive = receive - self._send = send - self._protocol = None - self.loop = None - - def get_protocol(self) -> MockProtocol: - if not self._protocol: - self._protocol = MockProtocol(self, self.loop) - return self._protocol - - def get_extra_info(self, info: str) -> Union[str, bool, None]: - if info == "peername": - return self.scope.get("client") - elif info == "sslcontext": - return self.scope.get("scheme") in ["https", "wss"] - return None - - def get_websocket_connection(self) -> WebSocketConnection: - try: - return self._websocket_connection - except AttributeError: - raise InvalidUsage("Improper websocket connection.") - - def create_websocket_connection( - self, send: ASGISend, receive: ASGIReceive - ) -> WebSocketConnection: - self._websocket_connection = WebSocketConnection( - send, receive, self.scope.get("subprotocols", []) - ) - return self._websocket_connection - - def add_task(self) -> None: - raise NotImplementedError - - async def send(self, data) -> None: - # TODO: - # - Validation on data and that it is formatted properly and is valid - await self._send(data) - - async def receive(self) -> ASGIMessage: - return await self._receive() - - class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app diff --git a/sanic/models/protocol_types.py b/sanic/models/protocol_types.py deleted file mode 100644 index 48616ad5..00000000 --- a/sanic/models/protocol_types.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Protocol, Union - - -class TransportProtocol(Protocol): - def get_protocol(self): - ... - - def get_extra_info(self, info: str) -> Union[str, bool, None]: - ... diff --git a/sanic/request.py b/sanic/request.py index c17bb324..895c73ad 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -38,7 +38,7 @@ from sanic.headers import ( parse_xforwarded, ) from sanic.log import error_logger, logger -from sanic.models.protocol_types import TransportProtocol +from sanic.models.transport import TransportProtocol from sanic.response import BaseHTTPResponse, HTTPResponse diff --git a/sanic/server.py b/sanic/server.py index 82869208..39f6ced2 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -39,7 +39,7 @@ from sanic.config import Config from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.http import Http, Stage from sanic.log import logger -from sanic.models.protocol_types import TransportProtocol +from sanic.models.transport import TransportProtocol from sanic.request import Request