Add support for 3.7 checking of mock transport

This commit is contained in:
Adam Hopkins 2021-02-22 13:39:10 +02:00
parent 070236677c
commit c7307c1f85
4 changed files with 5 additions and 103 deletions

View File

@ -1,108 +1,19 @@
import asyncio
import warnings
from inspect import isawaitable
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from typing import Optional
from urllib.parse import quote
import sanic.app # noqa
from sanic.compat import Header
from sanic.exceptions import InvalidUsage, ServerError
from sanic.exceptions import ServerError
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
from sanic.request import Request
from sanic.server import ConnInfo
from sanic.websocket import WebSocketConnection
ASGIScope = MutableMapping[str, Any]
ASGIMessage = MutableMapping[str, Any]
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
class MockProtocol:
def __init__(self, transport: "MockTransport", loop):
self.transport = transport
self._not_paused = asyncio.Event(loop=loop)
self._not_paused.set()
self._complete = asyncio.Event(loop=loop)
def pause_writing(self) -> None:
self._not_paused.clear()
def resume_writing(self) -> None:
self._not_paused.set()
async def complete(self) -> None:
self._not_paused.set()
await self.transport.send(
{"type": "http.response.body", "body": b"", "more_body": False}
)
@property
def is_complete(self) -> bool:
return self._complete.is_set()
async def push_data(self, data: bytes) -> None:
if not self.is_complete:
await self.transport.send(
{"type": "http.response.body", "body": data, "more_body": True}
)
async def drain(self) -> None:
await self._not_paused.wait()
class MockTransport:
_protocol: Optional[MockProtocol]
def __init__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope
self._receive = receive
self._send = send
self._protocol = None
self.loop = None
def get_protocol(self) -> MockProtocol:
if not self._protocol:
self._protocol = MockProtocol(self, self.loop)
return self._protocol
def get_extra_info(self, info: str) -> Union[str, bool, None]:
if info == "peername":
return self.scope.get("client")
elif info == "sslcontext":
return self.scope.get("scheme") in ["https", "wss"]
return None
def get_websocket_connection(self) -> WebSocketConnection:
try:
return self._websocket_connection
except AttributeError:
raise InvalidUsage("Improper websocket connection.")
def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive
) -> WebSocketConnection:
self._websocket_connection = WebSocketConnection(
send, receive, self.scope.get("subprotocols", [])
)
return self._websocket_connection
def add_task(self) -> None:
raise NotImplementedError
async def send(self, data) -> None:
# TODO:
# - Validation on data and that it is formatted properly and is valid
await self._send(data)
async def receive(self) -> ASGIMessage:
return await self._receive()
class Lifespan:
def __init__(self, asgi_app: "ASGIApp") -> None:
self.asgi_app = asgi_app

View File

@ -1,9 +0,0 @@
from typing import Protocol, Union
class TransportProtocol(Protocol):
def get_protocol(self):
...
def get_extra_info(self, info: str) -> Union[str, bool, None]:
...

View File

@ -38,7 +38,7 @@ from sanic.headers import (
parse_xforwarded,
)
from sanic.log import error_logger, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.transport import TransportProtocol
from sanic.response import BaseHTTPResponse, HTTPResponse

View File

@ -39,7 +39,7 @@ from sanic.config import Config
from sanic.exceptions import RequestTimeout, ServiceUnavailable
from sanic.http import Http, Stage
from sanic.log import logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.transport import TransportProtocol
from sanic.request import Request