Changed output to use a default_header dictionary and a ChainMap to unnecessary conditionals and simplified range parsing logic
This commit is contained in:
parent
b4a23bf1a6
commit
ca02ed3b9f
|
@ -1,6 +1,7 @@
|
|||
from aiofiles import open as open_async
|
||||
from mimetypes import guess_type
|
||||
from os import path
|
||||
from collections import ChainMap
|
||||
|
||||
from ujson import dumps as json_dumps
|
||||
|
||||
|
@ -97,26 +98,24 @@ class HTTPResponse:
|
|||
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||
# This is all returned in a kind-of funky way
|
||||
# We tried to make this as fast as possible in pure python
|
||||
if keep_alive and keep_alive_timeout:
|
||||
if 'Keep-Alive' not in self.headers:
|
||||
self.headers['Keep-Alive'] = keep_alive_timeout
|
||||
if 'Connection' not in self.headers:
|
||||
default_header = dict()
|
||||
if keep_alive:
|
||||
self.headers['Connection'] = 'keep-alive'
|
||||
if keep_alive_timeout:
|
||||
default_header['Keep-Alive'] = keep_alive_timeout
|
||||
default_header['Connection'] = 'keep-alive'
|
||||
else:
|
||||
self.headers['Connection'] = 'close'
|
||||
if 'Content-Length' not in self.headers:
|
||||
self.headers['Content-Length'] = len(self.body)
|
||||
if 'Content-Type' not in self.headers:
|
||||
self.headers['Content-Type'] = self.content_type
|
||||
default_header['Connection'] = 'close'
|
||||
default_header['Content-Length'] = len(self.body)
|
||||
default_header['Content-Type'] = self.content_type
|
||||
headers = b''
|
||||
if self.headers:
|
||||
for name, value in self.headers.items():
|
||||
for name, value in ChainMap(self.headers, default_header).items():
|
||||
try:
|
||||
headers += (b'%b: %b\r\n' % (
|
||||
headers += (
|
||||
b'%b: %b\r\n' % (
|
||||
name.encode(), value.encode('utf-8')))
|
||||
except AttributeError:
|
||||
headers += (b'%b: %b\r\n' % (
|
||||
headers += (
|
||||
b'%b: %b\r\n' % (
|
||||
str(name).encode(), str(value).encode('utf-8')))
|
||||
# Try to pull from the common codes first
|
||||
# Speeds up response rate 6% over pulling from all
|
||||
|
@ -148,46 +147,29 @@ class ContentRangeHandler:
|
|||
__slots__ = ('start', 'end', 'size', 'total', 'headers')
|
||||
|
||||
def __init__(self, request, stats):
|
||||
self.start = self.size = 0
|
||||
self.size = self.start = 0
|
||||
self.end = None
|
||||
self.headers = dict()
|
||||
self.total = stats.st_size
|
||||
_range = request.headers.get('Range')
|
||||
if _range:
|
||||
self.start, self.end = ContentRangeHandler.parse_range(_range)
|
||||
if self.start is not None and self.end is not None:
|
||||
self.size = self.end - self.start
|
||||
elif self.end is not None:
|
||||
self.size = self.end
|
||||
elif self.start is not None:
|
||||
self.size = self.total - self.start
|
||||
else:
|
||||
self.size = self.total
|
||||
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
|
||||
self.start, self.end, self.total)
|
||||
else:
|
||||
self.size = self.total
|
||||
|
||||
def __bool__(self):
|
||||
return self.size > 0
|
||||
|
||||
@staticmethod
|
||||
def parse_range(range_header):
|
||||
unit, _, value = tuple(map(str.strip, range_header.partition('=')))
|
||||
if _range is None:
|
||||
return
|
||||
unit, _, value = tuple(map(str.strip, _range.partition('=')))
|
||||
if unit != 'bytes':
|
||||
return None
|
||||
return
|
||||
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
|
||||
try:
|
||||
start = int(start_b) if start_b.strip() else None
|
||||
end = int(end_b) if end_b.strip() else None
|
||||
self.start = int(start_b) if start_b else 0
|
||||
self.end = int(end_b) if end_b else 0
|
||||
except ValueError:
|
||||
return None
|
||||
if end is not None:
|
||||
if start is None:
|
||||
if end != 0:
|
||||
start = -end
|
||||
end = None
|
||||
return start, end
|
||||
self.start = self.end = 0
|
||||
return
|
||||
self.size = self.end - self.start
|
||||
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
|
||||
self.start, self.end, self.total)
|
||||
|
||||
def __bool__(self):
|
||||
return self.size != 0
|
||||
|
||||
|
||||
def json(body, status=200, headers=None, **kwargs):
|
||||
|
|
|
@ -74,7 +74,7 @@ def register(app, uri, file_or_directory, pattern,
|
|||
if _range.start >= _range.total or _range.end == 0:
|
||||
raise ContentRangeError('Content-Range malformed',
|
||||
_range)
|
||||
if _range.start == 0 and _range.size == _range.total:
|
||||
if _range.start == 0 and _range.size == 0:
|
||||
_range = None
|
||||
else:
|
||||
headers['Content-Length'] = str(_range.size)
|
||||
|
|
|
@ -74,7 +74,7 @@ def test_static_head_request(static_file_path, static_file_content):
|
|||
assert int(response.headers['Content-Length']) == len(static_file_content)
|
||||
|
||||
|
||||
def test_static_content_range(static_file_path, static_file_content):
|
||||
def test_static_content_range_correct(static_file_path, static_file_content):
|
||||
app = Sanic('test_static')
|
||||
app.static('/testing.file', static_file_path, use_content_range=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user