Streaming responses

This commit is contained in:
Adam Hopkins 2019-05-27 02:11:52 +03:00
parent 3ead529693
commit 22c0d97783
4 changed files with 114 additions and 37 deletions

View File

@ -1,12 +1,12 @@
"""
1. Create a simple Sanic app
2. Run with an ASGI server:
0. Run with an ASGI server:
$ uvicorn run_asgi:app
or
$ hypercorn run_asgi:app
"""
import os
from pathlib import Path
from sanic import Sanic, response
@ -14,17 +14,17 @@ app = Sanic(__name__)
@app.route("/text")
def handler(request):
def handler_text(request):
return response.text("Hello")
@app.route("/json")
def handler_foo(request):
return response.text("bar")
def handler_json(request):
return response.json({"foo": "bar"})
@app.websocket("/ws")
async def feed(request, ws):
async def handler_ws(request, ws):
name = "<someone>"
while True:
data = f"Hello {name}"
@ -36,12 +36,23 @@ async def feed(request, ws):
@app.route("/file")
async def test_file(request):
return await response.file(os.path.abspath("setup.py"))
async def handler_file(request):
return await response.file(Path("../") / "setup.py")
@app.route("/file_stream")
async def test_file_stream(request):
async def handler_file_stream(request):
return await response.file_stream(
os.path.abspath("setup.py"), chunk_size=1024
Path("../") / "setup.py", chunk_size=1024
)
@app.route("/stream", stream=True)
async def handler_stream(request):
while True:
body = await request.stream.read()
if body is None:
break
body = body.decode("utf-8").replace("1", "A")
# await response.write(body)
return stream(streaming)

View File

@ -1,7 +1,7 @@
from typing import Any, Awaitable, Callable, MutableMapping, Union
import asyncio
from multidict import CIMultiDict
from functools import partial
from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection
@ -13,9 +13,54 @@ ASGISend = Callable[[ASGIMessage], Awaitable[None]]
ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
class MockProtocol:
def __init__(self, transport: "MockTransport", loop):
self.transport = transport
self._not_paused = asyncio.Event(loop=loop)
self._not_paused.set()
self._complete = asyncio.Event(loop=loop)
def pause_writing(self):
self._not_paused.clear()
def resume_writing(self):
self._not_paused.set()
async def complete(self):
self._not_paused.set()
await self.transport.send(
{"type": "http.response.body", "body": b"", "more_body": False}
)
@property
def is_complete(self):
return self._complete.is_set()
async def push_data(self, data):
if not self.is_complete:
await self.transport.send(
{"type": "http.response.body", "body": data, "more_body": True}
)
async def drain(self):
print("draining")
await self._not_paused.wait()
class MockTransport:
def __init__(self, scope: ASGIScope) -> None:
def __init__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope
self._receive = receive
self._send = send
self._protocol = None
self.loop = None
def get_protocol(self):
if not self._protocol:
self._protocol = MockProtocol(self, self.loop)
return self._protocol
def get_extra_info(self, info: str) -> Union[str, bool]:
if info == "peername":
@ -32,6 +77,18 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection
def add_task(self):
raise NotImplementedError
async def send(self, data):
print(">> sending. more:", data.get("more_body"))
# TODO:
# - Validation on data and that it is formatted properly and is valid
await self._send(data)
async def receive(self):
return await self._receive()
class ASGIApp:
def __init__(self) -> None:
@ -43,8 +100,9 @@ class ASGIApp:
) -> "ASGIApp":
instance = cls()
instance.sanic_app = sanic_app
instance.receive = receive
instance.send = send
instance.transport = MockTransport(scope, receive, send)
instance.transport.add_task = sanic_app.loop.create_task
instance.transport.loop = sanic_app.loop
url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1")
@ -60,8 +118,6 @@ class ASGIApp:
True if headers.get("expect") == "100-continue" else False
)
transport = MockTransport(scope)
if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
@ -69,7 +125,9 @@ class ASGIApp:
version = "1.1"
method = "GET"
instance.ws = transport.create_websocket_connection(send, receive)
instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept()
else:
pass
@ -77,7 +135,7 @@ class ASGIApp:
# - close connection
instance.request = Request(
url_bytes, headers, version, method, transport, sanic_app
url_bytes, headers, version, method, instance.transport, sanic_app
)
if sanic_app.is_request_stream:
@ -92,7 +150,7 @@ class ASGIApp:
body = b""
more_body = True
while more_body:
message = await self.receive()
message = await self.transport.receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
@ -105,7 +163,7 @@ class ASGIApp:
more_body = True
while more_body:
message = await self.receive()
message = await self.transport.receive()
chunk = message.get("body", b"")
await self.request.stream.put(chunk)
# self.sanic_app.loop.create_task(self.request.stream.put(chunk))
@ -131,29 +189,37 @@ class ASGIApp:
"""
Write the response.
"""
if isinstance(response, StreamingHTTPResponse):
raise NotImplementedError("Not supported")
headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items()
]
if "content-length" not in response.headers:
if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse
):
headers += [
(b"content-length", str(len(response.body)).encode("latin-1"))
]
await self.send(
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": headers,
}
)
await self.send(
{
"type": "http.response.body",
"body": response.body,
"more_body": False,
}
)
if isinstance(response, StreamingHTTPResponse):
response.protocol = self.transport.get_protocol()
await response.stream()
await response.protocol.complete()
else:
await self.transport.send(
{
"type": "http.response.body",
"body": response.body,
"more_body": False,
}
)

View File

@ -87,9 +87,9 @@ class StreamingHTTPResponse(BaseHTTPResponse):
data = self._encode_body(data)
if self.chunked:
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
else:
self.protocol.push_data(data)
await self.protocol.push_data(data)
await self.protocol.drain()
async def stream(
@ -105,11 +105,11 @@ class StreamingHTTPResponse(BaseHTTPResponse):
keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
)
self.protocol.push_data(headers)
await self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
if self.chunked:
self.protocol.push_data(b"0\r\n\r\n")
await self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.

View File

@ -457,7 +457,7 @@ class HttpProtocol(asyncio.Protocol):
async def drain(self):
await self._not_paused.wait()
def push_data(self, data):
async def push_data(self, data):
self.transport.write(data)
async def stream_response(self, response):