Merge pull request #302 from channelcat/request-headers-ci

Trimmed down features of CIMultiDict
This commit is contained in:
Channel Cat 2017-01-16 17:08:12 -08:00 committed by GitHub
commit 2aa380c5a3
3 changed files with 24 additions and 7 deletions

View File

@ -2,4 +2,3 @@ httptools
ujson ujson
uvloop uvloop
aiofiles aiofiles
multidict

View File

@ -83,10 +83,10 @@ class HTTPResponse:
if body is not None: if body is not None:
try: try:
# Try to encode it regularly # Try to encode it regularly
self.body = body.encode('utf-8') self.body = body.encode()
except AttributeError: except AttributeError:
# Convert it to a str if you can't # Convert it to a str if you can't
self.body = str(body).encode('utf-8') self.body = str(body).encode()
else: else:
self.body = body_bytes self.body = body_bytes

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from multidict import CIMultiDict
from signal import SIGINT, SIGTERM from signal import SIGINT, SIGTERM
from time import time from time import time
from httptools import HttpRequestParser from httptools import HttpRequestParser
@ -18,11 +17,30 @@ from .request import Request
from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage
current_time = None
class Signal: class Signal:
stopped = False stopped = False
current_time = None class CIDict(dict):
"""
Case Insensitive dict where all keys are converted to lowercase
This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive
"""
def get(self, key, default=None):
return super().get(key.casefold(), default)
def __getitem__(self, key):
return super().__getitem__(key.casefold())
def __setitem__(self, key, value):
return super().__setitem__(key.casefold(), value)
def __contains__(self, key):
return super().__contains__(key.casefold())
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
@ -118,12 +136,12 @@ class HttpProtocol(asyncio.Protocol):
exception = PayloadTooLarge('Payload Too Large') exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception) self.write_error(exception)
self.headers.append((name.decode(), value.decode('utf-8'))) self.headers.append((name.decode().casefold(), value.decode()))
def on_headers_complete(self): def on_headers_complete(self):
self.request = Request( self.request = Request(
url_bytes=self.url, url_bytes=self.url,
headers=CIMultiDict(self.headers), headers=CIDict(self.headers),
version=self.parser.get_http_version(), version=self.parser.get_http_version(),
method=self.parser.get_method().decode(), method=self.parser.get_method().decode(),
transport=self.transport transport=self.transport