feat: fixes exception due to unread bytes in stream (#1897)
* feat: fixes exception due to unread bytes in stream * feat: additonal unit tests to cover changes * fix: automated changes by `make fix-import` * fix: additonal changes by `make fix-import` Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
58e4087d4b
commit
3f7c9ea3f5
|
@ -249,7 +249,10 @@ def raw(
|
||||||
:param content_type: the content type (string) of the response.
|
:param content_type: the content type (string) of the response.
|
||||||
"""
|
"""
|
||||||
return HTTPResponse(
|
return HTTPResponse(
|
||||||
body=body, status=status, headers=headers, content_type=content_type,
|
body=body,
|
||||||
|
status=status,
|
||||||
|
headers=headers,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -452,7 +452,7 @@ class Router:
|
||||||
return route_handler, [], kwargs, route.uri, route.name
|
return route_handler, [], kwargs, route.uri, route.name
|
||||||
|
|
||||||
def is_stream_handler(self, request):
|
def is_stream_handler(self, request):
|
||||||
""" Handler for request is stream or not.
|
"""Handler for request is stream or not.
|
||||||
:param request: Request object
|
:param request: Request object
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -418,6 +418,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||||
async def stream_append(self):
|
async def stream_append(self):
|
||||||
while self._body_chunks:
|
while self._body_chunks:
|
||||||
body = self._body_chunks.popleft()
|
body = self._body_chunks.popleft()
|
||||||
|
if self.request:
|
||||||
if self.request.stream.is_full():
|
if self.request.stream.is_full():
|
||||||
self.transport.pause_reading()
|
self.transport.pause_reading()
|
||||||
await self.request.stream.put(body)
|
await self.request.stream.put(body)
|
||||||
|
|
|
@ -103,7 +103,9 @@ class SanicTestClient:
|
||||||
|
|
||||||
if self.port:
|
if self.port:
|
||||||
server_kwargs = dict(
|
server_kwargs = dict(
|
||||||
host=host or self.host, port=self.port, **server_kwargs,
|
host=host or self.host,
|
||||||
|
port=self.port,
|
||||||
|
**server_kwargs,
|
||||||
)
|
)
|
||||||
host, port = host or self.host, self.port
|
host, port = host or self.host, self.port
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -174,7 +174,7 @@ class GunicornWorker(base.Worker):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_ssl_context(cfg):
|
def _create_ssl_context(cfg):
|
||||||
""" Creates SSLContext instance for usage in asyncio.create_server.
|
"""Creates SSLContext instance for usage in asyncio.create_server.
|
||||||
See ssl.SSLSocket.__init__ for more details.
|
See ssl.SSLSocket.__init__ for more details.
|
||||||
"""
|
"""
|
||||||
ctx = ssl.SSLContext(cfg.ssl_version)
|
ctx = ssl.SSLContext(cfg.ssl_version)
|
||||||
|
|
|
@ -46,8 +46,8 @@ def test_custom_context(app):
|
||||||
invalid = str(e)
|
invalid = str(e)
|
||||||
|
|
||||||
j = loads(response.body)
|
j = loads(response.body)
|
||||||
j['response_mw_valid'] = user
|
j["response_mw_valid"] = user
|
||||||
j['response_mw_invalid'] = invalid
|
j["response_mw_invalid"] = invalid
|
||||||
return json(j)
|
return json(j)
|
||||||
|
|
||||||
request, response = app.test_client.get("/")
|
request, response = app.test_client.get("/")
|
||||||
|
@ -59,8 +59,7 @@ def test_custom_context(app):
|
||||||
"has_missing": False,
|
"has_missing": False,
|
||||||
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
|
"invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
|
||||||
"response_mw_valid": "sanic",
|
"response_mw_valid": "sanic",
|
||||||
"response_mw_invalid":
|
"response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'",
|
||||||
"'types.SimpleNamespace' object has no attribute 'missing'"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from sanic.blueprints import Blueprint
|
from sanic.blueprints import Blueprint
|
||||||
from sanic.exceptions import HeaderExpectationFailed
|
from sanic.exceptions import HeaderExpectationFailed
|
||||||
|
@ -6,6 +7,7 @@ from sanic.request import StreamBuffer
|
||||||
from sanic.response import json, stream, text
|
from sanic.response import json, stream, text
|
||||||
from sanic.views import CompositionView, HTTPMethodView
|
from sanic.views import CompositionView, HTTPMethodView
|
||||||
from sanic.views import stream as stream_decorator
|
from sanic.views import stream as stream_decorator
|
||||||
|
from sanic.server import HttpProtocol
|
||||||
|
|
||||||
|
|
||||||
data = "abc" * 1_000_000
|
data = "abc" * 1_000_000
|
||||||
|
@ -337,6 +339,22 @@ def test_request_stream_handle_exception(app):
|
||||||
assert "Method GET not allowed for URL /post/random_id" in response.text
|
assert "Method GET not allowed for URL /post/random_id" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_stream_unread(app):
|
||||||
|
"""ensure no error is raised when leaving unread bytes in byte-buffer"""
|
||||||
|
|
||||||
|
err = None
|
||||||
|
protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app)
|
||||||
|
try:
|
||||||
|
protocol.request = None
|
||||||
|
protocol._body_chunks.append("this is a test")
|
||||||
|
await protocol.stream_append()
|
||||||
|
except AttributeError as e:
|
||||||
|
err = e
|
||||||
|
|
||||||
|
assert err is None and not protocol._body_chunks
|
||||||
|
|
||||||
|
|
||||||
def test_request_stream_blueprint(app):
|
def test_request_stream_blueprint(app):
|
||||||
"""for self.is_request_stream = True"""
|
"""for self.is_request_stream = True"""
|
||||||
bp = Blueprint("test_blueprint_request_stream_blueprint")
|
bp = Blueprint("test_blueprint_request_stream_blueprint")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user