feat: fixes exception due to unread bytes in stream (#1897)

* feat: fixes exception due to unread bytes in stream

* feat: additonal unit tests to cover changes

* fix: automated changes by `make fix-import`

* fix: additonal changes by `make fix-import`

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Andrew Scott 2020-08-27 00:22:02 -07:00 committed by GitHub
parent 58e4087d4b
commit 3f7c9ea3f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 40 additions and 17 deletions

View File

@ -42,7 +42,7 @@ class BaseHTTPResponse:
body=b"",
):
""".. deprecated:: 20.3:
This function is not public API and will be removed."""
This function is not public API and will be removed."""
# self.headers get priority over content_type
if self.content_type and "Content-Type" not in self.headers:
@ -249,7 +249,10 @@ def raw(
:param content_type: the content type (string) of the response.
"""
return HTTPResponse(
body=body, status=status, headers=headers, content_type=content_type,
body=body,
status=status,
headers=headers,
content_type=content_type,
)

View File

@ -452,7 +452,7 @@ class Router:
return route_handler, [], kwargs, route.uri, route.name
def is_stream_handler(self, request):
""" Handler for request is stream or not.
"""Handler for request is stream or not.
:param request: Request object
:return: bool
"""

View File

@ -418,12 +418,13 @@ class HttpProtocol(asyncio.Protocol):
async def stream_append(self):
while self._body_chunks:
body = self._body_chunks.popleft()
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
if self.request:
if self.request.stream.is_full():
self.transport.pause_reading()
await self.request.stream.put(body)
self.transport.resume_reading()
else:
await self.request.stream.put(body)
def on_message_complete(self):
# Entire request (headers and whole body) is received.

View File

@ -103,7 +103,9 @@ class SanicTestClient:
if self.port:
server_kwargs = dict(
host=host or self.host, port=self.port, **server_kwargs,
host=host or self.host,
port=self.port,
**server_kwargs,
)
host, port = host or self.host, self.port
else:

View File

@ -174,7 +174,7 @@ class GunicornWorker(base.Worker):
@staticmethod
def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server.
"""Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
ctx = ssl.SSLContext(cfg.ssl_version)

View File

@ -244,8 +244,8 @@ async def handler3(request):
def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully
reuse the existing connection."""
both longer than the delay, the client _and_ server will successfully
reuse the existing connection."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

View File

@ -46,8 +46,8 @@ def test_custom_context(app):
invalid = str(e)
j = loads(response.body)
j['response_mw_valid'] = user
j['response_mw_invalid'] = invalid
j["response_mw_valid"] = user
j["response_mw_invalid"] = invalid
return json(j)
request, response = app.test_client.get("/")
@ -59,8 +59,7 @@ def test_custom_context(app):
"has_missing": False,
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
"response_mw_valid": "sanic",
"response_mw_invalid":
"'types.SimpleNamespace' object has no attribute 'missing'"
"response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
}

View File

@ -1,4 +1,5 @@
import pytest
import asyncio
from sanic.blueprints import Blueprint
from sanic.exceptions import HeaderExpectationFailed
@ -6,6 +7,7 @@ from sanic.request import StreamBuffer
from sanic.response import json, stream, text
from sanic.views import CompositionView, HTTPMethodView
from sanic.views import stream as stream_decorator
from sanic.server import HttpProtocol
data = "abc" * 1_000_000
@ -337,6 +339,22 @@ def test_request_stream_handle_exception(app):
assert "Method GET not allowed for URL /post/random_id" in response.text
@pytest.mark.asyncio
async def test_request_stream_unread(app):
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
err = None
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
try:
protocol.request = None
protocol._body_chunks.append("this is a test")
await protocol.stream_append()
except AttributeError as e:
err = e
assert err is None and not protocol._body_chunks
def test_request_stream_blueprint(app):
"""for self.is_request_stream = True"""
bp = Blueprint("test_blueprint_request_stream_blueprint")