Streaming responses

This commit is contained in:
Adam Hopkins 2019-05-27 02:11:52 +03:00
parent 3ead529693
commit 22c0d97783
4 changed files with 114 additions and 37 deletions

View File

@ -1,12 +1,12 @@
""" """
1. Create a simple Sanic app 1. Create a simple Sanic app
2. Run with an ASGI server: 0. Run with an ASGI server:
$ uvicorn run_asgi:app $ uvicorn run_asgi:app
or or
$ hypercorn run_asgi:app $ hypercorn run_asgi:app
""" """
import os from pathlib import Path
from sanic import Sanic, response from sanic import Sanic, response
@ -14,17 +14,17 @@ app = Sanic(__name__)
@app.route("/text") @app.route("/text")
def handler(request): def handler_text(request):
return response.text("Hello") return response.text("Hello")
@app.route("/json") @app.route("/json")
def handler_foo(request): def handler_json(request):
return response.text("bar") return response.json({"foo": "bar"})
@app.websocket("/ws") @app.websocket("/ws")
async def feed(request, ws): async def handler_ws(request, ws):
name = "<someone>" name = "<someone>"
while True: while True:
data = f"Hello {name}" data = f"Hello {name}"
@ -36,12 +36,23 @@ async def feed(request, ws):
@app.route("/file") @app.route("/file")
async def test_file(request): async def handler_file(request):
return await response.file(os.path.abspath("setup.py")) return await response.file(Path("../") / "setup.py")
@app.route("/file_stream") @app.route("/file_stream")
async def test_file_stream(request): async def handler_file_stream(request):
return await response.file_stream( return await response.file_stream(
os.path.abspath("setup.py"), chunk_size=1024 Path("../") / "setup.py", chunk_size=1024
) )
@app.route("/stream", stream=True)
async def handler_stream(request):
while True:
body = await request.stream.read()
if body is None:
break
body = body.decode("utf-8").replace("1", "A")
# await response.write(body)
return stream(streaming)

View File

@ -1,7 +1,7 @@
from typing import Any, Awaitable, Callable, MutableMapping, Union from typing import Any, Awaitable, Callable, MutableMapping, Union
import asyncio
from multidict import CIMultiDict from multidict import CIMultiDict
from functools import partial
from sanic.request import Request from sanic.request import Request
from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.websocket import WebSocketConnection from sanic.websocket import WebSocketConnection
@ -13,9 +13,54 @@ ASGISend = Callable[[ASGIMessage], Awaitable[None]]
ASGIReceive = Callable[[], Awaitable[ASGIMessage]] ASGIReceive = Callable[[], Awaitable[ASGIMessage]]
class MockProtocol:
def __init__(self, transport: "MockTransport", loop):
self.transport = transport
self._not_paused = asyncio.Event(loop=loop)
self._not_paused.set()
self._complete = asyncio.Event(loop=loop)
def pause_writing(self):
self._not_paused.clear()
def resume_writing(self):
self._not_paused.set()
async def complete(self):
self._not_paused.set()
await self.transport.send(
{"type": "http.response.body", "body": b"", "more_body": False}
)
@property
def is_complete(self):
return self._complete.is_set()
async def push_data(self, data):
if not self.is_complete:
await self.transport.send(
{"type": "http.response.body", "body": data, "more_body": True}
)
async def drain(self):
print("draining")
await self._not_paused.wait()
class MockTransport: class MockTransport:
def __init__(self, scope: ASGIScope) -> None: def __init__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.scope = scope self.scope = scope
self._receive = receive
self._send = send
self._protocol = None
self.loop = None
def get_protocol(self):
if not self._protocol:
self._protocol = MockProtocol(self, self.loop)
return self._protocol
def get_extra_info(self, info: str) -> Union[str, bool]: def get_extra_info(self, info: str) -> Union[str, bool]:
if info == "peername": if info == "peername":
@ -32,6 +77,18 @@ class MockTransport:
self._websocket_connection = WebSocketConnection(send, receive) self._websocket_connection = WebSocketConnection(send, receive)
return self._websocket_connection return self._websocket_connection
def add_task(self):
raise NotImplementedError
async def send(self, data):
print(">> sending. more:", data.get("more_body"))
# TODO:
# - Validation on data and that it is formatted properly and is valid
await self._send(data)
async def receive(self):
return await self._receive()
class ASGIApp: class ASGIApp:
def __init__(self) -> None: def __init__(self) -> None:
@ -43,8 +100,9 @@ class ASGIApp:
) -> "ASGIApp": ) -> "ASGIApp":
instance = cls() instance = cls()
instance.sanic_app = sanic_app instance.sanic_app = sanic_app
instance.receive = receive instance.transport = MockTransport(scope, receive, send)
instance.send = send instance.transport.add_task = sanic_app.loop.create_task
instance.transport.loop = sanic_app.loop
url_bytes = scope.get("root_path", "") + scope["path"] url_bytes = scope.get("root_path", "") + scope["path"]
url_bytes = url_bytes.encode("latin-1") url_bytes = url_bytes.encode("latin-1")
@ -60,8 +118,6 @@ class ASGIApp:
True if headers.get("expect") == "100-continue" else False True if headers.get("expect") == "100-continue" else False
) )
transport = MockTransport(scope)
if scope["type"] == "http": if scope["type"] == "http":
version = scope["http_version"] version = scope["http_version"]
method = scope["method"] method = scope["method"]
@ -69,7 +125,9 @@ class ASGIApp:
version = "1.1" version = "1.1"
method = "GET" method = "GET"
instance.ws = transport.create_websocket_connection(send, receive) instance.ws = instance.transport.create_websocket_connection(
send, receive
)
await instance.ws.accept() await instance.ws.accept()
else: else:
pass pass
@ -77,7 +135,7 @@ class ASGIApp:
# - close connection # - close connection
instance.request = Request( instance.request = Request(
url_bytes, headers, version, method, transport, sanic_app url_bytes, headers, version, method, instance.transport, sanic_app
) )
if sanic_app.is_request_stream: if sanic_app.is_request_stream:
@ -92,7 +150,7 @@ class ASGIApp:
body = b"" body = b""
more_body = True more_body = True
while more_body: while more_body:
message = await self.receive() message = await self.transport.receive()
body += message.get("body", b"") body += message.get("body", b"")
more_body = message.get("more_body", False) more_body = message.get("more_body", False)
@ -105,7 +163,7 @@ class ASGIApp:
more_body = True more_body = True
while more_body: while more_body:
message = await self.receive() message = await self.transport.receive()
chunk = message.get("body", b"") chunk = message.get("body", b"")
await self.request.stream.put(chunk) await self.request.stream.put(chunk)
# self.sanic_app.loop.create_task(self.request.stream.put(chunk)) # self.sanic_app.loop.create_task(self.request.stream.put(chunk))
@ -131,29 +189,37 @@ class ASGIApp:
""" """
Write the response. Write the response.
""" """
if isinstance(response, StreamingHTTPResponse):
raise NotImplementedError("Not supported")
headers = [ headers = [
(str(name).encode("latin-1"), str(value).encode("latin-1")) (str(name).encode("latin-1"), str(value).encode("latin-1"))
for name, value in response.headers.items() for name, value in response.headers.items()
] ]
if "content-length" not in response.headers:
if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse
):
headers += [ headers += [
(b"content-length", str(len(response.body)).encode("latin-1")) (b"content-length", str(len(response.body)).encode("latin-1"))
] ]
await self.send( await self.transport.send(
{ {
"type": "http.response.start", "type": "http.response.start",
"status": response.status, "status": response.status,
"headers": headers, "headers": headers,
} }
) )
await self.send(
{ if isinstance(response, StreamingHTTPResponse):
"type": "http.response.body", response.protocol = self.transport.get_protocol()
"body": response.body, await response.stream()
"more_body": False, await response.protocol.complete()
}
) else:
await self.transport.send(
{
"type": "http.response.body",
"body": response.body,
"more_body": False,
}
)

View File

@ -87,9 +87,9 @@ class StreamingHTTPResponse(BaseHTTPResponse):
data = self._encode_body(data) data = self._encode_body(data)
if self.chunked: if self.chunked:
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
else: else:
self.protocol.push_data(data) await self.protocol.push_data(data)
await self.protocol.drain() await self.protocol.drain()
async def stream( async def stream(
@ -105,11 +105,11 @@ class StreamingHTTPResponse(BaseHTTPResponse):
keep_alive=keep_alive, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout, keep_alive_timeout=keep_alive_timeout,
) )
self.protocol.push_data(headers) await self.protocol.push_data(headers)
await self.protocol.drain() await self.protocol.drain()
await self.streaming_fn(self) await self.streaming_fn(self)
if self.chunked: if self.chunked:
self.protocol.push_data(b"0\r\n\r\n") await self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the # no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it. # very last thing we write and nothing needs to wait for it.

View File

@ -457,7 +457,7 @@ class HttpProtocol(asyncio.Protocol):
async def drain(self): async def drain(self):
await self._not_paused.wait() await self._not_paused.wait()
def push_data(self, data): async def push_data(self, data):
self.transport.write(data) self.transport.write(data)
async def stream_response(self, response): async def stream_response(self, response):