Backport stream header fix (#1959)

Resolve headers as body in ASGI mode

* Bump version to 19.12.3

* Update multidict==5.0.0
This commit is contained in:
Adam Hopkins 2020-10-25 14:32:18 +02:00 committed by GitHub
parent 2a44a27236
commit c5070bd449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 42 additions and 21 deletions

View File

@ -1 +1 @@
__version__ = "19.12.2" __version__ = "19.12.3"

View File

@ -347,6 +347,8 @@ class ASGIApp:
if name not in (b"Set-Cookie",) if name not in (b"Set-Cookie",)
] ]
response.asgi = True
if "content-length" not in response.headers and not isinstance( if "content-length" not in response.headers and not isinstance(
response, StreamingHTTPResponse response, StreamingHTTPResponse
): ):

View File

@ -129,27 +129,27 @@ class Request:
def get(self, key, default=None): def get(self, key, default=None):
""".. deprecated:: 19.9 """.. deprecated:: 19.9
Custom context is now stored in `request.custom_context.yourkey`""" Custom context is now stored in `request.custom_context.yourkey`"""
return self.ctx.__dict__.get(key, default) return self.ctx.__dict__.get(key, default)
def __contains__(self, key): def __contains__(self, key):
""".. deprecated:: 19.9 """.. deprecated:: 19.9
Custom context is now stored in `request.custom_context.yourkey`""" Custom context is now stored in `request.custom_context.yourkey`"""
return key in self.ctx.__dict__ return key in self.ctx.__dict__
def __getitem__(self, key): def __getitem__(self, key):
""".. deprecated:: 19.9 """.. deprecated:: 19.9
Custom context is now stored in `request.custom_context.yourkey`""" Custom context is now stored in `request.custom_context.yourkey`"""
return self.ctx.__dict__[key] return self.ctx.__dict__[key]
def __delitem__(self, key): def __delitem__(self, key):
""".. deprecated:: 19.9 """.. deprecated:: 19.9
Custom context is now stored in `request.custom_context.yourkey`""" Custom context is now stored in `request.custom_context.yourkey`"""
del self.ctx.__dict__[key] del self.ctx.__dict__[key]
def __setitem__(self, key, value): def __setitem__(self, key, value):
""".. deprecated:: 19.9 """.. deprecated:: 19.9
Custom context is now stored in `request.custom_context.yourkey`""" Custom context is now stored in `request.custom_context.yourkey`"""
setattr(self.ctx, key, value) setattr(self.ctx, key, value)
def body_init(self): def body_init(self):

View File

@ -22,6 +22,9 @@ except ImportError:
class BaseHTTPResponse: class BaseHTTPResponse:
def __init__(self):
self.asgi = False
def _encode_body(self, data): def _encode_body(self, data):
try: try:
# Try to encode it regularly # Try to encode it regularly
@ -59,6 +62,8 @@ class StreamingHTTPResponse(BaseHTTPResponse):
content_type="text/plain", content_type="text/plain",
chunked=True, chunked=True,
): ):
super().__init__()
self.content_type = content_type self.content_type = content_type
self.streaming_fn = streaming_fn self.streaming_fn = streaming_fn
self.status = status self.status = status
@ -94,8 +99,9 @@ class StreamingHTTPResponse(BaseHTTPResponse):
keep_alive=keep_alive, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout, keep_alive_timeout=keep_alive_timeout,
) )
await self.protocol.push_data(headers) if not getattr(self, "asgi", False):
await self.protocol.drain() await self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self) await self.streaming_fn(self)
if self.chunked: if self.chunked:
await self.protocol.push_data(b"0\r\n\r\n") await self.protocol.push_data(b"0\r\n\r\n")
@ -145,6 +151,8 @@ class HTTPResponse(BaseHTTPResponse):
content_type=None, content_type=None,
body_bytes=b"", body_bytes=b"",
): ):
super().__init__()
self.content_type = content_type self.content_type = content_type
if body is not None: if body is not None:

View File

@ -484,7 +484,7 @@ class Router:
return route_handler, [], kwargs, route.uri, route.name return route_handler, [], kwargs, route.uri, route.name
def is_stream_handler(self, request): def is_stream_handler(self, request):
""" Handler for request is stream or not. """Handler for request is stream or not.
:param request: Request object :param request: Request object
:return: bool :return: bool
""" """

View File

@ -174,7 +174,7 @@ class GunicornWorker(base.Worker):
@staticmethod @staticmethod
def _create_ssl_context(cfg): def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server. """Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details. See ssl.SSLSocket.__init__ for more details.
""" """
ctx = ssl.SSLContext(cfg.ssl_version) ctx = ssl.SSLContext(cfg.ssl_version)

View File

@ -5,7 +5,6 @@ import codecs
import os import os
import re import re
import sys import sys
from distutils.util import strtobool from distutils.util import strtobool
from setuptools import setup from setuptools import setup
@ -39,9 +38,7 @@ def open_local(paths, mode="r", encoding="utf8"):
with open_local(["sanic", "__version__.py"], encoding="latin1") as fp: with open_local(["sanic", "__version__.py"], encoding="latin1") as fp:
try: try:
version = re.findall( version = re.findall(r"^__version__ = \"([^']+)\"\r?$", fp.read(), re.M)[0]
r"^__version__ = \"([^']+)\"\r?$", fp.read(), re.M
)[0]
except IndexError: except IndexError:
raise RuntimeError("Unable to determine version.") raise RuntimeError("Unable to determine version.")
@ -71,9 +68,7 @@ setup_kwargs = {
], ],
} }
env_dependency = ( env_dependency = '; sys_platform != "win32" ' 'and implementation_name == "cpython"'
'; sys_platform != "win32" ' 'and implementation_name == "cpython"'
)
ujson = "ujson>=1.35" + env_dependency ujson = "ujson>=1.35" + env_dependency
uvloop = "uvloop>=0.5.3" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency
@ -83,13 +78,13 @@ requirements = [
ujson, ujson,
"aiofiles>=0.3.0", "aiofiles>=0.3.0",
"websockets>=7.0,<9.0", "websockets>=7.0,<9.0",
"multidict>=4.0,<5.0", "multidict==5.0.0",
"httpx==0.9.3", "httpx==0.9.3",
] ]
tests_require = [ tests_require = [
"pytest==5.2.1", "pytest==5.2.1",
"multidict>=4.0,<5.0", "multidict==5.0.0",
"gunicorn", "gunicorn",
"pytest-cov", "pytest-cov",
"httpcore==0.3.0", "httpcore==0.3.0",

View File

@ -230,8 +230,8 @@ async def handler3(request):
def test_keep_alive_timeout_reuse(): def test_keep_alive_timeout_reuse():
"""If the server keep-alive timeout and client keep-alive timeout are """If the server keep-alive timeout and client keep-alive timeout are
both longer than the delay, the client _and_ server will successfully both longer than the delay, the client _and_ server will successfully
reuse the existing connection.""" reuse the existing connection."""
try: try:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)

View File

@ -232,6 +232,12 @@ def test_chunked_streaming_returns_correct_content(streaming_app):
assert response.text == "foo,bar" assert response.text == "foo,bar"
@pytest.mark.asyncio
async def test_chunked_streaming_returns_correct_content_asgi(streaming_app):
request, response = await streaming_app.asgi_client.get("/")
assert response.text == "4\r\nfoo,\r\n3\r\nbar\r\n0\r\n\r\n"
def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
request, response = non_chunked_streaming_app.test_client.get("/") request, response = non_chunked_streaming_app.test_client.get("/")
assert "Transfer-Encoding" not in response.headers assert "Transfer-Encoding" not in response.headers
@ -239,6 +245,16 @@ def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app):
assert response.headers["Content-Length"] == "7" assert response.headers["Content-Length"] == "7"
@pytest.mark.asyncio
async def test_non_chunked_streaming_adds_correct_headers_asgi(
non_chunked_streaming_app,
):
request, response = await non_chunked_streaming_app.asgi_client.get("/")
assert "Transfer-Encoding" not in response.headers
assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Length"] == "7"
def test_non_chunked_streaming_returns_correct_content( def test_non_chunked_streaming_returns_correct_content(
non_chunked_streaming_app, non_chunked_streaming_app,
): ):