Changed output to use a default_header dictionary and a ChainMap to unnecessary conditionals and simplified range parsing logic
This commit is contained in:
parent
ee5e145e2d
commit
8619e50845
|
@ -1,6 +1,7 @@
|
||||||
from aiofiles import open as open_async
|
from aiofiles import open as open_async
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from os import path
|
from os import path
|
||||||
|
from collections import ChainMap
|
||||||
|
|
||||||
from ujson import dumps as json_dumps
|
from ujson import dumps as json_dumps
|
||||||
|
|
||||||
|
@ -97,26 +98,24 @@ class HTTPResponse:
|
||||||
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
|
||||||
# This is all returned in a kind-of funky way
|
# This is all returned in a kind-of funky way
|
||||||
# We tried to make this as fast as possible in pure python
|
# We tried to make this as fast as possible in pure python
|
||||||
if keep_alive and keep_alive_timeout:
|
default_header = dict()
|
||||||
if 'Keep-Alive' not in self.headers:
|
if keep_alive:
|
||||||
self.headers['Keep-Alive'] = keep_alive_timeout
|
if keep_alive_timeout:
|
||||||
if 'Connection' not in self.headers:
|
default_header['Keep-Alive'] = keep_alive_timeout
|
||||||
if keep_alive:
|
default_header['Connection'] = 'keep-alive'
|
||||||
self.headers['Connection'] = 'keep-alive'
|
else:
|
||||||
else:
|
default_header['Connection'] = 'close'
|
||||||
self.headers['Connection'] = 'close'
|
default_header['Content-Length'] = len(self.body)
|
||||||
if 'Content-Length' not in self.headers:
|
default_header['Content-Type'] = self.content_type
|
||||||
self.headers['Content-Length'] = len(self.body)
|
|
||||||
if 'Content-Type' not in self.headers:
|
|
||||||
self.headers['Content-Type'] = self.content_type
|
|
||||||
headers = b''
|
headers = b''
|
||||||
if self.headers:
|
for name, value in ChainMap(self.headers, default_header).items():
|
||||||
for name, value in self.headers.items():
|
try:
|
||||||
try:
|
headers += (
|
||||||
headers += (b'%b: %b\r\n' % (
|
b'%b: %b\r\n' % (
|
||||||
name.encode(), value.encode('utf-8')))
|
name.encode(), value.encode('utf-8')))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
headers += (b'%b: %b\r\n' % (
|
headers += (
|
||||||
|
b'%b: %b\r\n' % (
|
||||||
str(name).encode(), str(value).encode('utf-8')))
|
str(name).encode(), str(value).encode('utf-8')))
|
||||||
# Try to pull from the common codes first
|
# Try to pull from the common codes first
|
||||||
# Speeds up response rate 6% over pulling from all
|
# Speeds up response rate 6% over pulling from all
|
||||||
|
@ -148,46 +147,29 @@ class ContentRangeHandler:
|
||||||
__slots__ = ('start', 'end', 'size', 'total', 'headers')
|
__slots__ = ('start', 'end', 'size', 'total', 'headers')
|
||||||
|
|
||||||
def __init__(self, request, stats):
|
def __init__(self, request, stats):
|
||||||
self.start = self.size = 0
|
self.size = self.start = 0
|
||||||
self.end = None
|
self.end = None
|
||||||
self.headers = dict()
|
self.headers = dict()
|
||||||
self.total = stats.st_size
|
self.total = stats.st_size
|
||||||
_range = request.headers.get('Range')
|
_range = request.headers.get('Range')
|
||||||
if _range:
|
if _range is None:
|
||||||
self.start, self.end = ContentRangeHandler.parse_range(_range)
|
return
|
||||||
if self.start is not None and self.end is not None:
|
unit, _, value = tuple(map(str.strip, _range.partition('=')))
|
||||||
self.size = self.end - self.start
|
|
||||||
elif self.end is not None:
|
|
||||||
self.size = self.end
|
|
||||||
elif self.start is not None:
|
|
||||||
self.size = self.total - self.start
|
|
||||||
else:
|
|
||||||
self.size = self.total
|
|
||||||
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
|
|
||||||
self.start, self.end, self.total)
|
|
||||||
else:
|
|
||||||
self.size = self.total
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return self.size > 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_range(range_header):
|
|
||||||
unit, _, value = tuple(map(str.strip, range_header.partition('=')))
|
|
||||||
if unit != 'bytes':
|
if unit != 'bytes':
|
||||||
return None
|
return
|
||||||
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
|
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
|
||||||
try:
|
try:
|
||||||
start = int(start_b) if start_b.strip() else None
|
self.start = int(start_b) if start_b else 0
|
||||||
end = int(end_b) if end_b.strip() else None
|
self.end = int(end_b) if end_b else 0
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
self.start = self.end = 0
|
||||||
if end is not None:
|
return
|
||||||
if start is None:
|
self.size = self.end - self.start
|
||||||
if end != 0:
|
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
|
||||||
start = -end
|
self.start, self.end, self.total)
|
||||||
end = None
|
|
||||||
return start, end
|
def __bool__(self):
|
||||||
|
return self.size != 0
|
||||||
|
|
||||||
|
|
||||||
def json(body, status=200, headers=None, **kwargs):
|
def json(body, status=200, headers=None, **kwargs):
|
||||||
|
|
|
@ -74,7 +74,7 @@ def register(app, uri, file_or_directory, pattern,
|
||||||
if _range.start >= _range.total or _range.end == 0:
|
if _range.start >= _range.total or _range.end == 0:
|
||||||
raise ContentRangeError('Content-Range malformed',
|
raise ContentRangeError('Content-Range malformed',
|
||||||
_range)
|
_range)
|
||||||
if _range.start == 0 and _range.size == _range.total:
|
if _range.start == 0 and _range.size == 0:
|
||||||
_range = None
|
_range = None
|
||||||
else:
|
else:
|
||||||
headers['Content-Length'] = str(_range.size)
|
headers['Content-Length'] = str(_range.size)
|
||||||
|
|
|
@ -74,7 +74,7 @@ def test_static_head_request(static_file_path, static_file_content):
|
||||||
assert int(response.headers['Content-Length']) == len(static_file_content)
|
assert int(response.headers['Content-Length']) == len(static_file_content)
|
||||||
|
|
||||||
|
|
||||||
def test_static_content_range(static_file_path, static_file_content):
|
def test_static_content_range_correct(static_file_path, static_file_content):
|
||||||
app = Sanic('test_static')
|
app = Sanic('test_static')
|
||||||
app.static('/testing.file', static_file_path, use_content_range=True)
|
app.static('/testing.file', static_file_path, use_content_range=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user