feat: fixes exception due to unread bytes in stream (#1897)

* feat: fixes exception due to unread bytes in stream

* feat: additonal unit tests to cover changes

* fix: automated changes by `make fix-import`

* fix: additonal changes by `make fix-import`

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Andrew Scott 2020-08-27 00:22:02 -07:00 committed by GitHub
parent 58e4087d4b
commit 3f7c9ea3f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 40 additions and 17 deletions

View File

@ -249,7 +249,10 @@ def raw(
:param content_type: the content type (string) of the response. :param content_type: the content type (string) of the response.
""" """
return HTTPResponse( return HTTPResponse(
body=body, status=status, headers=headers, content_type=content_type, body=body,
status=status,
headers=headers,
content_type=content_type,
) )

View File

@ -418,6 +418,7 @@ class HttpProtocol(asyncio.Protocol):
async def stream_append(self): async def stream_append(self):
while self._body_chunks: while self._body_chunks:
body = self._body_chunks.popleft() body = self._body_chunks.popleft()
if self.request:
if self.request.stream.is_full(): if self.request.stream.is_full():
self.transport.pause_reading() self.transport.pause_reading()
await self.request.stream.put(body) await self.request.stream.put(body)

View File

@ -103,7 +103,9 @@ class SanicTestClient:
if self.port: if self.port:
server_kwargs = dict( server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs, host=host or self.host,
port=self.port,
**server_kwargs,
) )
host, port = host or self.host, self.port host, port = host or self.host, self.port
else: else:

View File

@ -46,8 +46,8 @@ def test_custom_context(app):
invalid = str(e) invalid = str(e)
j = loads(response.body) j = loads(response.body)
j['response_mw_valid'] = user j["response_mw_valid"] = user
j['response_mw_invalid'] = invalid j["response_mw_invalid"] = invalid
return json(j) return json(j)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
@ -59,8 +59,7 @@ def test_custom_context(app):
"has_missing": False, "has_missing": False,
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'", "invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
"response_mw_valid": "sanic", "response_mw_valid": "sanic",
"response_mw_invalid": "response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
"'types.SimpleNamespace' object has no attribute 'missing'"
} }

View File

@ -1,4 +1,5 @@
import pytest import pytest
import asyncio
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed from sanic.exceptions import HeaderExpectationFailed
@ -6,6 +7,7 @@ from sanic.request import StreamBuffer
from sanic.response import json, stream, text from sanic.response import json, stream, text
from sanic.views import CompositionView, HTTPMethodView from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
from sanic.server import HttpProtocol
data = "abc" * 1_000_000 data = "abc" * 1_000_000
@ -337,6 +339,22 @@ def test_request_stream_handle_exception(app):
assert "Method GET not allowed for URL /post/random_id" in response.text assert "Method GET not allowed for URL /post/random_id" in response.text
@pytest.mark.asyncio
async def test_request_stream_unread(app):
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
err = None
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
try:
protocol.request = None
protocol._body_chunks.append("this is a test")
await protocol.stream_append()
except AttributeError as e:
err = e
assert err is None and not protocol._body_chunks
def test_request_stream_blueprint(app): def test_request_stream_blueprint(app):
"""for self.is_request_stream = True""" """for self.is_request_stream = True"""
bp = Blueprint("test_blueprint_request_stream_blueprint") bp = Blueprint("test_blueprint_request_stream_blueprint")