Trimmed down features of CIMultiDict

This commit is contained in:
Channel Cat 2017-01-16 16:12:42 -08:00
parent 48d496936a
commit 41918eaf0a
3 changed files with 23 additions and 6 deletions

View File

@ -2,4 +2,3 @@ httptools
ujson ujson
uvloop uvloop
aiofiles aiofiles
multidict

View File

@ -106,11 +106,11 @@ class HTTPResponse:
for name, value in self.headers.items(): for name, value in self.headers.items():
try: try:
headers += ( headers += (
b'%b: %b\r\n' % (name.encode(), value.encode('utf-8'))) b'%b: %b\r\n' % (name.title().encode(), value.encode('utf-8')))
except AttributeError: except AttributeError:
headers += ( headers += (
b'%b: %b\r\n' % ( b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8'))) str(name).title().encode(), str(value).encode('utf-8')))
# Try to pull from the common codes first # Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all # Speeds up response rate 6% over pulling from all

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from multidict import CIMultiDict
from signal import SIGINT, SIGTERM from signal import SIGINT, SIGTERM
from time import time from time import time
from httptools import HttpRequestParser from httptools import HttpRequestParser
@ -18,11 +17,30 @@ from .request import Request
from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage
current_time = None
class Signal: class Signal:
stopped = False stopped = False
current_time = None class CIDict(dict):
"""
Case Insensitive dict where all keys are converted to lowercase
This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive
"""
def get(self, key, default=None):
return super().get(key.casefold(), default)
def __getitem__(self, key):
return super().__getitem__(key.casefold())
def __setitem__(self, key, value):
return super().__setitem__(key.casefold(), value)
def __contains__(self, key):
return super().__contains__(key.casefold())
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
@ -127,7 +145,7 @@ class HttpProtocol(asyncio.Protocol):
self.request = Request( self.request = Request(
url_bytes=self.url, url_bytes=self.url,
headers=CIMultiDict(self.headers), headers=CIDict(self.headers),
version=self.parser.get_http_version(), version=self.parser.get_http_version(),
method=self.parser.get_method().decode() method=self.parser.get_method().decode()
) )