Merge pull request #473 from subyraman/explore-streams-v2

Add `stream` method for streaming content, add docs and examples
This commit is contained in:
Eli Uriegas 2017-03-05 17:51:44 -08:00 committed by GitHub
commit 19592e8eea
5 changed files with 318 additions and 44 deletions

32
docs/sanic/streaming.md Normal file
View File

@ -0,0 +1,32 @@
# Streaming
Sanic allows you to stream content to the client with the `stream` method. This method accepts a coroutine callback which is passed a `StreamingHTTPResponse` object that is written to. A simple example is like follows:
```python
from sanic import Sanic
from sanic.response import stream
app = Sanic(__name__)
@app.route("/")
async def test(request):
async def sample_streaming_fn(response):
response.write('foo,')
response.write('bar')
return stream(sample_streaming_fn, content_type='text/csv')
```
This is useful in situations where you want to stream content to the client that originates in an external service, like a database. For example, you can stream database records to the client with the asynchronous cursor that `asyncpg` provides:
```python
@app.route("/")
async def index(request):
async def stream_from_db(response):
conn = await asyncpg.connect(database='test')
async with conn.transaction():
async for record in conn.cursor('SELECT generate_series(0, 10)'):
response.write(record[0])
return stream(stream_from_db)
```

View File

@ -13,7 +13,7 @@ from sanic.constants import HTTP_METHODS
from sanic.exceptions import ServerError, URLBuildError, SanicException from sanic.exceptions import ServerError, URLBuildError, SanicException
from sanic.handlers import ErrorHandler from sanic.handlers import ErrorHandler
from sanic.log import log from sanic.log import log
from sanic.response import HTTPResponse from sanic.response import HTTPResponse, StreamingHTTPResponse
from sanic.router import Router from sanic.router import Router
from sanic.server import serve, serve_multiple, HttpProtocol from sanic.server import serve, serve_multiple, HttpProtocol
from sanic.static import register as static_register from sanic.static import register as static_register
@ -391,14 +391,17 @@ class Sanic:
def converted_response_type(self, response): def converted_response_type(self, response):
pass pass
async def handle_request(self, request, response_callback): async def handle_request(self, request, write_callback, stream_callback):
"""Take a request from the HTTP Server and return a response object """Take a request from the HTTP Server and return a response object
to be sent back The HTTP Server only expects a response object, so to be sent back The HTTP Server only expects a response object, so
exception handling must be done here exception handling must be done here
:param request: HTTP Request object :param request: HTTP Request object
:param response_callback: Response function to be called with the :param write_callback: Synchronous response function to be
response as the only argument called with the response as the only argument
:param stream_callback: Coroutine that handles streaming a
StreamingHTTPResponse if produced by the handler.
:return: Nothing :return: Nothing
""" """
try: try:
@ -467,7 +470,11 @@ class Sanic:
response = HTTPResponse( response = HTTPResponse(
"An error occurred while handling an error") "An error occurred while handling an error")
response_callback(response) # pass the response to the correct callback
if isinstance(response, StreamingHTTPResponse):
await stream_callback(response)
else:
write_callback(response)
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Testing # Testing

View File

@ -73,37 +73,16 @@ ALL_STATUS_CODES = {
} }
class HTTPResponse: class BaseHTTPResponse:
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') def _encode_body(self, data):
def __init__(self, body=None, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
self.content_type = content_type
if body is not None:
try: try:
# Try to encode it regularly # Try to encode it regularly
self.body = body.encode() return data.encode()
except AttributeError: except AttributeError:
# Convert it to a str if you can't # Convert it to a str if you can't
self.body = str(body).encode() return str(data).encode()
else:
self.body = body_bytes
self.status = status def _parse_headers(self):
self.headers = headers or {}
self._cookies = None
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Content-Length'] = self.headers.get(
'Content-Length', len(self.body))
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
headers = b'' headers = b''
for name, value in self.headers.items(): for name, value in self.headers.items():
try: try:
@ -115,6 +94,114 @@ class HTTPResponse:
b'%b: %b\r\n' % ( b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8'))) str(name).encode(), str(value).encode('utf-8')))
return headers
@property
def cookies(self):
if self._cookies is None:
self._cookies = CookieJar(self.headers)
return self._cookies
class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'transport', 'streaming_fn',
'status', 'content_type', 'headers', '_cookies')
def __init__(self, streaming_fn, status=200, headers=None,
content_type='text/plain'):
self.content_type = content_type
self.streaming_fn = streaming_fn
self.status = status
self.headers = headers or {}
self._cookies = None
def write(self, data):
"""Writes a chunk of data to the streaming response.
:param data: bytes-ish data to be written.
"""
if type(data) != bytes:
data = self._encode_body(data)
self.transport.write(
b"%b\r\n%b\r\n" % (str(len(data)).encode(), data))
async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
"""Streams headers, runs the `streaming_fn` callback that writes content
to the response body, then finalizes the response body.
"""
headers = self.get_headers(
version, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout)
self.transport.write(headers)
await self.streaming_fn(self)
self.transport.write(b'0\r\n\r\n')
def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Transfer-Encoding'] = 'chunked'
self.headers.pop('Content-Length', None)
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
headers = self._parse_headers()
# Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status)
if not status:
status = ALL_STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n'
b'%b'
b'%b\r\n') % (
version.encode(),
self.status,
status,
timeout_header,
headers
)
class HTTPResponse(BaseHTTPResponse):
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
def __init__(self, body=None, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
self.content_type = content_type
if body is not None:
self.body = self._encode_body(body)
else:
self.body = body_bytes
self.status = status
self.headers = headers or {}
self._cookies = None
def output(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
self.headers['Content-Length'] = self.headers.get(
'Content-Length', len(self.body))
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
headers = self._parse_headers()
# Try to pull from the common codes first # Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all # Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status) status = COMMON_STATUS_CODES.get(self.status)
@ -164,7 +251,8 @@ def text(body, status=200, headers=None,
:param content_type: :param content_type:
the content type (string) of the response the content type (string) of the response
""" """
return HTTPResponse(body, status=status, headers=headers, return HTTPResponse(
body, status=status, headers=headers,
content_type=content_type) content_type=content_type)
@ -220,6 +308,32 @@ async def file(location, mime_type=None, headers=None, _range=None):
body_bytes=out_stream) body_bytes=out_stream)
def stream(
streaming_fn, status=200, headers=None,
content_type="text/plain; charset=utf-8"):
"""Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
Example usage:
```
@app.route("/")
async def index(request):
async def streaming_fn(response):
await response.write('foo')
await response.write('bar')
return stream(streaming_fn, content_type='text/plain')
```
:param streaming_fn: A coroutine accepts a response and
writes content to that response.
:param mime_type: Specific mime_type.
:param headers: Custom Headers.
"""
return StreamingHTTPResponse(
streaming_fn, headers=headers, content_type=content_type, status=200)
def redirect(to, headers=None, status=302, def redirect(to, headers=None, status=302,
content_type="text/html; charset=utf-8"): content_type="text/html; charset=utf-8"):
"""Abort execution and cause a 302 redirect (by default). """Abort execution and cause a 302 redirect (by default).

View File

@ -159,20 +159,63 @@ class HttpProtocol(asyncio.Protocol):
def on_message_complete(self): def on_message_complete(self):
if self.request.body: if self.request.body:
self.request.body = b''.join(self.request.body) self.request.body = b''.join(self.request.body)
self._request_handler_task = self.loop.create_task( self._request_handler_task = self.loop.create_task(
self.request_handler(self.request, self.write_response)) self.request_handler(
self.request,
self.write_response,
self.stream_response))
# -------------------------------------------- # # -------------------------------------------- #
# Responding # Responding
# -------------------------------------------- # # -------------------------------------------- #
def write_response(self, response): def write_response(self, response):
"""
Writes response content synchronously to the transport.
"""
try:
keep_alive = ( keep_alive = (
self.parser.should_keep_alive() and not self.signal.stopped) self.parser.should_keep_alive() and not self.signal.stopped)
try:
self.transport.write( self.transport.write(
response.output( response.output(
self.request.version, keep_alive, self.request_timeout)) self.request.version, keep_alive,
self.request_timeout))
except AttributeError:
log.error(
('Invalid response object for url {}, '
'Expected Type: HTTPResponse, Actual Type: {}').format(
self.url, type(response)))
self.write_error(ServerError('Invalid response type'))
except RuntimeError:
log.error(
'Connection lost before response written @ {}'.format(
self.request.ip))
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(
repr(e)))
finally:
if not keep_alive:
self.transport.close()
else:
self._last_request_time = current_time
self.cleanup()
async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
the transport to the response so the response consumer can
write to the response as needed.
"""
try:
keep_alive = (
self.parser.should_keep_alive() and not self.signal.stopped)
response.transport = self.transport
await response.stream(
self.request.version, keep_alive, self.request_timeout)
except AttributeError: except AttributeError:
log.error( log.error(
('Invalid response object for url {}, ' ('Invalid response object for url {}, '
@ -191,7 +234,6 @@ class HttpProtocol(asyncio.Protocol):
if not keep_alive: if not keep_alive:
self.transport.close() self.transport.close()
else: else:
# Record that we received data
self._last_request_time = current_time self._last_request_time = current_time
self.cleanup() self.cleanup()

View File

@ -1,8 +1,12 @@
import asyncio
import pytest
from random import choice from random import choice
from sanic import Sanic from sanic import Sanic
from sanic.response import HTTPResponse from sanic.response import HTTPResponse, stream, StreamingHTTPResponse
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock
def test_response_body_not_a_string(): def test_response_body_not_a_string():
"""Test when a response body sent from the application is not a string""" """Test when a response body sent from the application is not a string"""
@ -15,3 +19,78 @@ def test_response_body_not_a_string():
request, response = app.test_client.get('/hello') request, response = app.test_client.get('/hello')
assert response.text == str(random_num) assert response.text == str(random_num)
async def sample_streaming_fn(response):
response.write('foo,')
await asyncio.sleep(.001)
response.write('bar')
@pytest.fixture
def streaming_app():
app = Sanic('streaming')
@app.route("/")
async def test(request):
return stream(sample_streaming_fn, content_type='text/csv')
return app
def test_streaming_adds_correct_headers(streaming_app):
request, response = streaming_app.test_client.get('/')
assert response.headers['Transfer-Encoding'] == 'chunked'
assert response.headers['Content-Type'] == 'text/csv'
def test_streaming_returns_correct_content(streaming_app):
request, response = streaming_app.test_client.get('/')
assert response.text == 'foo,bar'
@pytest.mark.parametrize('status', [200, 201, 400, 401])
def test_stream_response_status_returns_correct_headers(status):
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
headers = response.get_headers()
assert b"HTTP/1.1 %s" % str(status).encode() in headers
@pytest.mark.parametrize('keep_alive_timeout', [10, 20, 30])
def test_stream_response_keep_alive_returns_correct_headers(
keep_alive_timeout):
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers(
keep_alive=True, keep_alive_timeout=keep_alive_timeout)
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
def test_stream_response_includes_chunked_header():
response = StreamingHTTPResponse(sample_streaming_fn)
headers = response.get_headers()
assert b"Transfer-Encoding: chunked\r\n" in headers
def test_stream_response_writes_correct_content_to_transport(streaming_app):
response = StreamingHTTPResponse(sample_streaming_fn)
response.transport = MagicMock(asyncio.Transport)
@streaming_app.listener('after_server_start')
async def run_stream(app, loop):
await response.stream()
assert response.transport.write.call_args_list[1][0][0] == (
b'4\r\nfoo,\r\n'
)
assert response.transport.write.call_args_list[2][0][0] == (
b'3\r\nbar\r\n'
)
assert response.transport.write.call_args_list[3][0][0] == (
b'0\r\n\r\n'
)
app.stop()
streaming_app.run(host=HOST, port=PORT)