This commit is contained in:
Suby Raman 2017-02-21 11:05:06 -05:00
parent ff5d4276bc
commit 4e8aac4b41
5 changed files with 273 additions and 42 deletions

29
docs/sanic/streaming.md Normal file
View File

@ -0,0 +1,29 @@
# Streaming
Sanic allows you to stream content to the client with the `stream` method. This method accepts a coroutine callback which is passed a `StreamingHTTPResponse` object that is written to. A simple example is like follows:
```python
app = Sanic(__name__)
@app.route("/")
async def test(request):
async def sample_streaming_fn(response):
await response.write('foo,')
await response.write('bar')
return stream(sample_streaming_fn, content_type='text/csv')
```
This is useful in situations where you want to stream content to the client that originates in an external service, like a database. For example, you can stream database records to the client with the asynchronous cursor that `asyncpg` provides:
```python
@app.route("/")
async def index(request):
async def stream_from_db(response):
conn = await asyncpg.connect(database='test')
async with conn.transaction():
async for record in conn.cursor('SELECT generate_series(0, 10)'):
await response.write(record[0])
return stream(stream_from_db)
```

View File

@ -416,7 +416,7 @@ class Sanic:
response = HTTPResponse(
"An error occurred while handling an error")
response_callback(response)
await response_callback(response)
# -------------------------------------------------------------------- #
# Testing

View File

@ -1,5 +1,7 @@
import asyncio
from mimetypes import guess_type
from os import path
from inspect import isawaitable
from ujson import dumps as json_dumps
from aiofiles import open as open_async
@ -73,37 +75,16 @@ ALL_STATUS_CODES = {
}
class HTTPResponse:
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
class BaseHTTPResponse:
def _encode_body(self, data):
try:
# Try to encode it regularly
return data.encode()
except AttributeError:
# Convert it to a str if you can't
return str(data).encode()
def __init__(self, body=None, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
self.content_type = content_type
if body is not None:
try:
# Try to encode it regularly
self.body = body.encode()
except AttributeError:
# Convert it to a str if you can't
self.body = str(body).encode()
else:
self.body = body_bytes
self.status = status
self.headers = headers or {}
self._cookies = None
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Content-Length'] = self.headers.get(
'Content-Length', len(self.body))
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
def _parse_headers(self):
headers = b''
for name, value in self.headers.items():
try:
@ -115,6 +96,112 @@ class HTTPResponse:
b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8')))
return headers
@property
def cookies(self):
if self._cookies is None:
self._cookies = CookieJar(self.headers)
return self._cookies
class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'transport', 'streaming_fn',
'status', 'content_type', 'headers', '_cookies')
def __init__(self, streaming_fn, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
self.content_type = content_type
self.streaming_fn = streaming_fn
self.status = status
self.headers = headers or {}
self._cookies = None
async def write(self, data):
"""Writes a chunk of data to the streaming response.
:param data: bytes-ish data to be written.
"""
data = self._encode_body(data)
self.transport.write(
b"%b\r\n%b\r\n" % (str(len(data)).encode(), data))
async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
"""Streams headers, runs the `streaming_fn` callback that writes content
to the response body, then finalizes the response body.
"""
headers = self.get_headers(
version, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout)
self.transport.write(headers)
await self.streaming_fn(self)
self.transport.write(b'0\r\n\r\n')
def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Transfer-Encoding'] = 'chunked'
self.headers.pop('Content-Length', None)
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
headers = self._parse_headers()
# Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status)
if not status:
status = ALL_STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n'
b'%b'
b'%b\r\n') % (
version.encode(),
self.status,
status,
timeout_header,
headers
)
class HTTPResponse(BaseHTTPResponse):
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
def __init__(self, body=None, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
self.content_type = content_type
if body is not None:
self.body = self._encode_body(body)
else:
self.body = body_bytes
self.status = status
self.headers = headers or {}
self._cookies = None
def output(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Content-Length'] = self.headers.get(
'Content-Length', len(self.body))
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
headers = self._parse_headers()
# Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status)
@ -164,8 +251,9 @@ def text(body, status=200, headers=None,
:param content_type:
the content type (string) of the response
"""
return HTTPResponse(body, status=status, headers=headers,
content_type=content_type)
return HTTPResponse(
body, status=status, headers=headers,
content_type=content_type)
def raw(body, status=200, headers=None,
@ -220,6 +308,32 @@ async def file(location, mime_type=None, headers=None, _range=None):
body_bytes=out_stream)
def stream(
streaming_fn, status=200, headers=None,
content_type="text/plain; charset=utf-8"):
"""Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
Example usage:
```
@app.route("/")
async def index(request):
async def streaming_fn(response):
await response.write('foo')
await response.write('bar')
return stream(streaming_fn, content_type='text/plain')
```
:param streaming_fn: A coroutine accepts a response and
writes content to that response.
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
"""
return StreamingHTTPResponse(
streaming_fn, headers=headers, content_type=content_type, status=200)
def redirect(to, headers=None, status=302,
content_type="text/html; charset=utf-8"):
"""Abort execution and cause a 302 redirect (by default).

View File

@ -21,6 +21,7 @@ except ImportError:
from sanic.log import log
from sanic.request import Request
from sanic.response import StreamingHTTPResponse
from sanic.exceptions import (
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError)
@ -159,20 +160,29 @@ class HttpProtocol(asyncio.Protocol):
def on_message_complete(self):
if self.request.body:
self.request.body = b''.join(self.request.body)
self._request_handler_task = self.loop.create_task(
self.request_handler(self.request, self.write_response))
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
def write_response(self, response):
keep_alive = (
self.parser.should_keep_alive() and not self.signal.stopped)
async def write_response(self, response):
try:
self.transport.write(
response.output(
self.request.version, keep_alive, self.request_timeout))
keep_alive = (
self.parser.should_keep_alive() and not self.signal.stopped)
if isinstance(response, StreamingHTTPResponse):
# streaming responses should have direct write access to the
# transport
response.transport = self.transport
await response.stream(
self.request.version, keep_alive, self.request_timeout)
else:
self.transport.write(
response.output(
self.request.version, keep_alive,
self.request_timeout))
except AttributeError:
log.error(
('Invalid response object for url {}, '
@ -191,7 +201,6 @@ class HttpProtocol(asyncio.Protocol):
if not keep_alive:
self.transport.close()
else:
# Record that we received data
self._last_request_time = current_time
self.cleanup()

View File

@ -1,8 +1,12 @@
import asyncio
import pytest
from random import choice
from sanic import Sanic
from sanic.response import HTTPResponse
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock
def test_response_body_not_a_string():
"""Test when a response body sent from the application is not a string"""
@ -15,3 +19,78 @@ def test_response_body_not_a_string():
request, response = app.test_client.get('/hello')
assert response.text == str(random_num)
async def sample_streaming_fn(response):
await response.write('foo,')
await asyncio.sleep(.001)
await response.write('bar')
@pytest.fixture
def streaming_app():
app = Sanic('streaming')
@app.route("/")
async def test(request):
return stream(sample_streaming_fn, content_type='text/csv')
return app
def test_streaming_adds_correct_headers(streaming_app):
request, response = streaming_app.test_client.get('/')
assert response.headers['Transfer-Encoding'] == 'chunked'
assert response.headers['Content-Type'] == 'text/csv'
def test_streaming_returns_correct_content(streaming_app):
request, response = streaming_app.test_client.get('/')
assert response.text == 'foo,bar'
@pytest.mark.parametrize('status', [200, 201, 400, 401])
def test_stream_response_status_returns_correct_headers(status):
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
headers = response.get_headers()
assert b"HTTP/1.1 %s" % str(status).encode() in headers
@pytest.mark.parametrize('keep_alive_timeout', [10, 20, 30])
def test_stream_response_keep_alive_returns_correct_headers(
keep_alive_timeout):
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(
keep_alive=True, keep_alive_timeout=keep_alive_timeout)
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
def test_stream_response_includes_chunked_header():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers()
assert b"Transfer-Encoding: chunked\r\n" in headers
def test_stream_response_writes_correct_content_to_transport(streaming_app):
response = StreamingHTTPResponse(sample_streaming_fn)
response.transport = MagicMock(asyncio.Transport)
@streaming_app.listener('after_server_start')
async def run_stream(app, loop):
await response.stream()
assert response.transport.write.call_args_list[1][0][0] == (
b'4\r\nfoo,\r\n'
)
assert response.transport.write.call_args_list[2][0][0] == (
b'3\r\nbar\r\n'
)
assert response.transport.write.call_args_list[3][0][0] == (
b'0\r\n\r\n'
)
app.stop()
streaming_app.run(host=HOST, port=PORT)