Fix response ci header (#1244)
* add unit tests, which should fail * fix CIDict * moving CIDict to avoid circular imports * fix unit tests * use multidict for headers * fix cookie * add version constraint for multidict * omit test coverage for __main__.py * make flake8 happy * consolidate check in for loop * travisci retry build
This commit is contained in:
@@ -47,16 +47,15 @@ class CookieJar(dict):
|
||||
super().__init__()
|
||||
self.headers = headers
|
||||
self.cookie_headers = {}
|
||||
self.header_key = "Set-Cookie"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# If this cookie doesn't exist, add it to the header keys
|
||||
cookie_header = self.cookie_headers.get(key)
|
||||
if not cookie_header:
|
||||
if not self.cookie_headers.get(key):
|
||||
cookie = Cookie(key, value)
|
||||
cookie['path'] = '/'
|
||||
cookie_header = MultiHeader("Set-Cookie")
|
||||
self.cookie_headers[key] = cookie_header
|
||||
self.headers[cookie_header] = cookie
|
||||
self.cookie_headers[key] = self.header_key
|
||||
self.headers.add(self.header_key, cookie)
|
||||
return super().__setitem__(key, cookie)
|
||||
else:
|
||||
self[key].value = value
|
||||
@@ -67,7 +66,11 @@ class CookieJar(dict):
|
||||
self[key]['max-age'] = 0
|
||||
else:
|
||||
cookie_header = self.cookie_headers[key]
|
||||
del self.headers[cookie_header]
|
||||
# remove it from header
|
||||
cookies = self.headers.popall(cookie_header)
|
||||
for cookie in cookies:
|
||||
if cookie.key != key:
|
||||
self.headers.add(cookie_header, cookie)
|
||||
del self.cookie_headers[key]
|
||||
return super().__delitem__(key)
|
||||
|
||||
@@ -124,18 +127,3 @@ class Cookie(dict):
|
||||
output.append('%s=%s' % (self._keys[key], value))
|
||||
|
||||
return "; ".join(output).encode(encoding)
|
||||
|
||||
# ------------------------------------------------------------ #
|
||||
# Header Trickery
|
||||
# ------------------------------------------------------------ #
|
||||
|
||||
|
||||
class MultiHeader:
|
||||
"""String-holding object which allow us to set a header within response
|
||||
that has a unique key, but may contain duplicate header names
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def encode(self):
|
||||
return self.name.encode()
|
||||
|
||||
@@ -7,6 +7,7 @@ except BaseException:
|
||||
from json import dumps as json_dumps
|
||||
|
||||
from aiofiles import open as open_async
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from sanic import http
|
||||
from sanic.cookies import CookieJar
|
||||
@@ -53,7 +54,7 @@ class StreamingHTTPResponse(BaseHTTPResponse):
|
||||
self.content_type = content_type
|
||||
self.streaming_fn = streaming_fn
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self.headers = CIMultiDict(headers or {})
|
||||
self._cookies = None
|
||||
|
||||
def write(self, data):
|
||||
@@ -124,7 +125,7 @@ class HTTPResponse(BaseHTTPResponse):
|
||||
self.body = body_bytes
|
||||
|
||||
self.status = status
|
||||
self.headers = headers or {}
|
||||
self.headers = CIMultiDict(headers or {})
|
||||
self._cookies = None
|
||||
|
||||
def output(
|
||||
|
||||
@@ -18,6 +18,7 @@ from time import time
|
||||
|
||||
from httptools import HttpRequestParser
|
||||
from httptools.parser.errors import HttpParserError
|
||||
from multidict import CIMultiDict
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
@@ -39,25 +40,6 @@ class Signal:
|
||||
stopped = False
|
||||
|
||||
|
||||
class CIDict(dict):
|
||||
"""Case Insensitive dict where all keys are converted to lowercase
|
||||
This does not maintain the inputted case when calling items() or keys()
|
||||
in favor of speed, since headers are case insensitive
|
||||
"""
|
||||
|
||||
def get(self, key, default=None):
|
||||
return super().get(key.casefold(), default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return super().__getitem__(key.casefold())
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return super().__setitem__(key.casefold(), value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return super().__contains__(key.casefold())
|
||||
|
||||
|
||||
class HttpProtocol(asyncio.Protocol):
|
||||
__slots__ = (
|
||||
# event loop, connection
|
||||
@@ -256,7 +238,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
def on_headers_complete(self):
|
||||
self.request = self.request_class(
|
||||
url_bytes=self.url,
|
||||
headers=CIDict(self.headers),
|
||||
headers=CIMultiDict(self.headers),
|
||||
version=self.parser.get_http_version(),
|
||||
method=self.parser.get_method().decode(),
|
||||
transport=self.transport
|
||||
|
||||
Reference in New Issue
Block a user