Added Range request options for static files

This commit is contained in:
Kyle Blöm 2017-01-26 18:37:16 -08:00
parent 0f64702d72
commit 4942769fbe
4 changed files with 137 additions and 39 deletions

View File

@ -104,6 +104,7 @@ INTERNAL_SERVER_ERROR_HTML = '''
class SanicException(Exception): class SanicException(Exception):
def __init__(self, message, status_code=None): def __init__(self, message, status_code=None):
super().__init__(message) super().__init__(message)
if status_code is not None: if status_code is not None:
self.status_code = status_code self.status_code = status_code
@ -137,6 +138,17 @@ class PayloadTooLarge(SanicException):
status_code = 413 status_code = 413
class ContentRangeError(SanicException):
status_code = 416
def __init__(self, message, content_range):
super().__init__(message)
self.headers = {
'Content-Type': 'text/plain',
"Content-Range": "bytes */%s" % (content_range.total,)
}
class Handler: class Handler:
handlers = None handlers = None
@ -191,7 +203,9 @@ class Handler:
if issubclass(type(exception), SanicException): if issubclass(type(exception), SanicException):
return text( return text(
'Error: {}'.format(exception), 'Error: {}'.format(exception),
status=getattr(exception, 'status_code', 500)) status=getattr(exception, 'status_code', 500),
headers=getattr(exception, 'headers', dict())
)
elif self.debug: elif self.debug:
html_output = self._render_traceback_html(exception, request) html_output = self._render_traceback_html(exception, request)

View File

@ -97,21 +97,27 @@ class HTTPResponse:
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
timeout_header = b''
if keep_alive and keep_alive_timeout: if keep_alive and keep_alive_timeout:
timeout_header = b'Keep-Alive: timeout=%d\r\n' % keep_alive_timeout if 'Keep-Alive' not in self.headers:
self.headers['Keep-Alive'] = keep_alive_timeout
if 'Connection' not in self.headers:
if keep_alive:
self.headers['Connection'] = 'keep-alive'
else:
self.headers['Connection'] = 'close'
if 'Content-Length' not in self.headers:
self.headers['Content-Length'] = len(self.body)
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = self.content_type
headers = b'' headers = b''
if self.headers: if self.headers:
for name, value in self.headers.items(): for name, value in self.headers.items():
try: try:
headers += ( headers += (b'%b: %b\r\n' % (
b'%b: %b\r\n' % (name.encode(), value.encode('utf-8'))) name.encode(), value.encode('utf-8')))
except AttributeError: except AttributeError:
headers += ( headers += (b'%b: %b\r\n' % (
b'%b: %b\r\n' % ( str(name).encode(), str(value).encode('utf-8')))
str(name).encode(), str(value).encode('utf-8')))
# Try to pull from the common codes first # Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all # Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status) status = COMMON_STATUS_CODES.get(self.status)
@ -119,18 +125,11 @@ class HTTPResponse:
status = ALL_STATUS_CODES.get(self.status) status = ALL_STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n' return (b'HTTP/%b %d %b\r\n'
b'Content-Type: %b\r\n' b'%b\r\n'
b'Content-Length: %d\r\n'
b'Connection: %b\r\n'
b'%b%b\r\n'
b'%b') % ( b'%b') % (
version.encode(), version.encode(),
self.status, self.status,
status, status,
self.content_type.encode(),
len(self.body),
b'keep-alive' if keep_alive else b'close',
timeout_header,
headers, headers,
self.body self.body
) )
@ -142,13 +141,62 @@ class HTTPResponse:
return self._cookies return self._cookies
class ContentRangeHandler:
"""
This class is for parsing the request header
"""
__slots__ = ('start', 'end', 'size', 'total', 'headers')
def __init__(self, request, stats):
self.start = self.size = 0
self.end = None
self.headers = dict()
self.total = stats.st_size
range_header = request.headers.get('Range')
if range_header:
self.start, self.end = ContentRangeHandler.parse_range(range_header)
if self.start is not None and self.end is not None:
self.size = self.end - self.start
elif self.end is not None:
self.size = self.end
elif self.start is not None:
self.size = self.total - self.start
else:
self.size = self.total
self.headers['Content-Range'] = "bytes %s-%s/%s" % (
self.start, self.end, self.total)
else:
self.size = self.total
def __bool__(self):
return self.size > 0
@staticmethod
def parse_range(range_header):
unit, _, value = tuple(map(str.strip, range_header.partition('=')))
if unit != 'bytes':
return None
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
try:
start = int(start_b) if start_b.strip() else None
end = int(end_b) if end_b.strip() else None
except ValueError:
return None
if end is not None:
if start is None:
if end != 0:
start = -end
end = None
return start, end
def json(body, status=200, headers=None, **kwargs): def json(body, status=200, headers=None, **kwargs):
""" """
Returns response object with body in json format. Returns response object with body in json format.
:param body: Response data to be serialized. :param body: Response data to be serialized.
:param status: Response code. :param status: Response code.
:param headers: Custom Headers. :param headers: Custom Headers.
:param \**kwargs: Remaining arguments that are passed to the json encoder. :param kwargs: Remaining arguments that are passed to the json encoder.
""" """
return HTTPResponse(json_dumps(body, **kwargs), headers=headers, return HTTPResponse(json_dumps(body, **kwargs), headers=headers,
status=status, content_type="application/json") status=status, content_type="application/json")
@ -176,17 +224,24 @@ def html(body, status=200, headers=None):
content_type="text/html; charset=utf-8") content_type="text/html; charset=utf-8")
async def file(location, mime_type=None, headers=None): async def file(location, mime_type=None, headers=None, _range=None):
""" """
Returns response object with file data. Returns response object with file data.
:param location: Location of file on system. :param location: Location of file on system.
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param _range:
""" """
filename = path.split(location)[-1] filename = path.split(location)[-1]
async with open_async(location, mode='rb') as _file: async with open_async(location, mode='rb') as _file:
out_stream = await _file.read() if _range:
await _file.seek(_range.start)
out_stream = await _file.read(_range.size)
headers['Content-Range'] = 'bytes %s-%s/%s' % (
_range.start, _range.end, _range.total)
else:
out_stream = await _file.read()
mime_type = mime_type or guess_type(filename)[0] or 'text/plain' mime_type = mime_type or guess_type(filename)[0] or 'text/plain'

View File

@ -78,22 +78,22 @@ class Sanic:
# Shorthand method decorators # Shorthand method decorators
def get(self, uri, host=None): def get(self, uri, host=None):
return self.route(uri, methods=["GET"], host=host) return self.route(uri, methods=frozenset({"GET"}), host=host)
def post(self, uri, host=None): def post(self, uri, host=None):
return self.route(uri, methods=["POST"], host=host) return self.route(uri, methods=frozenset({"POST"}), host=host)
def put(self, uri, host=None): def put(self, uri, host=None):
return self.route(uri, methods=["PUT"], host=host) return self.route(uri, methods=frozenset({"PUT"}), host=host)
def head(self, uri, host=None): def head(self, uri, host=None):
return self.route(uri, methods=["HEAD"], host=host) return self.route(uri, methods=frozenset({"HEAD"}), host=host)
def options(self, uri, host=None): def options(self, uri, host=None):
return self.route(uri, methods=["OPTIONS"], host=host) return self.route(uri, methods=frozenset({"OPTIONS"}), host=host)
def patch(self, uri, host=None): def patch(self, uri, host=None):
return self.route(uri, methods=["PATCH"], host=host) return self.route(uri, methods=frozenset({"PATCH"}), host=host)
def add_route(self, handler, uri, methods=None, host=None): def add_route(self, handler, uri, methods=None, host=None):
""" """
@ -117,7 +117,7 @@ class Sanic:
""" """
Decorates a function to be registered as a handler for exceptions Decorates a function to be registered as a handler for exceptions
:param \*exceptions: exceptions :param exceptions: exceptions
:return: decorated function :return: decorated function
""" """
@ -152,13 +152,13 @@ class Sanic:
# Static Files # Static Files
def static(self, uri, file_or_directory, pattern='.+', def static(self, uri, file_or_directory, pattern='.+',
use_modified_since=True): use_modified_since=True, use_content_range=False):
""" """
Registers a root to serve files from. The input can either be a file Registers a root to serve files from. The input can either be a file
or a directory. See or a directory. See
""" """
static_register(self, uri, file_or_directory, pattern, static_register(self, uri, file_or_directory, pattern,
use_modified_since) use_modified_since, use_content_range)
def blueprint(self, blueprint, **options): def blueprint(self, blueprint, **options):
""" """

View File

@ -2,14 +2,16 @@ from aiofiles.os import stat
from os import path from os import path
from re import sub from re import sub
from time import strftime, gmtime from time import strftime, gmtime
from mimetypes import guess_type
from urllib.parse import unquote from urllib.parse import unquote
from .exceptions import FileNotFound, InvalidUsage from .exceptions import FileNotFound, InvalidUsage, ContentRangeError
from .response import file, HTTPResponse from .response import file, HTTPResponse, ContentRangeHandler
def register(app, uri, file_or_directory, pattern, use_modified_since): def register(app, uri, file_or_directory, pattern,
# TODO: Though sanic is not a file server, I feel like we should atleast use_modified_since, use_content_range):
# TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could # make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching # also look into etags, expires, and caching
""" """
@ -23,8 +25,9 @@ def register(app, uri, file_or_directory, pattern, use_modified_since):
:param use_modified_since: If true, send file modified time, and return :param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the not modified if the browser's matches the
server's server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
""" """
# If we're not trying to match a file directly, # If we're not trying to match a file directly,
# serve from the folder # serve from the folder
if not path.isfile(file_or_directory): if not path.isfile(file_or_directory):
@ -50,6 +53,7 @@ def register(app, uri, file_or_directory, pattern, use_modified_since):
headers = {} headers = {}
# Check if the client has been sent this file before # Check if the client has been sent this file before
# and it has not been modified since # and it has not been modified since
stats = None
if use_modified_since: if use_modified_since:
stats = await stat(file_path) stats = await stat(file_path)
modified_since = strftime('%a, %d %b %Y %H:%M:%S GMT', modified_since = strftime('%a, %d %b %Y %H:%M:%S GMT',
@ -57,11 +61,36 @@ def register(app, uri, file_or_directory, pattern, use_modified_since):
if request.headers.get('If-Modified-Since') == modified_since: if request.headers.get('If-Modified-Since') == modified_since:
return HTTPResponse(status=304) return HTTPResponse(status=304)
headers['Last-Modified'] = modified_since headers['Last-Modified'] = modified_since
_range = None
return await file(file_path, headers=headers) if use_content_range:
except: if not stats:
stats = await stat(file_path)
headers['Accept-Ranges'] = 'bytes'
headers['Content-Length'] = str(stats.st_size)
if request.method != 'HEAD':
_range = ContentRangeHandler(request, stats)
# If the start byte is greater than the size
# of the entire file or if the end is
if _range.start >= _range.total or _range.end == 0:
raise ContentRangeError('Content-Range malformed',
_range)
if _range.start == 0 and _range.size == _range.total:
_range = None
else:
headers['Content-Length'] = str(_range.size)
for k, v in _range.headers.items():
headers[k] = v
if request.method == 'HEAD':
return HTTPResponse(
headers=headers,
content_type=guess_type(file_path)[0] or 'text/plain')
else:
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
raise FileNotFound('File not found', raise FileNotFound('File not found',
path=file_or_directory, path=file_or_directory,
relative_url=file_uri) relative_url=file_uri)
app.route(uri, methods=['GET'])(_handler) app.route(uri, methods=['GET', 'HEAD'])(_handler)