diff --git a/docs/sanic/streaming.md b/docs/sanic/streaming.md new file mode 100644 index 00000000..c13fffb3 --- /dev/null +++ b/docs/sanic/streaming.md @@ -0,0 +1,32 @@ +# Streaming + +Sanic allows you to stream content to the client with the `stream` method. This method accepts a coroutine callback which is passed a `StreamingHTTPResponse` object that is written to. A simple example is like follows: + +```python +from sanic import Sanic +from sanic.response import stream + +app = Sanic(__name__) + +@app.route("/") +async def test(request): + async def sample_streaming_fn(response): + response.write('foo,') + response.write('bar') + + return stream(sample_streaming_fn, content_type='text/csv') +``` + +This is useful in situations where you want to stream content to the client that originates in an external service, like a database. For example, you can stream database records to the client with the asynchronous cursor that `asyncpg` provides: + +```python +@app.route("/") +async def index(request): + async def stream_from_db(response): + conn = await asyncpg.connect(database='test') + async with conn.transaction(): + async for record in conn.cursor('SELECT generate_series(0, 10)'): + response.write(record[0]) + + return stream(stream_from_db) +``` \ No newline at end of file diff --git a/sanic/app.py b/sanic/app.py index bb52fdf8..bb7b0efe 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -13,7 +13,7 @@ from sanic.constants import HTTP_METHODS from sanic.exceptions import ServerError, URLBuildError, SanicException from sanic.handlers import ErrorHandler from sanic.log import log -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.router import Router from sanic.server import serve, serve_multiple, HttpProtocol from sanic.static import register as static_register @@ -391,14 +391,17 @@ class Sanic: def converted_response_type(self, response): pass - async def handle_request(self, request, response_callback): + async def handle_request(self, request, write_callback, stream_callback): """Take a request from the HTTP Server and return a response object to be sent back The HTTP Server only expects a response object, so exception handling must be done here :param request: HTTP Request object - :param response_callback: Response function to be called with the - response as the only argument + :param write_callback: Synchronous response function to be + called with the response as the only argument + :param stream_callback: Coroutine that handles streaming a + StreamingHTTPResponse if produced by the handler. + :return: Nothing """ try: @@ -467,7 +470,11 @@ class Sanic: response = HTTPResponse( "An error occurred while handling an error") - response_callback(response) + # pass the response to the correct callback + if isinstance(response, StreamingHTTPResponse): + await stream_callback(response) + else: + write_callback(response) # -------------------------------------------------------------------- # # Testing diff --git a/sanic/response.py b/sanic/response.py index ad263364..c36c0181 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -73,37 +73,16 @@ ALL_STATUS_CODES = { } -class HTTPResponse: - __slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') +class BaseHTTPResponse: + def _encode_body(self, data): + try: + # Try to encode it regularly + return data.encode() + except AttributeError: + # Convert it to a str if you can't + return str(data).encode() - def __init__(self, body=None, status=200, headers=None, - content_type='text/plain', body_bytes=b''): - self.content_type = content_type - - if body is not None: - try: - # Try to encode it regularly - self.body = body.encode() - except AttributeError: - # Convert it to a str if you can't - self.body = str(body).encode() - else: - self.body = body_bytes - - self.status = status - self.headers = headers or {} - self._cookies = None - - def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b'' - if keep_alive and keep_alive_timeout is not None: - timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout - self.headers['Content-Length'] = self.headers.get( - 'Content-Length', len(self.body)) - self.headers['Content-Type'] = self.headers.get( - 'Content-Type', self.content_type) + def _parse_headers(self): headers = b'' for name, value in self.headers.items(): try: @@ -115,6 +94,114 @@ class HTTPResponse: b'%b: %b\r\n' % ( str(name).encode(), str(value).encode('utf-8'))) + return headers + + @property + def cookies(self): + if self._cookies is None: + self._cookies = CookieJar(self.headers) + return self._cookies + + +class StreamingHTTPResponse(BaseHTTPResponse): + __slots__ = ( + 'transport', 'streaming_fn', + 'status', 'content_type', 'headers', '_cookies') + + def __init__(self, streaming_fn, status=200, headers=None, + content_type='text/plain'): + self.content_type = content_type + self.streaming_fn = streaming_fn + self.status = status + self.headers = headers or {} + self._cookies = None + + def write(self, data): + """Writes a chunk of data to the streaming response. + + :param data: bytes-ish data to be written. + """ + if type(data) != bytes: + data = self._encode_body(data) + + self.transport.write( + b"%b\r\n%b\r\n" % (str(len(data)).encode(), data)) + + async def stream( + self, version="1.1", keep_alive=False, keep_alive_timeout=None): + """Streams headers, runs the `streaming_fn` callback that writes content + to the response body, then finalizes the response body. + """ + headers = self.get_headers( + version, keep_alive=keep_alive, + keep_alive_timeout=keep_alive_timeout) + self.transport.write(headers) + + await self.streaming_fn(self) + self.transport.write(b'0\r\n\r\n') + + def get_headers( + self, version="1.1", keep_alive=False, keep_alive_timeout=None): + # This is all returned in a kind-of funky way + # We tried to make this as fast as possible in pure python + timeout_header = b'' + if keep_alive and keep_alive_timeout is not None: + timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + + self.headers['Transfer-Encoding'] = 'chunked' + self.headers.pop('Content-Length', None) + self.headers['Content-Type'] = self.headers.get( + 'Content-Type', self.content_type) + + headers = self._parse_headers() + + # Try to pull from the common codes first + # Speeds up response rate 6% over pulling from all + status = COMMON_STATUS_CODES.get(self.status) + if not status: + status = ALL_STATUS_CODES.get(self.status) + + return (b'HTTP/%b %d %b\r\n' + b'%b' + b'%b\r\n') % ( + version.encode(), + self.status, + status, + timeout_header, + headers + ) + + +class HTTPResponse(BaseHTTPResponse): + __slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') + + def __init__(self, body=None, status=200, headers=None, + content_type='text/plain', body_bytes=b''): + self.content_type = content_type + + if body is not None: + self.body = self._encode_body(body) + else: + self.body = body_bytes + + self.status = status + self.headers = headers or {} + self._cookies = None + + def output( + self, version="1.1", keep_alive=False, keep_alive_timeout=None): + # This is all returned in a kind-of funky way + # We tried to make this as fast as possible in pure python + timeout_header = b'' + if keep_alive and keep_alive_timeout is not None: + timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + self.headers['Content-Length'] = self.headers.get( + 'Content-Length', len(self.body)) + self.headers['Content-Type'] = self.headers.get( + 'Content-Type', self.content_type) + + headers = self._parse_headers() + # Try to pull from the common codes first # Speeds up response rate 6% over pulling from all status = COMMON_STATUS_CODES.get(self.status) @@ -164,8 +251,9 @@ def text(body, status=200, headers=None, :param content_type: the content type (string) of the response """ - return HTTPResponse(body, status=status, headers=headers, - content_type=content_type) + return HTTPResponse( + body, status=status, headers=headers, + content_type=content_type) def raw(body, status=200, headers=None, @@ -220,6 +308,32 @@ async def file(location, mime_type=None, headers=None, _range=None): body_bytes=out_stream) +def stream( + streaming_fn, status=200, headers=None, + content_type="text/plain; charset=utf-8"): + """Accepts an coroutine `streaming_fn` which can be used to + write chunks to a streaming response. Returns a `StreamingHTTPResponse`. + Example usage: + + ``` + @app.route("/") + async def index(request): + async def streaming_fn(response): + await response.write('foo') + await response.write('bar') + + return stream(streaming_fn, content_type='text/plain') + ``` + + :param streaming_fn: A coroutine accepts a response and + writes content to that response. + :param mime_type: Specific mime_type. + :param headers: Custom Headers. + """ + return StreamingHTTPResponse( + streaming_fn, headers=headers, content_type=content_type, status=200) + + def redirect(to, headers=None, status=302, content_type="text/html; charset=utf-8"): """Abort execution and cause a 302 redirect (by default). diff --git a/sanic/server.py b/sanic/server.py index 28ce0848..0c4bb4b8 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -159,20 +159,63 @@ class HttpProtocol(asyncio.Protocol): def on_message_complete(self): if self.request.body: self.request.body = b''.join(self.request.body) + self._request_handler_task = self.loop.create_task( - self.request_handler(self.request, self.write_response)) + self.request_handler( + self.request, + self.write_response, + self.stream_response)) # -------------------------------------------- # # Responding # -------------------------------------------- # - def write_response(self, response): - keep_alive = ( - self.parser.should_keep_alive() and not self.signal.stopped) + """ + Writes response content synchronously to the transport. + """ try: + keep_alive = ( + self.parser.should_keep_alive() and not self.signal.stopped) + self.transport.write( response.output( - self.request.version, keep_alive, self.request_timeout)) + self.request.version, keep_alive, + self.request_timeout)) + except AttributeError: + log.error( + ('Invalid response object for url {}, ' + 'Expected Type: HTTPResponse, Actual Type: {}').format( + self.url, type(response))) + self.write_error(ServerError('Invalid response type')) + except RuntimeError: + log.error( + 'Connection lost before response written @ {}'.format( + self.request.ip)) + except Exception as e: + self.bail_out( + "Writing response failed, connection closed {}".format( + repr(e))) + finally: + if not keep_alive: + self.transport.close() + else: + self._last_request_time = current_time + self.cleanup() + + async def stream_response(self, response): + """ + Streams a response to the client asynchronously. Attaches + the transport to the response so the response consumer can + write to the response as needed. + """ + + try: + keep_alive = ( + self.parser.should_keep_alive() and not self.signal.stopped) + + response.transport = self.transport + await response.stream( + self.request.version, keep_alive, self.request_timeout) except AttributeError: log.error( ('Invalid response object for url {}, ' @@ -191,7 +234,6 @@ class HttpProtocol(asyncio.Protocol): if not keep_alive: self.transport.close() else: - # Record that we received data self._last_request_time = current_time self.cleanup() diff --git a/tests/test_response.py b/tests/test_response.py index 9639e076..ff5fd42b 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,8 +1,12 @@ +import asyncio +import pytest from random import choice from sanic import Sanic -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, stream, StreamingHTTPResponse +from sanic.testing import HOST, PORT +from unittest.mock import MagicMock def test_response_body_not_a_string(): """Test when a response body sent from the application is not a string""" @@ -15,3 +19,78 @@ def test_response_body_not_a_string(): request, response = app.test_client.get('/hello') assert response.text == str(random_num) + + +async def sample_streaming_fn(response): + response.write('foo,') + await asyncio.sleep(.001) + response.write('bar') + + +@pytest.fixture +def streaming_app(): + app = Sanic('streaming') + + @app.route("/") + async def test(request): + return stream(sample_streaming_fn, content_type='text/csv') + + return app + + +def test_streaming_adds_correct_headers(streaming_app): + request, response = streaming_app.test_client.get('/') + assert response.headers['Transfer-Encoding'] == 'chunked' + assert response.headers['Content-Type'] == 'text/csv' + + +def test_streaming_returns_correct_content(streaming_app): + request, response = streaming_app.test_client.get('/') + assert response.text == 'foo,bar' + + +@pytest.mark.parametrize('status', [200, 201, 400, 401]) +def test_stream_response_status_returns_correct_headers(status): + response = StreamingHTTPResponse(sample_streaming_fn, status=status) + headers = response.get_headers() + assert b"HTTP/1.1 %s" % str(status).encode() in headers + + +@pytest.mark.parametrize('keep_alive_timeout', [10, 20, 30]) +def test_stream_response_keep_alive_returns_correct_headers( + keep_alive_timeout): + response = StreamingHTTPResponse(sample_streaming_fn) + headers = response.get_headers( + keep_alive=True, keep_alive_timeout=keep_alive_timeout) + + assert b"Keep-Alive: %s\r\n" % str(keep_alive_timeout).encode() in headers + + +def test_stream_response_includes_chunked_header(): + response = StreamingHTTPResponse(sample_streaming_fn) + headers = response.get_headers() + assert b"Transfer-Encoding: chunked\r\n" in headers + + +def test_stream_response_writes_correct_content_to_transport(streaming_app): + response = StreamingHTTPResponse(sample_streaming_fn) + response.transport = MagicMock(asyncio.Transport) + + @streaming_app.listener('after_server_start') + async def run_stream(app, loop): + await response.stream() + assert response.transport.write.call_args_list[1][0][0] == ( + b'4\r\nfoo,\r\n' + ) + + assert response.transport.write.call_args_list[2][0][0] == ( + b'3\r\nbar\r\n' + ) + + assert response.transport.write.call_args_list[3][0][0] == ( + b'0\r\n\r\n' + ) + + app.stop() + + streaming_app.run(host=HOST, port=PORT)