diff --git a/.coveragerc b/.coveragerc index f0624d2e..724b2872 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,7 @@ [run] branch = True source = sanic -omit = site-packages, sanic/utils.py +omit = site-packages, sanic/utils.py, sanic/__main__.py [html] directory = coverage diff --git a/requirements-dev.txt b/requirements-dev.txt index 674ef91d..3d94c51d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,3 +10,4 @@ tox ujson; sys_platform != "win32" and implementation_name == "cpython" uvloop; sys_platform != "win32" and implementation_name == "cpython" gunicorn +multidict>=4.0,<5.0 diff --git a/requirements.txt b/requirements.txt index 05968bd8..e320e781 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ httptools ujson; sys_platform != "win32" and implementation_name == "cpython" uvloop; sys_platform != "win32" and implementation_name == "cpython" websockets>=5.0,<6.0 +multidict>=4.0,<5.0 diff --git a/sanic/cookies.py b/sanic/cookies.py index f4cbf6a3..19daac11 100644 --- a/sanic/cookies.py +++ b/sanic/cookies.py @@ -47,16 +47,15 @@ class CookieJar(dict): super().__init__() self.headers = headers self.cookie_headers = {} + self.header_key = "Set-Cookie" def __setitem__(self, key, value): # If this cookie doesn't exist, add it to the header keys - cookie_header = self.cookie_headers.get(key) - if not cookie_header: + if not self.cookie_headers.get(key): cookie = Cookie(key, value) cookie['path'] = '/' - cookie_header = MultiHeader("Set-Cookie") - self.cookie_headers[key] = cookie_header - self.headers[cookie_header] = cookie + self.cookie_headers[key] = self.header_key + self.headers.add(self.header_key, cookie) return super().__setitem__(key, cookie) else: self[key].value = value @@ -67,7 +66,11 @@ class CookieJar(dict): self[key]['max-age'] = 0 else: cookie_header = self.cookie_headers[key] - del self.headers[cookie_header] + # remove it from header + cookies = self.headers.popall(cookie_header) + for cookie in cookies: + if cookie.key != key: + self.headers.add(cookie_header, cookie) del self.cookie_headers[key] return super().__delitem__(key) @@ -124,18 +127,3 @@ class Cookie(dict): output.append('%s=%s' % (self._keys[key], value)) return "; ".join(output).encode(encoding) - -# ------------------------------------------------------------ # -# Header Trickery -# ------------------------------------------------------------ # - - -class MultiHeader: - """String-holding object which allow us to set a header within response - that has a unique key, but may contain duplicate header names - """ - def __init__(self, name): - self.name = name - - def encode(self): - return self.name.encode() diff --git a/sanic/response.py b/sanic/response.py index b32e9daf..62daf91e 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -7,6 +7,7 @@ except BaseException: from json import dumps as json_dumps from aiofiles import open as open_async +from multidict import CIMultiDict from sanic import http from sanic.cookies import CookieJar @@ -53,7 +54,7 @@ class StreamingHTTPResponse(BaseHTTPResponse): self.content_type = content_type self.streaming_fn = streaming_fn self.status = status - self.headers = headers or {} + self.headers = CIMultiDict(headers or {}) self._cookies = None def write(self, data): @@ -124,7 +125,7 @@ class HTTPResponse(BaseHTTPResponse): self.body = body_bytes self.status = status - self.headers = headers or {} + self.headers = CIMultiDict(headers or {}) self._cookies = None def output( diff --git a/sanic/server.py b/sanic/server.py index fc8291b5..11e54edc 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -18,6 +18,7 @@ from time import time from httptools import HttpRequestParser from httptools.parser.errors import HttpParserError +from multidict import CIMultiDict try: import uvloop @@ -39,25 +40,6 @@ class Signal: stopped = False -class CIDict(dict): - """Case Insensitive dict where all keys are converted to lowercase - This does not maintain the inputted case when calling items() or keys() - in favor of speed, since headers are case insensitive - """ - - def get(self, key, default=None): - return super().get(key.casefold(), default) - - def __getitem__(self, key): - return super().__getitem__(key.casefold()) - - def __setitem__(self, key, value): - return super().__setitem__(key.casefold(), value) - - def __contains__(self, key): - return super().__contains__(key.casefold()) - - class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection @@ -256,7 +238,7 @@ class HttpProtocol(asyncio.Protocol): def on_headers_complete(self): self.request = self.request_class( url_bytes=self.url, - headers=CIDict(self.headers), + headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), transport=self.transport diff --git a/setup.py b/setup.py index 73cb559f..34703ab4 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ requirements = [ ujson, 'aiofiles>=0.3.0', 'websockets>=5.0,<6.0', + 'multidict>=4.0,<5.0', ] if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): print("Installing without uJSON") diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 84b493cb..61f50735 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -25,6 +25,7 @@ def test_cookies(): assert response.text == 'Cookies are: working!' assert response_cookies['right_back'].value == 'at you' + @pytest.mark.parametrize("httponly,expected", [ (False, False), (True, True), diff --git a/tests/test_response.py b/tests/test_response.py index 12049460..36259970 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -64,6 +64,25 @@ def test_method_not_allowed(): assert response.headers['Content-Length'] == '0' +def test_response_header(): + app = Sanic('test_response_header') + @app.get('/') + async def test(request): + return json({ + "ok": True + }, headers={ + 'CONTENT-TYPE': 'application/json' + }) + + request, response = app.test_client.get('/') + assert dict(response.headers) == { + 'Connection': 'keep-alive', + 'Keep-Alive': '2', + 'Content-Length': '11', + 'Content-Type': 'application/json', + } + + @pytest.fixture def json_app(): app = Sanic('json')