Add Request.cancel_stream()

This commit is contained in:
38elements 2017-06-06 21:07:32 +09:00
parent 639c9f579d
commit e87c29a248
5 changed files with 61 additions and 5 deletions

View File

@ -41,6 +41,18 @@ async def handler(request):
return stream(streaming) return stream(streaming)
@app.post('/cancel_stream', stream=True)
async def cancel_stream(request):
request.cancel_stream()
result = ''
while True:
body = await request.stream.get()
if body is None:
break
result += body.decode('utf-8')
return text(result)
@bp.put('/bp_stream', stream=True) @bp.put('/bp_stream', stream=True)
async def bp_handler(request): async def bp_handler(request):
result = '' result = ''

View File

@ -9,6 +9,18 @@ bp = Blueprint('blueprint_request_stream')
app = Sanic('request_stream') app = Sanic('request_stream')
@app.post('/cancel_stream', stream=True)
async def cancel_stream(request):
request.cancel_stream()
result = ''
while True:
body = await request.stream.get()
if body is None:
break
result += body.decode('utf-8')
return text(result)
class SimpleView(HTTPMethodView): class SimpleView(HTTPMethodView):
@stream_decorator @stream_decorator

View File

@ -45,7 +45,7 @@ class Request(dict):
__slots__ = ( __slots__ = (
'app', 'headers', 'version', 'method', '_cookies', 'transport', 'app', 'headers', 'version', 'method', '_cookies', 'transport',
'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', 'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files',
'_ip', '_parsed_url', 'uri_template', 'stream' '_ip', '_parsed_url', 'uri_template', 'stream', 'streamable'
) )
def __init__(self, url_bytes, headers, version, method, transport): def __init__(self, url_bytes, headers, version, method, transport):
@ -67,6 +67,7 @@ class Request(dict):
self.uri_template = None self.uri_template = None
self._cookies = None self._cookies = None
self.stream = None self.stream = None
self.streamable = True
@property @property
def json(self): def json(self):
@ -195,6 +196,9 @@ class Request(dict):
self.query_string, self.query_string,
None)) None))
def cancel_stream(self):
self.streamable = False
File = namedtuple('File', ['type', 'body', 'name']) File = namedtuple('File', ['type', 'body', 'name'])

View File

@ -69,7 +69,7 @@ class HttpProtocol(asyncio.Protocol):
'has_log', 'has_log',
# connection management # connection management
'_total_request_size', '_timeout_handler', '_last_communication_time', '_total_request_size', '_timeout_handler', '_last_communication_time',
'_is_stream_handler') '_is_stream_handler', '_is_cancel_stream')
def __init__(self, *, loop, request_handler, error_handler, def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60, signal=Signal(), connections=set(), request_timeout=60,
@ -93,6 +93,7 @@ class HttpProtocol(asyncio.Protocol):
self.request_class = request_class or Request self.request_class = request_class or Request
self.is_request_stream = is_request_stream self.is_request_stream = is_request_stream
self._is_stream_handler = False self._is_stream_handler = False
self._is_cancel_stream = False
self._total_request_size = 0 self._total_request_size = 0
self._timeout_handler = None self._timeout_handler = None
self._last_request_time = None self._last_request_time = None
@ -185,11 +186,20 @@ class HttpProtocol(asyncio.Protocol):
if self._is_stream_handler: if self._is_stream_handler:
self.request.stream = asyncio.Queue() self.request.stream = asyncio.Queue()
self.execute_request_handler() self.execute_request_handler()
if not self.request.streamable:
self._is_cancel_stream = True
self.transport.close()
def on_body(self, body): def on_body(self, body):
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task( if self.request.streamable:
self.request.stream.put(body)) self._request_stream_task = self.loop.create_task(
self.request.stream.put(body))
elif self._is_cancel_stream is True:
return
else:
self._is_cancel_stream = True
self.transport.close()
return return
self.request.body.append(body) self.request.body.append(body)
@ -346,6 +356,7 @@ class HttpProtocol(asyncio.Protocol):
self._request_stream_task = None self._request_stream_task = None
self._total_request_size = 0 self._total_request_size = 0
self._is_stream_handler = False self._is_stream_handler = False
self._is_cancel_stream = False
def close_if_idle(self): def close_if_idle(self):
"""Close the connection if a request is not being sent or received """Close the connection if a request is not being sent or received

View File

@ -6,7 +6,7 @@ from sanic.views import HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
from sanic.response import stream, text from sanic.response import stream, text
data = "abc" * 100000 data = 'abc' * 100000
def test_request_stream_method_view(): def test_request_stream_method_view():
@ -49,6 +49,21 @@ def test_request_stream_app():
app = Sanic('test_request_stream_app') app = Sanic('test_request_stream_app')
contents = []
@app.post('/cancel_stream', stream=True)
async def cancel_stream(request):
assert isinstance(request.stream, asyncio.Queue)
request.cancel_stream()
result = ''
while True:
body = await request.stream.get()
contents.append(body)
if body is None:
break
result += body.decode('utf-8')
return text(result)
@app.get('/get') @app.get('/get')
async def get(request): async def get(request):
assert request.stream is None assert request.stream is None
@ -122,6 +137,8 @@ def test_request_stream_app():
assert app.is_request_stream is True assert app.is_request_stream is True
assert len(contents) == 0
request, response = app.test_client.get('/get') request, response = app.test_client.get('/get')
assert response.status == 200 assert response.status == 200
assert response.text == 'GET' assert response.text == 'GET'