diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py new file mode 100644 index 00000000..595b0553 --- /dev/null +++ b/sanic/models/asgi.py @@ -0,0 +1,95 @@ +import asyncio + +from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union + +from sanic.exceptions import InvalidUsage +from sanic.websocket import WebSocketConnection + + +ASGIScope = MutableMapping[str, Any] +ASGIMessage = MutableMapping[str, Any] +ASGISend = Callable[[ASGIMessage], Awaitable[None]] +ASGIReceive = Callable[[], Awaitable[ASGIMessage]] + + +class MockProtocol: + def __init__(self, transport: "MockTransport", loop): + self.transport = transport + self._not_paused = asyncio.Event(loop=loop) + self._not_paused.set() + self._complete = asyncio.Event(loop=loop) + + def pause_writing(self) -> None: + self._not_paused.clear() + + def resume_writing(self) -> None: + self._not_paused.set() + + async def complete(self) -> None: + self._not_paused.set() + await self.transport.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @property + def is_complete(self) -> bool: + return self._complete.is_set() + + async def push_data(self, data: bytes) -> None: + if not self.is_complete: + await self.transport.send( + {"type": "http.response.body", "body": data, "more_body": True} + ) + + async def drain(self) -> None: + await self._not_paused.wait() + + +class MockTransport: + _protocol: Optional[MockProtocol] + + def __init__( + self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: + self.scope = scope + self._receive = receive + self._send = send + self._protocol = None + self.loop = None + + def get_protocol(self) -> MockProtocol: + if not self._protocol: + self._protocol = MockProtocol(self, self.loop) + return self._protocol + + def get_extra_info(self, info: str) -> Union[str, bool, None]: + if info == "peername": + return self.scope.get("client") + elif info == "sslcontext": + return self.scope.get("scheme") in ["https", "wss"] + return None + + def get_websocket_connection(self) -> WebSocketConnection: + try: + return self._websocket_connection + except AttributeError: + raise InvalidUsage("Improper websocket connection.") + + def create_websocket_connection( + self, send: ASGISend, receive: ASGIReceive + ) -> WebSocketConnection: + self._websocket_connection = WebSocketConnection( + send, receive, self.scope.get("subprotocols", []) + ) + return self._websocket_connection + + def add_task(self) -> None: + raise NotImplementedError + + async def send(self, data) -> None: + # TODO: + # - Validation on data and that it is formatted properly and is valid + await self._send(data) + + async def receive(self) -> ASGIMessage: + return await self._receive() diff --git a/sanic/models/transport.py b/sanic/models/transport.py new file mode 100644 index 00000000..fc612dd6 --- /dev/null +++ b/sanic/models/transport.py @@ -0,0 +1,21 @@ +import sys + +from typing import Union + + +if sys.version_info < (3, 8): + from asyncio import BaseTransport + + from sanic.models.asgi import MockTransport + + TransportProtocol = Union[MockTransport, BaseTransport] +else: + # Protocol is a 3.8+ feature + from typing import Protocol + + class TransportProtocol(Protocol): + def get_protocol(self): + ... + + def get_extra_info(self, info: str) -> Union[str, bool, None]: + ...