feat: fixes exception due to unread bytes in stream (#1897)

* feat: fixes exception due to unread bytes in stream

* feat: additonal unit tests to cover changes

* fix: automated changes by `make fix-import`

* fix: additonal changes by `make fix-import`

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Andrew Scott 2020-08-27 00:22:02 -07:00 committed by GitHub
parent 58e4087d4b
commit 3f7c9ea3f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 40 additions and 17 deletions

View File

@ -42,7 +42,7 @@ class BaseHTTPResponse:
body=b"", body=b"",
): ):
""".. deprecated:: 20.3: """.. deprecated:: 20.3:
This function is not public API and will be removed.""" This function is not public API and will be removed."""
# self.headers get priority over content_type # self.headers get priority over content_type
if self.content_type and "Content-Type" not in self.headers: if self.content_type and "Content-Type" not in self.headers:
@ -249,7 +249,10 @@ def raw(
:param content_type: the content type (string) of the response. :param content_type: the content type (string) of the response.
""" """
return HTTPResponse( return HTTPResponse(
body=body, status=status, headers=headers, content_type=content_type, body=body,
status=status,
headers=headers,
content_type=content_type,
) )

View File

@ -452,7 +452,7 @@ class Router:
return route_handler, [], kwargs, route.uri, route.name return route_handler, [], kwargs, route.uri, route.name
def is_stream_handler(self, request): def is_stream_handler(self, request):
""" Handler for request is stream or not. """Handler for request is stream or not.
:param request: Request object :param request: Request object
:return: bool :return: bool
""" """

View File

@ -418,12 +418,13 @@ class HttpProtocol(asyncio.Protocol):
async def stream_append(self): async def stream_append(self):
while self._body_chunks: while self._body_chunks:
body = self._body_chunks.popleft() body = self._body_chunks.popleft()
if self.request.stream.is_full(): if self.request:
self.transport.pause_reading() if self.request.stream.is_full():
await self.request.stream.put(body) self.transport.pause_reading()
self.transport.resume_reading() await self.request.stream.put(body)
else: self.transport.resume_reading()
await self.request.stream.put(body) else:
await self.request.stream.put(body)
def on_message_complete(self): def on_message_complete(self):
# Entire request (headers and whole body) is received. # Entire request (headers and whole body) is received.

View File

@ -103,7 +103,9 @@ class SanicTestClient:
if self.port: if self.port:
server_kwargs = dict( server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs, host=host or self.host,
port=self.port,
**server_kwargs,
) )
host, port = host or self.host, self.port host, port = host or self.host, self.port
else: else:

View File

@ -174,7 +174,7 @@ class GunicornWorker(base.Worker):
@staticmethod @staticmethod
def _create_ssl_context(cfg): def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server. """Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details. See ssl.SSLSocket.__init__ for more details.
""" """
ctx = ssl.SSLContext(cfg.ssl_version) ctx = ssl.SSLContext(cfg.ssl_version)

View File

@ -244,8 +244,8 @@ async def handler3(request):
def test_keep_alive_timeout_reuse(): def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are """If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully both longer than the delay, the client _and_ server will successfully
reuse the existing connection.""" reuse the existing connection."""
try: try:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)

View File

@ -46,8 +46,8 @@ def test_custom_context(app):
invalid = str(e) invalid = str(e)
j = loads(response.body) j = loads(response.body)
j['response_mw_valid'] = user j["response_mw_valid"] = user
j['response_mw_invalid'] = invalid j["response_mw_invalid"] = invalid
return json(j) return json(j)
request, response = app.test_client.get("/") request, response = app.test_client.get("/")
@ -59,8 +59,7 @@ def test_custom_context(app):
"has_missing": False, "has_missing": False,
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'", "invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
"response_mw_valid": "sanic", "response_mw_valid": "sanic",
"response_mw_invalid": "response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
"'types.SimpleNamespace' object has no attribute 'missing'"
} }

View File

@ -1,4 +1,5 @@
import pytest import pytest
import asyncio
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed from sanic.exceptions import HeaderExpectationFailed
@ -6,6 +7,7 @@ from sanic.request import StreamBuffer
from sanic.response import json, stream, text from sanic.response import json, stream, text
from sanic.views import CompositionView, HTTPMethodView from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator from sanic.views import stream as stream_decorator
from sanic.server import HttpProtocol
data = "abc" * 1_000_000 data = "abc" * 1_000_000
@ -337,6 +339,22 @@ def test_request_stream_handle_exception(app):
assert "Method GET not allowed for URL /post/random_id" in response.text assert "Method GET not allowed for URL /post/random_id" in response.text
@pytest.mark.asyncio
async def test_request_stream_unread(app):
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
err = None
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
try:
protocol.request = None
protocol._body_chunks.append("this is a test")
await protocol.stream_append()
except AttributeError as e:
err = e
assert err is None and not protocol._body_chunks
def test_request_stream_blueprint(app): def test_request_stream_blueprint(app):
"""for self.is_request_stream = True""" """for self.is_request_stream = True"""
bp = Blueprint("test_blueprint_request_stream_blueprint") bp = Blueprint("test_blueprint_request_stream_blueprint")