diff --git a/docs/sanic/streaming.md b/docs/sanic/streaming.md index a785322a..bf3ca664 100644 --- a/docs/sanic/streaming.md +++ b/docs/sanic/streaming.md @@ -37,7 +37,7 @@ async def handler(request): if body is None: break body = body.decode('utf-8').replace('1', 'A') - response.write(body) + await response.write(body) return stream(streaming) @@ -85,8 +85,8 @@ app = Sanic(__name__) @app.route("/") async def test(request): async def sample_streaming_fn(response): - response.write('foo,') - response.write('bar') + await response.write('foo,') + await response.write('bar') return stream(sample_streaming_fn, content_type='text/csv') ``` @@ -100,7 +100,7 @@ async def index(request): conn = await asyncpg.connect(database='test') async with conn.transaction(): async for record in conn.cursor('SELECT generate_series(0, 10)'): - response.write(record[0]) + await response.write(record[0]) return stream(stream_from_db) ``` diff --git a/examples/request_stream/server.py b/examples/request_stream/server.py index e53a224c..d3d35aef 100644 --- a/examples/request_stream/server.py +++ b/examples/request_stream/server.py @@ -30,7 +30,7 @@ async def handler(request): if body is None: break body = body.decode('utf-8').replace('1', 'A') - response.write(body) + await response.write(body) return stream(streaming) diff --git a/sanic/response.py b/sanic/response.py index 2d2f5b96..f169b4f2 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -46,7 +46,7 @@ class BaseHTTPResponse: class StreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( - 'transport', 'streaming_fn', 'status', + 'protocol', 'streaming_fn', 'status', 'content_type', 'headers', '_cookies' ) @@ -58,7 +58,7 @@ class StreamingHTTPResponse(BaseHTTPResponse): self.headers = CIMultiDict(headers or {}) self._cookies = None - def write(self, data): + async def write(self, data): """Writes a chunk of data to the streaming response. :param data: bytes-ish data to be written. @@ -66,8 +66,9 @@ class StreamingHTTPResponse(BaseHTTPResponse): if type(data) != bytes: data = self._encode_body(data) - self.transport.write( + self.protocol.push_data( b"%x\r\n%b\r\n" % (len(data), data)) + await self.protocol.drain() async def stream( self, version="1.1", keep_alive=False, keep_alive_timeout=None): @@ -77,10 +78,12 @@ class StreamingHTTPResponse(BaseHTTPResponse): headers = self.get_headers( version, keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout) - self.transport.write(headers) - + self.protocol.push_data(headers) + await self.protocol.drain() await self.streaming_fn(self) - self.transport.write(b'0\r\n\r\n') + self.protocol.push_data(b'0\r\n\r\n') + # no need to await drain here after this write, because it is the + # very last thing we write and nothing needs to wait for it. def get_headers( self, version="1.1", keep_alive=False, keep_alive_timeout=None): @@ -298,13 +301,13 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None, if len(content) < 1: break to_send -= len(content) - response.write(content) + await response.write(content) else: while True: content = await _file.read(chunk_size) if len(content) < 1: break - response.write(content) + await response.write(content) finally: await _file.close() return # Returning from this fn closes the stream diff --git a/sanic/server.py b/sanic/server.py index 11e54edc..d5a8f211 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -55,7 +55,8 @@ class HttpProtocol(asyncio.Protocol): # connection management '_total_request_size', '_request_timeout_handler', '_response_timeout_handler', '_keep_alive_timeout_handler', - '_last_request_time', '_last_response_time', '_is_stream_handler') + '_last_request_time', '_last_response_time', '_is_stream_handler', + '_not_paused') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, @@ -82,6 +83,7 @@ class HttpProtocol(asyncio.Protocol): self.request_class = request_class or Request self.is_request_stream = is_request_stream self._is_stream_handler = False + self._not_paused = asyncio.Event(loop=loop) self._total_request_size = 0 self._request_timeout_handler = None self._response_timeout_handler = None @@ -96,6 +98,7 @@ class HttpProtocol(asyncio.Protocol): if 'requests_count' not in self.state: self.state['requests_count'] = 0 self._debug = debug + self._not_paused.set() @property def keep_alive(self): @@ -124,6 +127,12 @@ class HttpProtocol(asyncio.Protocol): if self._keep_alive_timeout_handler: self._keep_alive_timeout_handler.cancel() + def pause_writing(self): + self._not_paused.clear() + + def resume_writing(self): + self._not_paused.set() + def request_timeout_callback(self): # See the docstring in the RequestTimeout exception, to see # exactly what this timeout is checking for. @@ -351,6 +360,12 @@ class HttpProtocol(asyncio.Protocol): self._last_response_time = current_time self.cleanup() + async def drain(self): + await self._not_paused.wait() + + def push_data(self, data): + self.transport.write(data) + async def stream_response(self, response): """ Streams a response to the client asynchronously. Attaches @@ -360,9 +375,10 @@ class HttpProtocol(asyncio.Protocol): if self._response_timeout_handler: self._response_timeout_handler.cancel() self._response_timeout_handler = None + try: keep_alive = self.keep_alive - response.transport = self.transport + response.protocol = self await response.stream( self.request.version, keep_alive, self.keep_alive_timeout) self.log_response(response) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 4ca4e44e..b14aa519 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -83,7 +83,7 @@ def test_request_stream_app(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.put('/_put') @@ -100,7 +100,7 @@ def test_request_stream_app(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.patch('/_patch') @@ -117,7 +117,7 @@ def test_request_stream_app(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) assert app.is_request_stream is True @@ -177,7 +177,7 @@ def test_request_stream_handle_exception(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) # 404 @@ -231,7 +231,7 @@ def test_request_stream_blueprint(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @bp.put('/_put') @@ -248,7 +248,7 @@ def test_request_stream_blueprint(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @bp.patch('/_patch') @@ -265,7 +265,7 @@ def test_request_stream_blueprint(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) app.blueprint(bp) @@ -380,7 +380,7 @@ def test_request_stream(): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.get('/get') diff --git a/tests/test_response.py b/tests/test_response.py index 96c06ce6..6dcd2ea6 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -10,6 +10,7 @@ from random import choice from sanic import Sanic from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +from sanic.server import HttpProtocol from sanic.testing import HOST, PORT from unittest.mock import MagicMock @@ -30,9 +31,10 @@ def test_response_body_not_a_string(): async def sample_streaming_fn(response): - response.write('foo,') + await response.write('foo,') await asyncio.sleep(.001) - response.write('bar') + await response.write('bar') + def test_method_not_allowed(): @@ -189,20 +191,30 @@ def test_stream_response_includes_chunked_header(): def test_stream_response_writes_correct_content_to_transport(streaming_app): response = StreamingHTTPResponse(sample_streaming_fn) - response.transport = MagicMock(asyncio.Transport) + response.protocol = MagicMock(HttpProtocol) + response.protocol.transport = MagicMock(asyncio.Transport) + + async def mock_drain(): + pass + + def mock_push_data(data): + response.protocol.transport.write(data) + + response.protocol.push_data = mock_push_data + response.protocol.drain = mock_drain @streaming_app.listener('after_server_start') async def run_stream(app, loop): await response.stream() - assert response.transport.write.call_args_list[1][0][0] == ( + assert response.protocol.transport.write.call_args_list[1][0][0] == ( b'4\r\nfoo,\r\n' ) - assert response.transport.write.call_args_list[2][0][0] == ( + assert response.protocol.transport.write.call_args_list[2][0][0] == ( b'3\r\nbar\r\n' ) - assert response.transport.write.call_args_list[3][0][0] == ( + assert response.protocol.transport.write.call_args_list[3][0][0] == ( b'0\r\n\r\n' )