Add support for 3.7 checking of mock transport
This commit is contained in:
parent
070236677c
commit
c7307c1f85
|
@ -1,108 +1,19 @@
|
||||||
import asyncio
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
from typing import Optional
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import sanic.app # noqa
|
import sanic.app # noqa
|
||||||
|
|
||||||
from sanic.compat import Header
|
from sanic.compat import Header
|
||||||
from sanic.exceptions import InvalidUsage, ServerError
|
from sanic.exceptions import ServerError
|
||||||
|
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
from sanic.server import ConnInfo
|
from sanic.server import ConnInfo
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.websocket import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
ASGIScope = MutableMapping[str, Any]
|
|
||||||
ASGIMessage = MutableMapping[str, Any]
|
|
||||||
ASGISend = Callable[[ASGIMessage], Awaitable[None]]
|
|
||||||
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
|
|
||||||
|
|
||||||
|
|
||||||
class MockProtocol:
|
|
||||||
def __init__(self, transport: "MockTransport", loop):
|
|
||||||
self.transport = transport
|
|
||||||
self._not_paused = asyncio.Event(loop=loop)
|
|
||||||
self._not_paused.set()
|
|
||||||
self._complete = asyncio.Event(loop=loop)
|
|
||||||
|
|
||||||
def pause_writing(self) -> None:
|
|
||||||
self._not_paused.clear()
|
|
||||||
|
|
||||||
def resume_writing(self) -> None:
|
|
||||||
self._not_paused.set()
|
|
||||||
|
|
||||||
async def complete(self) -> None:
|
|
||||||
self._not_paused.set()
|
|
||||||
await self.transport.send(
|
|
||||||
{"type": "http.response.body", "body": b"", "more_body": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_complete(self) -> bool:
|
|
||||||
return self._complete.is_set()
|
|
||||||
|
|
||||||
async def push_data(self, data: bytes) -> None:
|
|
||||||
if not self.is_complete:
|
|
||||||
await self.transport.send(
|
|
||||||
{"type": "http.response.body", "body": data, "more_body": True}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def drain(self) -> None:
|
|
||||||
await self._not_paused.wait()
|
|
||||||
|
|
||||||
|
|
||||||
class MockTransport:
|
|
||||||
_protocol: Optional[MockProtocol]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
|
|
||||||
) -> None:
|
|
||||||
self.scope = scope
|
|
||||||
self._receive = receive
|
|
||||||
self._send = send
|
|
||||||
self._protocol = None
|
|
||||||
self.loop = None
|
|
||||||
|
|
||||||
def get_protocol(self) -> MockProtocol:
|
|
||||||
if not self._protocol:
|
|
||||||
self._protocol = MockProtocol(self, self.loop)
|
|
||||||
return self._protocol
|
|
||||||
|
|
||||||
def get_extra_info(self, info: str) -> Union[str, bool, None]:
|
|
||||||
if info == "peername":
|
|
||||||
return self.scope.get("client")
|
|
||||||
elif info == "sslcontext":
|
|
||||||
return self.scope.get("scheme") in ["https", "wss"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_websocket_connection(self) -> WebSocketConnection:
|
|
||||||
try:
|
|
||||||
return self._websocket_connection
|
|
||||||
except AttributeError:
|
|
||||||
raise InvalidUsage("Improper websocket connection.")
|
|
||||||
|
|
||||||
def create_websocket_connection(
|
|
||||||
self, send: ASGISend, receive: ASGIReceive
|
|
||||||
) -> WebSocketConnection:
|
|
||||||
self._websocket_connection = WebSocketConnection(
|
|
||||||
send, receive, self.scope.get("subprotocols", [])
|
|
||||||
)
|
|
||||||
return self._websocket_connection
|
|
||||||
|
|
||||||
def add_task(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def send(self, data) -> None:
|
|
||||||
# TODO:
|
|
||||||
# - Validation on data and that it is formatted properly and is valid
|
|
||||||
await self._send(data)
|
|
||||||
|
|
||||||
async def receive(self) -> ASGIMessage:
|
|
||||||
return await self._receive()
|
|
||||||
|
|
||||||
|
|
||||||
class Lifespan:
|
class Lifespan:
|
||||||
def __init__(self, asgi_app: "ASGIApp") -> None:
|
def __init__(self, asgi_app: "ASGIApp") -> None:
|
||||||
self.asgi_app = asgi_app
|
self.asgi_app = asgi_app
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
from typing import Protocol, Union
|
|
||||||
|
|
||||||
|
|
||||||
class TransportProtocol(Protocol):
|
|
||||||
def get_protocol(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_extra_info(self, info: str) -> Union[str, bool, None]:
|
|
||||||
...
|
|
|
@ -38,7 +38,7 @@ from sanic.headers import (
|
||||||
parse_xforwarded,
|
parse_xforwarded,
|
||||||
)
|
)
|
||||||
from sanic.log import error_logger, logger
|
from sanic.log import error_logger, logger
|
||||||
from sanic.models.protocol_types import TransportProtocol
|
from sanic.models.transport import TransportProtocol
|
||||||
from sanic.response import BaseHTTPResponse, HTTPResponse
|
from sanic.response import BaseHTTPResponse, HTTPResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ from sanic.config import Config
|
||||||
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
from sanic.exceptions import RequestTimeout, ServiceUnavailable
|
||||||
from sanic.http import Http, Stage
|
from sanic.http import Http, Stage
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.models.protocol_types import TransportProtocol
|
from sanic.models.transport import TransportProtocol
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user