Fix response ci header (#1244)

* add unit tests, which should fail

* fix CIDict

* moving CIDict to avoid circular imports

* fix unit tests

* use multidict for headers

* fix cookie

* add version constraint for multidict

* omit test coverage for __main__.py

* make flake8 happy

* consolidate check in for loop

* travisci retry build
This commit is contained in:
7 2018-07-11 01:44:21 -07:00 committed by Raphael Deem
parent becbc5f9ef
commit 334649dfd4
9 changed files with 38 additions and 44 deletions

View File

@ -1,7 +1,7 @@
[run] [run]
branch = True branch = True
source = sanic source = sanic
omit = site-packages, sanic/utils.py omit = site-packages, sanic/utils.py, sanic/__main__.py
[html] [html]
directory = coverage directory = coverage

View File

@ -10,3 +10,4 @@ tox
ujson; sys_platform != "win32" and implementation_name == "cpython" ujson; sys_platform != "win32" and implementation_name == "cpython"
uvloop; sys_platform != "win32" and implementation_name == "cpython" uvloop; sys_platform != "win32" and implementation_name == "cpython"
gunicorn gunicorn
multidict>=4.0,<5.0

View File

@ -3,3 +3,4 @@ httptools
ujson; sys_platform != "win32" and implementation_name == "cpython" ujson; sys_platform != "win32" and implementation_name == "cpython"
uvloop; sys_platform != "win32" and implementation_name == "cpython" uvloop; sys_platform != "win32" and implementation_name == "cpython"
websockets>=5.0,<6.0 websockets>=5.0,<6.0
multidict>=4.0,<5.0

View File

@ -47,16 +47,15 @@ class CookieJar(dict):
super().__init__() super().__init__()
self.headers = headers self.headers = headers
self.cookie_headers = {} self.cookie_headers = {}
self.header_key = "Set-Cookie"
def __setitem__(self, key, value): def __setitem__(self, key, value):
# If this cookie doesn't exist, add it to the header keys # If this cookie doesn't exist, add it to the header keys
cookie_header = self.cookie_headers.get(key) if not self.cookie_headers.get(key):
if not cookie_header:
cookie = Cookie(key, value) cookie = Cookie(key, value)
cookie['path'] = '/' cookie['path'] = '/'
cookie_header = MultiHeader("Set-Cookie") self.cookie_headers[key] = self.header_key
self.cookie_headers[key] = cookie_header self.headers.add(self.header_key, cookie)
self.headers[cookie_header] = cookie
return super().__setitem__(key, cookie) return super().__setitem__(key, cookie)
else: else:
self[key].value = value self[key].value = value
@ -67,7 +66,11 @@ class CookieJar(dict):
self[key]['max-age'] = 0 self[key]['max-age'] = 0
else: else:
cookie_header = self.cookie_headers[key] cookie_header = self.cookie_headers[key]
del self.headers[cookie_header] # remove it from header
cookies = self.headers.popall(cookie_header)
for cookie in cookies:
if cookie.key != key:
self.headers.add(cookie_header, cookie)
del self.cookie_headers[key] del self.cookie_headers[key]
return super().__delitem__(key) return super().__delitem__(key)
@ -124,18 +127,3 @@ class Cookie(dict):
output.append('%s=%s' % (self._keys[key], value)) output.append('%s=%s' % (self._keys[key], value))
return "; ".join(output).encode(encoding) return "; ".join(output).encode(encoding)
# ------------------------------------------------------------ #
# Header Trickery
# ------------------------------------------------------------ #
class MultiHeader:
"""String-holding object which allow us to set a header within response
that has a unique key, but may contain duplicate header names
"""
def __init__(self, name):
self.name = name
def encode(self):
return self.name.encode()

View File

@ -7,6 +7,7 @@ except BaseException:
from json import dumps as json_dumps from json import dumps as json_dumps
from aiofiles import open as open_async from aiofiles import open as open_async
from multidict import CIMultiDict
from sanic import http from sanic import http
from sanic.cookies import CookieJar from sanic.cookies import CookieJar
@ -53,7 +54,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
self.content_type = content_type self.content_type = content_type
self.streaming_fn = streaming_fn self.streaming_fn = streaming_fn
self.status = status self.status = status
self.headers = headers or {} self.headers = CIMultiDict(headers or {})
self._cookies = None self._cookies = None
def write(self, data): def write(self, data):
@ -124,7 +125,7 @@ class HTTPResponse(BaseHTTPResponse):
self.body = body_bytes self.body = body_bytes
self.status = status self.status = status
self.headers = headers or {} self.headers = CIMultiDict(headers or {})
self._cookies = None self._cookies = None
def output( def output(

View File

@ -18,6 +18,7 @@ from time import time
from httptools import HttpRequestParser from httptools import HttpRequestParser
from httptools.parser.errors import HttpParserError from httptools.parser.errors import HttpParserError
from multidict import CIMultiDict
try: try:
import uvloop import uvloop
@ -39,25 +40,6 @@ class Signal:
stopped = False stopped = False
class CIDict(dict):
"""Case Insensitive dict where all keys are converted to lowercase
This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive
"""
def get(self, key, default=None):
return super().get(key.casefold(), default)
def __getitem__(self, key):
return super().__getitem__(key.casefold())
def __setitem__(self, key, value):
return super().__setitem__(key.casefold(), value)
def __contains__(self, key):
return super().__contains__(key.casefold())
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
__slots__ = ( __slots__ = (
# event loop, connection # event loop, connection
@ -256,7 +238,7 @@ class HttpProtocol(asyncio.Protocol):
def on_headers_complete(self): def on_headers_complete(self):
self.request = self.request_class( self.request = self.request_class(
url_bytes=self.url, url_bytes=self.url,
headers=CIDict(self.headers), headers=CIMultiDict(self.headers),
version=self.parser.get_http_version(), version=self.parser.get_http_version(),
method=self.parser.get_method().decode(), method=self.parser.get_method().decode(),
transport=self.transport transport=self.transport

View File

@ -61,6 +61,7 @@ requirements = [
ujson, ujson,
'aiofiles>=0.3.0', 'aiofiles>=0.3.0',
'websockets>=5.0,<6.0', 'websockets>=5.0,<6.0',
'multidict>=4.0,<5.0',
] ]
if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): if strtobool(os.environ.get("SANIC_NO_UJSON", "no")):
print("Installing without uJSON") print("Installing without uJSON")

View File

@ -25,6 +25,7 @@ def test_cookies():
assert response.text == 'Cookies are: working!' assert response.text == 'Cookies are: working!'
assert response_cookies['right_back'].value == 'at you' assert response_cookies['right_back'].value == 'at you'
@pytest.mark.parametrize("httponly,expected", [ @pytest.mark.parametrize("httponly,expected", [
(False, False), (False, False),
(True, True), (True, True),

View File

@ -64,6 +64,25 @@ def test_method_not_allowed():
assert response.headers['Content-Length'] == '0' assert response.headers['Content-Length'] == '0'
def test_response_header():
app = Sanic('test_response_header')
@app.get('/')
async def test(request):
return json({
"ok": True
}, headers={
'CONTENT-TYPE': 'application/json'
})
request, response = app.test_client.get('/')
assert dict(response.headers) == {
'Connection': 'keep-alive',
'Keep-Alive': '2',
'Content-Length': '11',
'Content-Type': 'application/json',
}
@pytest.fixture @pytest.fixture
def json_app(): def json_app():
app = Sanic('json') app = Sanic('json')