rebase
This commit is contained in:
parent
ff5d4276bc
commit
4e8aac4b41
29
docs/sanic/streaming.md
Normal file
29
docs/sanic/streaming.md
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Streaming
|
||||
|
||||
Sanic allows you to stream content to the client with the `stream` method. This method accepts a coroutine callback which is passed a `StreamingHTTPResponse` object that is written to. A simple example is like follows:
|
||||
|
||||
```python
|
||||
app = Sanic(__name__)
|
||||
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
async def sample_streaming_fn(response):
|
||||
await response.write('foo,')
|
||||
await response.write('bar')
|
||||
|
||||
return stream(sample_streaming_fn, content_type='text/csv')
|
||||
```
|
||||
|
||||
This is useful in situations where you want to stream content to the client that originates in an external service, like a database. For example, you can stream database records to the client with the asynchronous cursor that `asyncpg` provides:
|
||||
|
||||
```python
|
||||
@app.route("/")
|
||||
async def index(request):
|
||||
async def stream_from_db(response):
|
||||
conn = await asyncpg.connect(database='test')
|
||||
async with conn.transaction():
|
||||
async for record in conn.cursor('SELECT generate_series(0, 10)'):
|
||||
await response.write(record[0])
|
||||
|
||||
return stream(stream_from_db)
|
||||
```
|
|
@ -416,7 +416,7 @@ class Sanic:
|
|||
response = HTTPResponse(
|
||||
"An error occurred while handling an error")
|
||||
|
||||
response_callback(response)
|
||||
await response_callback(response)
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# Testing
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
from mimetypes import guess_type
|
||||
from os import path
|
||||
from inspect import isawaitable
|
||||
from ujson import dumps as json_dumps
|
||||
|
||||
from aiofiles import open as open_async
|
||||
|
@ -73,37 +75,16 @@ ALL_STATUS_CODES = {
|
|||
}
|
||||
|
||||
|
||||
class HTTPResponse:
|
||||
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
|
||||
|
||||
def __init__(self, body=None, status=200, headers=None,
|
||||
content_type='text/plain', body_bytes=b''):
|
||||
self.content_type = content_type
|
||||
|
||||
if body is not None:
|
||||
class BaseHTTPResponse:
|
||||
def _encode_body(self, data):
|
||||
try:
|
||||
# Try to encode it regularly
|
||||
self.body = body.encode()
|
||||
return data.encode()
|
||||
except AttributeError:
|
||||
# Convert it to a str if you can't
|
||||
self.body = str(body).encode()
|
||||
else:
|
||||
self.body = body_bytes
|
||||
return str(data).encode()
|
||||
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self._cookies = None
|
||||
|
||||
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
# This is all returned in a kind-of funky way
|
||||
# We tried to make this as fast as possible in pure python
|
||||
timeout_header = b''
|
||||
if keep_alive and keep_alive_timeout is not None:
|
||||
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
|
||||
self.headers['Content-Length'] = self.headers.get(
|
||||
'Content-Length', len(self.body))
|
||||
self.headers['Content-Type'] = self.headers.get(
|
||||
'Content-Type', self.content_type)
|
||||
def _parse_headers(self):
|
||||
headers = b''
|
||||
for name, value in self.headers.items():
|
||||
try:
|
||||
|
@ -115,6 +96,112 @@ class HTTPResponse:
|
|||
b'%b: %b\r\n' % (
|
||||
str(name).encode(), str(value).encode('utf-8')))
|
||||
|
||||
return headers
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
if self._cookies is None:
|
||||
self._cookies = CookieJar(self.headers)
|
||||
return self._cookies
|
||||
|
||||
|
||||
class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
__slots__ = (
|
||||
'transport', 'streaming_fn',
|
||||
'status', 'content_type', 'headers', '_cookies')
|
||||
|
||||
def __init__(self, streaming_fn, status=200, headers=None,
|
||||
content_type='text/plain', body_bytes=b''):
|
||||
self.content_type = content_type
|
||||
self.streaming_fn = streaming_fn
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self._cookies = None
|
||||
|
||||
async def write(self, data):
|
||||
"""Writes a chunk of data to the streaming response.
|
||||
|
||||
:param data: bytes-ish data to be written.
|
||||
"""
|
||||
data = self._encode_body(data)
|
||||
self.transport.write(
|
||||
b"%b\r\n%b\r\n" % (str(len(data)).encode(), data))
|
||||
|
||||
async def stream(
|
||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
"""Streams headers, runs the `streaming_fn` callback that writes content
|
||||
to the response body, then finalizes the response body.
|
||||
"""
|
||||
headers = self.get_headers(
|
||||
version, keep_alive=keep_alive,
|
||||
keep_alive_timeout=keep_alive_timeout)
|
||||
self.transport.write(headers)
|
||||
|
||||
await self.streaming_fn(self)
|
||||
self.transport.write(b'0\r\n\r\n')
|
||||
|
||||
def get_headers(
|
||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
# This is all returned in a kind-of funky way
|
||||
# We tried to make this as fast as possible in pure python
|
||||
timeout_header = b''
|
||||
if keep_alive and keep_alive_timeout is not None:
|
||||
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
|
||||
|
||||
self.headers['Transfer-Encoding'] = 'chunked'
|
||||
self.headers.pop('Content-Length', None)
|
||||
self.headers['Content-Type'] = self.headers.get(
|
||||
'Content-Type', self.content_type)
|
||||
|
||||
headers = self._parse_headers()
|
||||
|
||||
# Try to pull from the common codes first
|
||||
# Speeds up response rate 6% over pulling from all
|
||||
status = COMMON_STATUS_CODES.get(self.status)
|
||||
if not status:
|
||||
status = ALL_STATUS_CODES.get(self.status)
|
||||
|
||||
return (b'HTTP/%b %d %b\r\n'
|
||||
b'%b'
|
||||
b'%b\r\n') % (
|
||||
version.encode(),
|
||||
self.status,
|
||||
status,
|
||||
timeout_header,
|
||||
headers
|
||||
)
|
||||
|
||||
|
||||
class HTTPResponse(BaseHTTPResponse):
|
||||
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
|
||||
|
||||
def __init__(self, body=None, status=200, headers=None,
|
||||
content_type='text/plain', body_bytes=b''):
|
||||
self.content_type = content_type
|
||||
|
||||
if body is not None:
|
||||
self.body = self._encode_body(body)
|
||||
else:
|
||||
self.body = body_bytes
|
||||
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self._cookies = None
|
||||
|
||||
def output(
|
||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
# This is all returned in a kind-of funky way
|
||||
# We tried to make this as fast as possible in pure python
|
||||
timeout_header = b''
|
||||
if keep_alive and keep_alive_timeout is not None:
|
||||
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
|
||||
self.headers['Content-Length'] = self.headers.get(
|
||||
'Content-Length', len(self.body))
|
||||
self.headers['Content-Type'] = self.headers.get(
|
||||
'Content-Type', self.content_type)
|
||||
|
||||
headers = self._parse_headers()
|
||||
|
||||
# Try to pull from the common codes first
|
||||
# Speeds up response rate 6% over pulling from all
|
||||
status = COMMON_STATUS_CODES.get(self.status)
|
||||
|
@ -164,7 +251,8 @@ def text(body, status=200, headers=None,
|
|||
:param content_type:
|
||||
the content type (string) of the response
|
||||
"""
|
||||
return HTTPResponse(body, status=status, headers=headers,
|
||||
return HTTPResponse(
|
||||
body, status=status, headers=headers,
|
||||
content_type=content_type)
|
||||
|
||||
|
||||
|
@ -220,6 +308,32 @@ async def file(location, mime_type=None, headers=None, _range=None):
|
|||
body_bytes=out_stream)
|
||||
|
||||
|
||||
def stream(
|
||||
streaming_fn, status=200, headers=None,
|
||||
content_type="text/plain; charset=utf-8"):
|
||||
"""Accepts an coroutine `streaming_fn` which can be used to
|
||||
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
|
||||
Example usage:
|
||||
|
||||
```
|
||||
@app.route("/")
|
||||
async def index(request):
|
||||
async def streaming_fn(response):
|
||||
await response.write('foo')
|
||||
await response.write('bar')
|
||||
|
||||
return stream(streaming_fn, content_type='text/plain')
|
||||
```
|
||||
|
||||
:param streaming_fn: A coroutine accepts a response and
|
||||
writes content to that response.
|
||||
:param mime_type: Specific mime_type.
|
||||
:param headers: Custom Headers.
|
||||
"""
|
||||
return StreamingHTTPResponse(
|
||||
streaming_fn, headers=headers, content_type=content_type, status=200)
|
||||
|
||||
|
||||
def redirect(to, headers=None, status=302,
|
||||
content_type="text/html; charset=utf-8"):
|
||||
"""Abort execution and cause a 302 redirect (by default).
|
||||
|
|
|
@ -21,6 +21,7 @@ except ImportError:
|
|||
|
||||
from sanic.log import log
|
||||
from sanic.request import Request
|
||||
from sanic.response import StreamingHTTPResponse
|
||||
from sanic.exceptions import (
|
||||
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError)
|
||||
|
||||
|
@ -159,20 +160,29 @@ class HttpProtocol(asyncio.Protocol):
|
|||
def on_message_complete(self):
|
||||
if self.request.body:
|
||||
self.request.body = b''.join(self.request.body)
|
||||
|
||||
self._request_handler_task = self.loop.create_task(
|
||||
self.request_handler(self.request, self.write_response))
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Responding
|
||||
# -------------------------------------------- #
|
||||
|
||||
def write_response(self, response):
|
||||
async def write_response(self, response):
|
||||
try:
|
||||
keep_alive = (
|
||||
self.parser.should_keep_alive() and not self.signal.stopped)
|
||||
try:
|
||||
|
||||
if isinstance(response, StreamingHTTPResponse):
|
||||
# streaming responses should have direct write access to the
|
||||
# transport
|
||||
response.transport = self.transport
|
||||
await response.stream(
|
||||
self.request.version, keep_alive, self.request_timeout)
|
||||
else:
|
||||
self.transport.write(
|
||||
response.output(
|
||||
self.request.version, keep_alive, self.request_timeout))
|
||||
self.request.version, keep_alive,
|
||||
self.request_timeout))
|
||||
except AttributeError:
|
||||
log.error(
|
||||
('Invalid response object for url {}, '
|
||||
|
@ -191,7 +201,6 @@ class HttpProtocol(asyncio.Protocol):
|
|||
if not keep_alive:
|
||||
self.transport.close()
|
||||
else:
|
||||
# Record that we received data
|
||||
self._last_request_time = current_time
|
||||
self.cleanup()
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
from random import choice
|
||||
|
||||
from sanic import Sanic
|
||||
from sanic.response import HTTPResponse
|
||||
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse
|
||||
from sanic.testing import HOST, PORT
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
def test_response_body_not_a_string():
|
||||
"""Test when a response body sent from the application is not a string"""
|
||||
|
@ -15,3 +19,78 @@ def test_response_body_not_a_string():
|
|||
|
||||
request, response = app.test_client.get('/hello')
|
||||
assert response.text == str(random_num)
|
||||
|
||||
|
||||
async def sample_streaming_fn(response):
|
||||
await response.write('foo,')
|
||||
await asyncio.sleep(.001)
|
||||
await response.write('bar')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming_app():
|
||||
app = Sanic('streaming')
|
||||
|
||||
@app.route("/")
|
||||
async def test(request):
|
||||
return stream(sample_streaming_fn, content_type='text/csv')
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_streaming_adds_correct_headers(streaming_app):
|
||||
request, response = streaming_app.test_client.get('/')
|
||||
assert response.headers['Transfer-Encoding'] == 'chunked'
|
||||
assert response.headers['Content-Type'] == 'text/csv'
|
||||
|
||||
|
||||
def test_streaming_returns_correct_content(streaming_app):
|
||||
request, response = streaming_app.test_client.get('/')
|
||||
assert response.text == 'foo,bar'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('status', [200, 201, 400, 401])
|
||||
def test_stream_response_status_returns_correct_headers(status):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn, status=status)
|
||||
headers = response.get_headers()
|
||||
assert b"HTTP/1.1 %s" % str(status).encode() in headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize('keep_alive_timeout', [10, 20, 30])
|
||||
def test_stream_response_keep_alive_returns_correct_headers(
|
||||
keep_alive_timeout):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
headers = response.get_headers(
|
||||
keep_alive=True, keep_alive_timeout=keep_alive_timeout)
|
||||
|
||||
assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers
|
||||
|
||||
|
||||
def test_stream_response_includes_chunked_header():
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
headers = response.get_headers()
|
||||
assert b"Transfer-Encoding: chunked\r\n" in headers
|
||||
|
||||
|
||||
def test_stream_response_writes_correct_content_to_transport(streaming_app):
|
||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||
response.transport = MagicMock(asyncio.Transport)
|
||||
|
||||
@streaming_app.listener('after_server_start')
|
||||
async def run_stream(app, loop):
|
||||
await response.stream()
|
||||
assert response.transport.write.call_args_list[1][0][0] == (
|
||||
b'4\r\nfoo,\r\n'
|
||||
)
|
||||
|
||||
assert response.transport.write.call_args_list[2][0][0] == (
|
||||
b'3\r\nbar\r\n'
|
||||
)
|
||||
|
||||
assert response.transport.write.call_args_list[3][0][0] == (
|
||||
b'0\r\n\r\n'
|
||||
)
|
||||
|
||||
app.stop()
|
||||
|
||||
streaming_app.run(host=HOST, port=PORT)
|
||||
|
|
Loading…
Reference in New Issue
Block a user