Changed output to use a default_header dictionary and a ChainMap to unnecessary conditionals and simplified range parsing logic

This commit is contained in:
Kyle Blöm 2017-01-28 11:18:52 -08:00
parent ee5e145e2d
commit 8619e50845
3 changed files with 34 additions and 52 deletions

View File

@ -1,6 +1,7 @@
from aiofiles import open as open_async from aiofiles import open as open_async
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from collections import ChainMap
from ujson import dumps as json_dumps from ujson import dumps as json_dumps
@ -97,26 +98,24 @@ class HTTPResponse:
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
if keep_alive and keep_alive_timeout: default_header = dict()
if 'Keep-Alive' not in self.headers:
self.headers['Keep-Alive'] = keep_alive_timeout
if 'Connection' not in self.headers:
if keep_alive: if keep_alive:
self.headers['Connection'] = 'keep-alive' if keep_alive_timeout:
default_header['Keep-Alive'] = keep_alive_timeout
default_header['Connection'] = 'keep-alive'
else: else:
self.headers['Connection'] = 'close' default_header['Connection'] = 'close'
if 'Content-Length' not in self.headers: default_header['Content-Length'] = len(self.body)
self.headers['Content-Length'] = len(self.body) default_header['Content-Type'] = self.content_type
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = self.content_type
headers = b'' headers = b''
if self.headers: for name, value in ChainMap(self.headers, default_header).items():
for name, value in self.headers.items():
try: try:
headers += (b'%b: %b\r\n' % ( headers += (
b'%b: %b\r\n' % (
name.encode(), value.encode('utf-8'))) name.encode(), value.encode('utf-8')))
except AttributeError: except AttributeError:
headers += (b'%b: %b\r\n' % ( headers += (
b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8'))) str(name).encode(), str(value).encode('utf-8')))
# Try to pull from the common codes first # Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all # Speeds up response rate 6% over pulling from all
@ -148,46 +147,29 @@ class ContentRangeHandler:
__slots__ = ('start', 'end', 'size', 'total', 'headers') __slots__ = ('start', 'end', 'size', 'total', 'headers')
def __init__(self, request, stats): def __init__(self, request, stats):
self.start = self.size = 0 self.size = self.start = 0
self.end = None self.end = None
self.headers = dict() self.headers = dict()
self.total = stats.st_size self.total = stats.st_size
_range = request.headers.get('Range') _range = request.headers.get('Range')
if _range: if _range is None:
self.start, self.end = ContentRangeHandler.parse_range(_range) return
if self.start is not None and self.end is not None: unit, _, value = tuple(map(str.strip, _range.partition('=')))
self.size = self.end - self.start
elif self.end is not None:
self.size = self.end
elif self.start is not None:
self.size = self.total - self.start
else:
self.size = self.total
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
self.start, self.end, self.total)
else:
self.size = self.total
def __bool__(self):
return self.size > 0
@staticmethod
def parse_range(range_header):
unit, _, value = tuple(map(str.strip, range_header.partition('=')))
if unit != 'bytes': if unit != 'bytes':
return None return
start_b, _, end_b = tuple(map(str.strip, value.partition('-'))) start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
try: try:
start = int(start_b) if start_b.strip() else None self.start = int(start_b) if start_b else 0
end = int(end_b) if end_b.strip() else None self.end = int(end_b) if end_b else 0
except ValueError: except ValueError:
return None self.start = self.end = 0
if end is not None: return
if start is None: self.size = self.end - self.start
if end != 0: self.headers['Content-Range'] = "bytes %s-%s/%s" % (
start = -end self.start, self.end, self.total)
end = None
return start, end def __bool__(self):
return self.size != 0
def json(body, status=200, headers=None, **kwargs): def json(body, status=200, headers=None, **kwargs):

View File

@ -74,7 +74,7 @@ def register(app, uri, file_or_directory, pattern,
if _range.start >= _range.total or _range.end == 0: if _range.start >= _range.total or _range.end == 0:
raise ContentRangeError('Content-Range malformed', raise ContentRangeError('Content-Range malformed',
_range) _range)
if _range.start == 0 and _range.size == _range.total: if _range.start == 0 and _range.size == 0:
_range = None _range = None
else: else:
headers['Content-Length'] = str(_range.size) headers['Content-Length'] = str(_range.size)

View File

@ -74,7 +74,7 @@ def test_static_head_request(static_file_path, static_file_content):
assert int(response.headers['Content-Length']) == len(static_file_content) assert int(response.headers['Content-Length']) == len(static_file_content)
def test_static_content_range(static_file_path, static_file_content): def test_static_content_range_correct(static_file_path, static_file_content):
app = Sanic('test_static') app = Sanic('test_static')
app.static('/testing.file', static_file_path, use_content_range=True) app.static('/testing.file', static_file_path, use_content_range=True)