Pausable response streams (#1179)
* This commit adds handlers for the asyncio/uvloop protocol callbacks for pause_writing and resume_writing. These are needed for the correct functioning of built-in tcp flow-control provided by uvloop and asyncio. This is somewhat of a breaking change, because the `write` function in user streaming callbacks now must be `await`ed. This is necessary because it is possible now that the http protocol may be paused, and any calls to write may need to wait on an async event to be called to become unpaused. Updated examples and tests to reflect this change. This change does not apply to websocket connections. A change to websocket connections may be required to match this change. * Fix a couple of PEP8 errors caused by previous rebase. * update docs add await syntax to response.write in response-streaming docs. * remove commented out code from a test file
This commit is contained in:
parent
a87934d434
commit
30e6a310f1
|
@ -37,7 +37,7 @@ async def handler(request):
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
body = body.decode('utf-8').replace('1', 'A')
|
body = body.decode('utf-8').replace('1', 'A')
|
||||||
response.write(body)
|
await response.write(body)
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,8 +85,8 @@ app = Sanic(__name__)
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
async def test(request):
|
async def test(request):
|
||||||
async def sample_streaming_fn(response):
|
async def sample_streaming_fn(response):
|
||||||
response.write('foo,')
|
await response.write('foo,')
|
||||||
response.write('bar')
|
await response.write('bar')
|
||||||
|
|
||||||
return stream(sample_streaming_fn, content_type='text/csv')
|
return stream(sample_streaming_fn, content_type='text/csv')
|
||||||
```
|
```
|
||||||
|
@ -100,7 +100,7 @@ async def index(request):
|
||||||
conn = await asyncpg.connect(database='test')
|
conn = await asyncpg.connect(database='test')
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
async for record in conn.cursor('SELECT generate_series(0, 10)'):
|
async for record in conn.cursor('SELECT generate_series(0, 10)'):
|
||||||
response.write(record[0])
|
await response.write(record[0])
|
||||||
|
|
||||||
return stream(stream_from_db)
|
return stream(stream_from_db)
|
||||||
```
|
```
|
||||||
|
|
|
@ -30,7 +30,7 @@ async def handler(request):
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
body = body.decode('utf-8').replace('1', 'A')
|
body = body.decode('utf-8').replace('1', 'A')
|
||||||
response.write(body)
|
await response.write(body)
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ class BaseHTTPResponse:
|
||||||
|
|
||||||
class StreamingHTTPResponse(BaseHTTPResponse):
|
class StreamingHTTPResponse(BaseHTTPResponse):
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
'transport', 'streaming_fn', 'status',
|
'protocol', 'streaming_fn', 'status',
|
||||||
'content_type', 'headers', '_cookies'
|
'content_type', 'headers', '_cookies'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||||
self.headers = CIMultiDict(headers or {})
|
self.headers = CIMultiDict(headers or {})
|
||||||
self._cookies = None
|
self._cookies = None
|
||||||
|
|
||||||
def write(self, data):
|
async def write(self, data):
|
||||||
"""Writes a chunk of data to the streaming response.
|
"""Writes a chunk of data to the streaming response.
|
||||||
|
|
||||||
:param data: bytes-ish data to be written.
|
:param data: bytes-ish data to be written.
|
||||||
|
@ -66,8 +66,9 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||||
if type(data) != bytes:
|
if type(data) != bytes:
|
||||||
data = self._encode_body(data)
|
data = self._encode_body(data)
|
||||||
|
|
||||||
self.transport.write(
|
self.protocol.push_data(
|
||||||
b"%x\r\n%b\r\n" % (len(data), data))
|
b"%x\r\n%b\r\n" % (len(data), data))
|
||||||
|
await self.protocol.drain()
|
||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||||
|
@ -77,10 +78,12 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||||
headers = self.get_headers(
|
headers = self.get_headers(
|
||||||
version, keep_alive=keep_alive,
|
version, keep_alive=keep_alive,
|
||||||
keep_alive_timeout=keep_alive_timeout)
|
keep_alive_timeout=keep_alive_timeout)
|
||||||
self.transport.write(headers)
|
self.protocol.push_data(headers)
|
||||||
|
await self.protocol.drain()
|
||||||
await self.streaming_fn(self)
|
await self.streaming_fn(self)
|
||||||
self.transport.write(b'0\r\n\r\n')
|
self.protocol.push_data(b'0\r\n\r\n')
|
||||||
|
# no need to await drain here after this write, because it is the
|
||||||
|
# very last thing we write and nothing needs to wait for it.
|
||||||
|
|
||||||
def get_headers(
|
def get_headers(
|
||||||
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||||
|
@ -298,13 +301,13 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
|
||||||
if len(content) < 1:
|
if len(content) < 1:
|
||||||
break
|
break
|
||||||
to_send -= len(content)
|
to_send -= len(content)
|
||||||
response.write(content)
|
await response.write(content)
|
||||||
else:
|
else:
|
||||||
while True:
|
while True:
|
||||||
content = await _file.read(chunk_size)
|
content = await _file.read(chunk_size)
|
||||||
if len(content) < 1:
|
if len(content) < 1:
|
||||||
break
|
break
|
||||||
response.write(content)
|
await response.write(content)
|
||||||
finally:
|
finally:
|
||||||
await _file.close()
|
await _file.close()
|
||||||
return # Returning from this fn closes the stream
|
return # Returning from this fn closes the stream
|
||||||
|
|
|
@ -55,7 +55,8 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
# connection management
|
# connection management
|
||||||
'_total_request_size', '_request_timeout_handler',
|
'_total_request_size', '_request_timeout_handler',
|
||||||
'_response_timeout_handler', '_keep_alive_timeout_handler',
|
'_response_timeout_handler', '_keep_alive_timeout_handler',
|
||||||
'_last_request_time', '_last_response_time', '_is_stream_handler')
|
'_last_request_time', '_last_response_time', '_is_stream_handler',
|
||||||
|
'_not_paused')
|
||||||
|
|
||||||
def __init__(self, *, loop, request_handler, error_handler,
|
def __init__(self, *, loop, request_handler, error_handler,
|
||||||
signal=Signal(), connections=set(), request_timeout=60,
|
signal=Signal(), connections=set(), request_timeout=60,
|
||||||
|
@ -82,6 +83,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self.request_class = request_class or Request
|
self.request_class = request_class or Request
|
||||||
self.is_request_stream = is_request_stream
|
self.is_request_stream = is_request_stream
|
||||||
self._is_stream_handler = False
|
self._is_stream_handler = False
|
||||||
|
self._not_paused = asyncio.Event(loop=loop)
|
||||||
self._total_request_size = 0
|
self._total_request_size = 0
|
||||||
self._request_timeout_handler = None
|
self._request_timeout_handler = None
|
||||||
self._response_timeout_handler = None
|
self._response_timeout_handler = None
|
||||||
|
@ -96,6 +98,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
if 'requests_count' not in self.state:
|
if 'requests_count' not in self.state:
|
||||||
self.state['requests_count'] = 0
|
self.state['requests_count'] = 0
|
||||||
self._debug = debug
|
self._debug = debug
|
||||||
|
self._not_paused.set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keep_alive(self):
|
def keep_alive(self):
|
||||||
|
@ -124,6 +127,12 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
if self._keep_alive_timeout_handler:
|
if self._keep_alive_timeout_handler:
|
||||||
self._keep_alive_timeout_handler.cancel()
|
self._keep_alive_timeout_handler.cancel()
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
self._not_paused.clear()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
self._not_paused.set()
|
||||||
|
|
||||||
def request_timeout_callback(self):
|
def request_timeout_callback(self):
|
||||||
# See the docstring in the RequestTimeout exception, to see
|
# See the docstring in the RequestTimeout exception, to see
|
||||||
# exactly what this timeout is checking for.
|
# exactly what this timeout is checking for.
|
||||||
|
@ -351,6 +360,12 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
self._last_response_time = current_time
|
self._last_response_time = current_time
|
||||||
self.cleanup()
|
self.cleanup()
|
||||||
|
|
||||||
|
async def drain(self):
|
||||||
|
await self._not_paused.wait()
|
||||||
|
|
||||||
|
def push_data(self, data):
|
||||||
|
self.transport.write(data)
|
||||||
|
|
||||||
async def stream_response(self, response):
|
async def stream_response(self, response):
|
||||||
"""
|
"""
|
||||||
Streams a response to the client asynchronously. Attaches
|
Streams a response to the client asynchronously. Attaches
|
||||||
|
@ -360,9 +375,10 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
if self._response_timeout_handler:
|
if self._response_timeout_handler:
|
||||||
self._response_timeout_handler.cancel()
|
self._response_timeout_handler.cancel()
|
||||||
self._response_timeout_handler = None
|
self._response_timeout_handler = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keep_alive = self.keep_alive
|
keep_alive = self.keep_alive
|
||||||
response.transport = self.transport
|
response.protocol = self
|
||||||
await response.stream(
|
await response.stream(
|
||||||
self.request.version, keep_alive, self.keep_alive_timeout)
|
self.request.version, keep_alive, self.keep_alive_timeout)
|
||||||
self.log_response(response)
|
self.log_response(response)
|
||||||
|
|
|
@ -83,7 +83,7 @@ def test_request_stream_app():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
@app.put('/_put')
|
@app.put('/_put')
|
||||||
|
@ -100,7 +100,7 @@ def test_request_stream_app():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
@app.patch('/_patch')
|
@app.patch('/_patch')
|
||||||
|
@ -117,7 +117,7 @@ def test_request_stream_app():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
assert app.is_request_stream is True
|
assert app.is_request_stream is True
|
||||||
|
@ -177,7 +177,7 @@ def test_request_stream_handle_exception():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
# 404
|
# 404
|
||||||
|
@ -231,7 +231,7 @@ def test_request_stream_blueprint():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
@bp.put('/_put')
|
@bp.put('/_put')
|
||||||
|
@ -248,7 +248,7 @@ def test_request_stream_blueprint():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
@bp.patch('/_patch')
|
@bp.patch('/_patch')
|
||||||
|
@ -265,7 +265,7 @@ def test_request_stream_blueprint():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
app.blueprint(bp)
|
app.blueprint(bp)
|
||||||
|
@ -380,7 +380,7 @@ def test_request_stream():
|
||||||
body = await request.stream.get()
|
body = await request.stream.get()
|
||||||
if body is None:
|
if body is None:
|
||||||
break
|
break
|
||||||
response.write(body.decode('utf-8'))
|
await response.write(body.decode('utf-8'))
|
||||||
return stream(streaming)
|
return stream(streaming)
|
||||||
|
|
||||||
@app.get('/get')
|
@app.get('/get')
|
||||||
|
|
|
@ -10,6 +10,7 @@ from random import choice
|
||||||
|
|
||||||
from sanic import Sanic
|
from sanic import Sanic
|
||||||
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
|
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
|
||||||
|
from sanic.server import HttpProtocol
|
||||||
from sanic.testing import HOST, PORT
|
from sanic.testing import HOST, PORT
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
@ -30,9 +31,10 @@ def test_response_body_not_a_string():
|
||||||
|
|
||||||
|
|
||||||
async def sample_streaming_fn(response):
|
async def sample_streaming_fn(response):
|
||||||
response.write('foo,')
|
await response.write('foo,')
|
||||||
await asyncio.sleep(.001)
|
await asyncio.sleep(.001)
|
||||||
response.write('bar')
|
await response.write('bar')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_method_not_allowed():
|
def test_method_not_allowed():
|
||||||
|
@ -189,20 +191,30 @@ def test_stream_response_includes_chunked_header():
|
||||||
|
|
||||||
def test_stream_response_writes_correct_content_to_transport(streaming_app):
|
def test_stream_response_writes_correct_content_to_transport(streaming_app):
|
||||||
response = StreamingHTTPResponse(sample_streaming_fn)
|
response = StreamingHTTPResponse(sample_streaming_fn)
|
||||||
response.transport = MagicMock(asyncio.Transport)
|
response.protocol = MagicMock(HttpProtocol)
|
||||||
|
response.protocol.transport = MagicMock(asyncio.Transport)
|
||||||
|
|
||||||
|
async def mock_drain():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def mock_push_data(data):
|
||||||
|
response.protocol.transport.write(data)
|
||||||
|
|
||||||
|
response.protocol.push_data = mock_push_data
|
||||||
|
response.protocol.drain = mock_drain
|
||||||
|
|
||||||
@streaming_app.listener('after_server_start')
|
@streaming_app.listener('after_server_start')
|
||||||
async def run_stream(app, loop):
|
async def run_stream(app, loop):
|
||||||
await response.stream()
|
await response.stream()
|
||||||
assert response.transport.write.call_args_list[1][0][0] == (
|
assert response.protocol.transport.write.call_args_list[1][0][0] == (
|
||||||
b'4\r\nfoo,\r\n'
|
b'4\r\nfoo,\r\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.transport.write.call_args_list[2][0][0] == (
|
assert response.protocol.transport.write.call_args_list[2][0][0] == (
|
||||||
b'3\r\nbar\r\n'
|
b'3\r\nbar\r\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.transport.write.call_args_list[3][0][0] == (
|
assert response.protocol.transport.write.call_args_list[3][0][0] == (
|
||||||
b'0\r\n\r\n'
|
b'0\r\n\r\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user