Merge pull request #302 from channelcat/request-headers-ci

Trimmed down features of CIMultiDict
This commit is contained in:
Channel Cat 2017-01-16 17:08:12 -08:00 committed by GitHub
commit 2aa380c5a3
3 changed files with 24 additions and 7 deletions

View File

@ -2,4 +2,3 @@ httptools
ujson
uvloop
aiofiles
multidict

View File

@ -83,10 +83,10 @@ class HTTPResponse:
if body is not None:
try:
# Try to encode it regularly
self.body = body.encode('utf-8')
self.body = body.encode()
except AttributeError:
# Convert it to a str if you can't
self.body = str(body).encode('utf-8')
self.body = str(body).encode()
else:
self.body = body_bytes

View File

@ -1,7 +1,6 @@
import asyncio
from functools import partial
from inspect import isawaitable
from multidict import CIMultiDict
from signal import SIGINT, SIGTERM
from time import time
from httptools import HttpRequestParser
@ -18,11 +17,30 @@ from .request import Request
from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage
current_time = None
class Signal:
stopped = False
current_time = None
class CIDict(dict):
"""
Case Insensitive dict where all keys are converted to lowercase
This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive
"""
def get(self, key, default=None):
return super().get(key.casefold(), default)
def __getitem__(self, key):
return super().__getitem__(key.casefold())
def __setitem__(self, key, value):
return super().__setitem__(key.casefold(), value)
def __contains__(self, key):
return super().__contains__(key.casefold())
class HttpProtocol(asyncio.Protocol):
@ -118,12 +136,12 @@ class HttpProtocol(asyncio.Protocol):
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
self.headers.append((name.decode(), value.decode('utf-8')))
self.headers.append((name.decode().casefold(), value.decode()))
def on_headers_complete(self):
self.request = Request(
url_bytes=self.url,
headers=CIMultiDict(self.headers),
headers=CIDict(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport