feat: fixes exception due to unread bytes in stream (#1897)
* feat: fixes exception due to unread bytes in stream * feat: additonal unit tests to cover changes * fix: automated changes by `make fix-import` * fix: additonal changes by `make fix-import` Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
58e4087d4b
commit
3f7c9ea3f5
|
@ -42,7 +42,7 @@ class BaseHTTPResponse:
|
|||
body=b"",
|
||||
):
|
||||
""".. deprecated:: 20.3:
|
||||
This function is not public API and will be removed."""
|
||||
This function is not public API and will be removed."""
|
||||
|
||||
# self.headers get priority over content_type
|
||||
if self.content_type and "Content-Type" not in self.headers:
|
||||
|
@ -249,7 +249,10 @@ def raw(
|
|||
:param content_type: the content type (string) of the response.
|
||||
"""
|
||||
return HTTPResponse(
|
||||
body=body, status=status, headers=headers, content_type=content_type,
|
||||
body=body,
|
||||
status=status,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -452,7 +452,7 @@ class Router:
|
|||
return route_handler, [], kwargs, route.uri, route.name
|
||||
|
||||
def is_stream_handler(self, request):
|
||||
""" Handler for request is stream or not.
|
||||
"""Handler for request is stream or not.
|
||||
:param request: Request object
|
||||
:return: bool
|
||||
"""
|
||||
|
|
|
@ -418,12 +418,13 @@ class HttpProtocol(asyncio.Protocol):
|
|||
async def stream_append(self):
|
||||
while self._body_chunks:
|
||||
body = self._body_chunks.popleft()
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
if self.request:
|
||||
if self.request.stream.is_full():
|
||||
self.transport.pause_reading()
|
||||
await self.request.stream.put(body)
|
||||
self.transport.resume_reading()
|
||||
else:
|
||||
await self.request.stream.put(body)
|
||||
|
||||
def on_message_complete(self):
|
||||
# Entire request (headers and whole body) is received.
|
||||
|
|
|
@ -103,7 +103,9 @@ class SanicTestClient:
|
|||
|
||||
if self.port:
|
||||
server_kwargs = dict(
|
||||
host=host or self.host, port=self.port, **server_kwargs,
|
||||
host=host or self.host,
|
||||
port=self.port,
|
||||
**server_kwargs,
|
||||
)
|
||||
host, port = host or self.host, self.port
|
||||
else:
|
||||
|
|
|
@ -174,7 +174,7 @@ class GunicornWorker(base.Worker):
|
|||
|
||||
@staticmethod
|
||||
def _create_ssl_context(cfg):
|
||||
""" Creates SSLContext instance for usage in asyncio.create_server.
|
||||
"""Creates SSLContext instance for usage in asyncio.create_server.
|
||||
See ssl.SSLSocket.__init__ for more details.
|
||||
"""
|
||||
ctx = ssl.SSLContext(cfg.ssl_version)
|
||||
|
|
|
@ -244,8 +244,8 @@ async def handler3(request):
|
|||
|
||||
def test_keep_alive_timeout_reuse():
|
||||
"""If the server keep-alive timeout and client keep-alive timeout are
|
||||
both longer than the delay, the client _and_ server will successfully
|
||||
reuse the existing connection."""
|
||||
both longer than the delay, the client _and_ server will successfully
|
||||
reuse the existing connection."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
|
|
@ -46,8 +46,8 @@ def test_custom_context(app):
|
|||
invalid = str(e)
|
||||
|
||||
j = loads(response.body)
|
||||
j['response_mw_valid'] = user
|
||||
j['response_mw_invalid'] = invalid
|
||||
j["response_mw_valid"] = user
|
||||
j["response_mw_invalid"] = invalid
|
||||
return json(j)
|
||||
|
||||
request, response = app.test_client.get("/")
|
||||
|
@ -59,8 +59,7 @@ def test_custom_context(app):
|
|||
"has_missing": False,
|
||||
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
|
||||
"response_mw_valid": "sanic",
|
||||
"response_mw_invalid":
|
||||
"'types.SimpleNamespace' object has no attribute 'missing'"
|
||||
"response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
|
||||
from sanic.blueprints import Blueprint
|
||||
from sanic.exceptions import HeaderExpectationFailed
|
||||
|
@ -6,6 +7,7 @@ from sanic.request import StreamBuffer
|
|||
from sanic.response import json, stream, text
|
||||
from sanic.views import CompositionView, HTTPMethodView
|
||||
from sanic.views import stream as stream_decorator
|
||||
from sanic.server import HttpProtocol
|
||||
|
||||
|
||||
data = "abc" * 1_000_000
|
||||
|
@ -337,6 +339,22 @@ def test_request_stream_handle_exception(app):
|
|||
assert "Method GET not allowed for URL /post/random_id" in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_stream_unread(app):
|
||||
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
|
||||
|
||||
err = None
|
||||
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
|
||||
try:
|
||||
protocol.request = None
|
||||
protocol._body_chunks.append("this is a test")
|
||||
await protocol.stream_append()
|
||||
except AttributeError as e:
|
||||
err = e
|
||||
|
||||
assert err is None and not protocol._body_chunks
|
||||
|
||||
|
||||
def test_request_stream_blueprint(app):
|
||||
"""for self.is_request_stream = True"""
|
||||
bp = Blueprint("test_blueprint_request_stream_blueprint")
|
||||
|
|
Loading…
Reference in New Issue
Block a user