diff --git a/requirements.txt b/requirements.txt index 3acfbb1f..cef8660e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ httptools ujson uvloop aiofiles -multidict diff --git a/sanic/response.py b/sanic/response.py index 9c7bd2b5..ba10b8c4 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -83,10 +83,10 @@ class HTTPResponse: if body is not None: try: # Try to encode it regularly - self.body = body.encode('utf-8') + self.body = body.encode() except AttributeError: # Convert it to a str if you can't - self.body = str(body).encode('utf-8') + self.body = str(body).encode() else: self.body = body_bytes diff --git a/sanic/server.py b/sanic/server.py index ade02564..711117dc 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -1,7 +1,6 @@ import asyncio from functools import partial from inspect import isawaitable -from multidict import CIMultiDict from signal import SIGINT, SIGTERM from time import time from httptools import HttpRequestParser @@ -18,11 +17,30 @@ from .request import Request from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage +current_time = None + + class Signal: stopped = False -current_time = None +class CIDict(dict): + """ + Case Insensitive dict where all keys are converted to lowercase + This does not maintain the inputted case when calling items() or keys() + in favor of speed, since headers are case insensitive + """ + def get(self, key, default=None): + return super().get(key.casefold(), default) + + def __getitem__(self, key): + return super().__getitem__(key.casefold()) + + def __setitem__(self, key, value): + return super().__setitem__(key.casefold(), value) + + def __contains__(self, key): + return super().__contains__(key.casefold()) class HttpProtocol(asyncio.Protocol): @@ -118,12 +136,12 @@ class HttpProtocol(asyncio.Protocol): exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) - self.headers.append((name.decode(), value.decode('utf-8'))) + self.headers.append((name.decode().casefold(), value.decode())) def on_headers_complete(self): self.request = Request( url_bytes=self.url, - headers=CIMultiDict(self.headers), + headers=CIDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), transport=self.transport