diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 81333383..818faaf9 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -1,12 +1,12 @@ """ 1. Create a simple Sanic app -2. Run with an ASGI server: +0. Run with an ASGI server: $ uvicorn run_asgi:app or $ hypercorn run_asgi:app """ -import os +from pathlib import Path from sanic import Sanic, response @@ -14,17 +14,17 @@ app = Sanic(__name__) @app.route("/text") -def handler(request): +def handler_text(request): return response.text("Hello") @app.route("/json") -def handler_foo(request): - return response.text("bar") +def handler_json(request): + return response.json({"foo": "bar"}) @app.websocket("/ws") -async def feed(request, ws): +async def handler_ws(request, ws): name = "" while True: data = f"Hello {name}" @@ -36,12 +36,23 @@ async def feed(request, ws): @app.route("/file") -async def test_file(request): - return await response.file(os.path.abspath("setup.py")) +async def handler_file(request): + return await response.file(Path("../") / "setup.py") @app.route("/file_stream") -async def test_file_stream(request): +async def handler_file_stream(request): return await response.file_stream( - os.path.abspath("setup.py"), chunk_size=1024 + Path("../") / "setup.py", chunk_size=1024 ) + + +@app.route("/stream", stream=True) +async def handler_stream(request): + while True: + body = await request.stream.read() + if body is None: + break + body = body.decode("utf-8").replace("1", "A") + # await response.write(body) + return stream(streaming) diff --git a/sanic/asgi.py b/sanic/asgi.py index fa853f5a..8ed448e3 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,7 +1,7 @@ from typing import Any, Awaitable, Callable, MutableMapping, Union - +import asyncio from multidict import CIMultiDict - +from functools import partial from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.websocket import WebSocketConnection @@ -13,9 +13,54 @@ ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGIReceive = Callable[[], Awaitable[ASGIMessage]] +class MockProtocol: + def __init__(self, transport: "MockTransport", loop): + self.transport = transport + self._not_paused = asyncio.Event(loop=loop) + self._not_paused.set() + self._complete = asyncio.Event(loop=loop) + + def pause_writing(self): + self._not_paused.clear() + + def resume_writing(self): + self._not_paused.set() + + async def complete(self): + self._not_paused.set() + await self.transport.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @property + def is_complete(self): + return self._complete.is_set() + + async def push_data(self, data): + if not self.is_complete: + await self.transport.send( + {"type": "http.response.body", "body": data, "more_body": True} + ) + + async def drain(self): + print("draining") + await self._not_paused.wait() + + class MockTransport: - def __init__(self, scope: ASGIScope) -> None: + def __init__( + self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: self.scope = scope + self._receive = receive + self._send = send + self._protocol = None + self.loop = None + + def get_protocol(self): + if not self._protocol: + self._protocol = MockProtocol(self, self.loop) + return self._protocol def get_extra_info(self, info: str) -> Union[str, bool]: if info == "peername": @@ -32,6 +77,18 @@ class MockTransport: self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection + def add_task(self): + raise NotImplementedError + + async def send(self, data): + print(">> sending. more:", data.get("more_body")) + # TODO: + # - Validation on data and that it is formatted properly and is valid + await self._send(data) + + async def receive(self): + return await self._receive() + class ASGIApp: def __init__(self) -> None: @@ -43,8 +100,9 @@ class ASGIApp: ) -> "ASGIApp": instance = cls() instance.sanic_app = sanic_app - instance.receive = receive - instance.send = send + instance.transport = MockTransport(scope, receive, send) + instance.transport.add_task = sanic_app.loop.create_task + instance.transport.loop = sanic_app.loop url_bytes = scope.get("root_path", "") + scope["path"] url_bytes = url_bytes.encode("latin-1") @@ -60,8 +118,6 @@ class ASGIApp: True if headers.get("expect") == "100-continue" else False ) - transport = MockTransport(scope) - if scope["type"] == "http": version = scope["http_version"] method = scope["method"] @@ -69,7 +125,9 @@ class ASGIApp: version = "1.1" method = "GET" - instance.ws = transport.create_websocket_connection(send, receive) + instance.ws = instance.transport.create_websocket_connection( + send, receive + ) await instance.ws.accept() else: pass @@ -77,7 +135,7 @@ class ASGIApp: # - close connection instance.request = Request( - url_bytes, headers, version, method, transport, sanic_app + url_bytes, headers, version, method, instance.transport, sanic_app ) if sanic_app.is_request_stream: @@ -92,7 +150,7 @@ class ASGIApp: body = b"" more_body = True while more_body: - message = await self.receive() + message = await self.transport.receive() body += message.get("body", b"") more_body = message.get("more_body", False) @@ -105,7 +163,7 @@ class ASGIApp: more_body = True while more_body: - message = await self.receive() + message = await self.transport.receive() chunk = message.get("body", b"") await self.request.stream.put(chunk) # self.sanic_app.loop.create_task(self.request.stream.put(chunk)) @@ -131,29 +189,37 @@ class ASGIApp: """ Write the response. """ - if isinstance(response, StreamingHTTPResponse): - raise NotImplementedError("Not supported") headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() ] - if "content-length" not in response.headers: + + if "content-length" not in response.headers and not isinstance( + response, StreamingHTTPResponse + ): headers += [ (b"content-length", str(len(response.body)).encode("latin-1")) ] - await self.send( + await self.transport.send( { "type": "http.response.start", "status": response.status, "headers": headers, } ) - await self.send( - { - "type": "http.response.body", - "body": response.body, - "more_body": False, - } - ) + + if isinstance(response, StreamingHTTPResponse): + response.protocol = self.transport.get_protocol() + await response.stream() + await response.protocol.complete() + + else: + await self.transport.send( + { + "type": "http.response.body", + "body": response.body, + "more_body": False, + } + ) diff --git a/sanic/response.py b/sanic/response.py index be178eff..34f59e66 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -87,9 +87,9 @@ class StreamingHTTPResponse(BaseHTTPResponse): data = self._encode_body(data) if self.chunked: - self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) + await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) else: - self.protocol.push_data(data) + await self.protocol.push_data(data) await self.protocol.drain() async def stream( @@ -105,11 +105,11 @@ class StreamingHTTPResponse(BaseHTTPResponse): keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout, ) - self.protocol.push_data(headers) + await self.protocol.push_data(headers) await self.protocol.drain() await self.streaming_fn(self) if self.chunked: - self.protocol.push_data(b"0\r\n\r\n") + await self.protocol.push_data(b"0\r\n\r\n") # no need to await drain here after this write, because it is the # very last thing we write and nothing needs to wait for it. diff --git a/sanic/server.py b/sanic/server.py index 4f0bea37..c7e96676 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -457,7 +457,7 @@ class HttpProtocol(asyncio.Protocol): async def drain(self): await self._not_paused.wait() - def push_data(self, data): + async def push_data(self, data): self.transport.write(data) async def stream_response(self, response):