Merge pull request #378 from aquacash5/master

Added the tests, code formatting changes, and the Range Request feature.
This commit is contained in:
Eli Uriegas 2017-02-08 19:39:29 -06:00 committed by GitHub
commit 4f856e8783
7 changed files with 360 additions and 67 deletions

View File

@ -1,8 +1,3 @@
from .response import text, html
from .log import log
from traceback import format_exc, extract_tb
import sys
TRACEBACK_STYLE = ''' TRACEBACK_STYLE = '''
<style> <style>
body { body {
@ -104,6 +99,7 @@ INTERNAL_SERVER_ERROR_HTML = '''
class SanicException(Exception): class SanicException(Exception):
def __init__(self, message, status_code=None): def __init__(self, message, status_code=None):
super().__init__(message) super().__init__(message)
if status_code is not None: if status_code is not None:
self.status_code = status_code self.status_code = status_code
@ -141,6 +137,25 @@ class PayloadTooLarge(SanicException):
status_code = 413 status_code = 413
class HeaderNotFound(SanicException):
status_code = 400
class ContentRangeError(SanicException):
status_code = 416
def __init__(self, message, content_range):
super().__init__(message)
self.headers = {
'Content-Type': 'text/plain',
"Content-Range": "bytes */%s" % (content_range.total,)
}
class InvalidRangeType(ContentRangeError):
pass
class Handler: class Handler:
handlers = None handlers = None
cached_handlers = None cached_handlers = None

128
sanic/handlers.py Normal file
View File

@ -0,0 +1,128 @@
import sys
from traceback import format_exc, extract_tb
from .exceptions import ContentRangeError
from .exceptions import INTERNAL_SERVER_ERROR_HTML, TRACEBACK_LINE_HTML
from .exceptions import SanicException, HeaderNotFound, InvalidRangeType
from .exceptions import TRACEBACK_STYLE, TRACEBACK_WRAPPER_HTML
from .log import log
from .response import text, html
class ErrorHandler:
handlers = None
def __init__(self):
self.handlers = {}
self.debug = False
def _render_traceback_html(self, exception, request):
exc_type, exc_value, tb = sys.exc_info()
frames = extract_tb(tb)
frame_html = []
for frame in frames:
frame_html.append(TRACEBACK_LINE_HTML.format(frame))
return TRACEBACK_WRAPPER_HTML.format(
style=TRACEBACK_STYLE,
exc_name=exc_type.__name__,
exc_value=exc_value,
frame_html=''.join(frame_html),
uri=request.url)
def add(self, exception, handler):
self.handlers[exception] = handler
def response(self, request, exception):
"""
Fetches and executes an exception handler and returns a response object
:param request: Request
:param exception: Exception to handle
:return: Response object
"""
handler = self.handlers.get(type(exception), self.default)
try:
response = handler(request=request, exception=exception)
except Exception:
log.error(format_exc())
if self.debug:
response_message = (
'Exception raised in exception handler "{}" '
'for uri: "{}"\n{}').format(
handler.__name__, request.url, format_exc())
log.error(response_message)
return text(response_message, 500)
else:
return text('An error occurred while handling an error', 500)
return response
def default(self, request, exception):
log.error(format_exc())
if issubclass(type(exception), SanicException):
return text(
'Error: {}'.format(exception),
status=getattr(exception, 'status_code', 500),
headers=getattr(exception, 'headers', dict())
)
elif self.debug:
html_output = self._render_traceback_html(exception, request)
response_message = (
'Exception occurred while handling uri: "{}"\n{}'.format(
request.url, format_exc()))
log.error(response_message)
return html(html_output, status=500)
else:
return html(INTERNAL_SERVER_ERROR_HTML, status=500)
class ContentRangeHandler:
"""
This class is for parsing the request header
"""
__slots__ = ('start', 'end', 'size', 'total', 'headers')
def __init__(self, request, stats):
self.total = stats.st_size
_range = request.headers.get('Range')
if _range is None:
raise HeaderNotFound('Range Header Not Found')
unit, _, value = tuple(map(str.strip, _range.partition('=')))
if unit != 'bytes':
raise InvalidRangeType(
'%s is not a valid Range Type' % (unit,), self)
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
try:
self.start = int(start_b) if start_b else None
except ValueError:
raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (start_b,), self)
try:
self.end = int(end_b) if end_b else None
except ValueError:
raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (end_b,), self)
if self.end is None:
if self.start is None:
raise ContentRangeError(
'Invalid for Content Range parameters', self)
else:
# this case represents `Content-Range: bytes 5-`
self.end = self.total
else:
if self.start is None:
# this case represents `Content-Range: bytes -5`
self.start = self.total - self.end
self.end = self.total
if self.start >= self.end:
raise ContentRangeError(
'Invalid for Content Range parameters', self)
self.size = self.end - self.start
self.headers = {
'Content-Range': "bytes %s-%s/%s" % (
self.start, self.end, self.total)}
def __bool__(self):
return self.size > 0

View File

@ -1,9 +1,10 @@
from aiofiles import open as open_async from collections import ChainMap
from mimetypes import guess_type from mimetypes import guess_type
from os import path from os import path
from ujson import dumps as json_dumps from ujson import dumps as json_dumps
from aiofiles import open as open_async
from .cookies import CookieJar from .cookies import CookieJar
COMMON_STATUS_CODES = { COMMON_STATUS_CODES = {
@ -97,21 +98,25 @@ class HTTPResponse:
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
timeout_header = b'' default_header = dict()
if keep_alive and keep_alive_timeout: if keep_alive:
timeout_header = b'Keep-Alive: timeout=%d\r\n' % keep_alive_timeout if keep_alive_timeout:
default_header['Keep-Alive'] = keep_alive_timeout
default_header['Connection'] = 'keep-alive'
else:
default_header['Connection'] = 'close'
default_header['Content-Length'] = len(self.body)
default_header['Content-Type'] = self.content_type
headers = b'' headers = b''
if self.headers: for name, value in ChainMap(self.headers, default_header).items():
for name, value in self.headers.items():
try: try:
headers += ( headers += (
b'%b: %b\r\n' % (name.encode(), value.encode('utf-8'))) b'%b: %b\r\n' % (
name.encode(), value.encode('utf-8')))
except AttributeError: except AttributeError:
headers += ( headers += (
b'%b: %b\r\n' % ( b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8'))) str(name).encode(), str(value).encode('utf-8')))
# Try to pull from the common codes first # Try to pull from the common codes first
# Speeds up response rate 6% over pulling from all # Speeds up response rate 6% over pulling from all
status = COMMON_STATUS_CODES.get(self.status) status = COMMON_STATUS_CODES.get(self.status)
@ -119,18 +124,11 @@ class HTTPResponse:
status = ALL_STATUS_CODES.get(self.status) status = ALL_STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n' return (b'HTTP/%b %d %b\r\n'
b'Content-Type: %b\r\n' b'%b\r\n'
b'Content-Length: %d\r\n'
b'Connection: %b\r\n'
b'%b%b\r\n'
b'%b') % ( b'%b') % (
version.encode(), version.encode(),
self.status, self.status,
status, status,
self.content_type.encode(),
len(self.body),
b'keep-alive' if keep_alive else b'close',
timeout_header,
headers, headers,
self.body self.body
) )
@ -148,7 +146,7 @@ def json(body, status=200, headers=None, **kwargs):
:param body: Response data to be serialized. :param body: Response data to be serialized.
:param status: Response code. :param status: Response code.
:param headers: Custom Headers. :param headers: Custom Headers.
:param \**kwargs: Remaining arguments that are passed to the json encoder. :param kwargs: Remaining arguments that are passed to the json encoder.
""" """
return HTTPResponse(json_dumps(body, **kwargs), headers=headers, return HTTPResponse(json_dumps(body, **kwargs), headers=headers,
status=status, content_type="application/json") status=status, content_type="application/json")
@ -176,16 +174,23 @@ def html(body, status=200, headers=None):
content_type="text/html; charset=utf-8") content_type="text/html; charset=utf-8")
async def file(location, mime_type=None, headers=None): async def file(location, mime_type=None, headers=None, _range=None):
""" """
Returns response object with file data. Returns response object with file data.
:param location: Location of file on system. :param location: Location of file on system.
:param mime_type: Specific mime_type. :param mime_type: Specific mime_type.
:param headers: Custom Headers. :param headers: Custom Headers.
:param _range:
""" """
filename = path.split(location)[-1] filename = path.split(location)[-1]
async with open_async(location, mode='rb') as _file: async with open_async(location, mode='rb') as _file:
if _range:
await _file.seek(_range.start)
out_stream = await _file.read(_range.size)
headers['Content-Range'] = 'bytes %s-%s/%s' % (
_range.start, _range.end, _range.total)
else:
out_stream = await _file.read() out_stream = await _file.read()
mime_type = mime_type or guess_type(filename)[0] or 'text/plain' mime_type = mime_type or guess_type(filename)[0] or 'text/plain'

View File

@ -1,17 +1,16 @@
import logging import logging
import re
import warnings
from asyncio import get_event_loop from asyncio import get_event_loop
from collections import deque from collections import deque
from functools import partial from functools import partial
from inspect import isawaitable, stack, getmodulename from inspect import isawaitable, stack, getmodulename
import re
from traceback import format_exc from traceback import format_exc
from urllib.parse import urlencode, urlunparse from urllib.parse import urlencode, urlunparse
import warnings
from .config import Config from .config import Config
from .constants import HTTP_METHODS from .constants import HTTP_METHODS
from .exceptions import Handler from .exceptions import Handler, ServerError, URLBuildError
from .exceptions import ServerError, URLBuildError
from .log import log from .log import log
from .response import HTTPResponse from .response import HTTPResponse
from .router import Router from .router import Router
@ -36,7 +35,7 @@ class Sanic:
name = getmodulename(frame_records[1]) name = getmodulename(frame_records[1])
self.name = name self.name = name
self.router = router or Router() self.router = router or Router()
self.error_handler = error_handler or Handler() self.error_handler = error_handler or ErrorHandler()
self.config = Config() self.config = Config()
self.request_middleware = deque() self.request_middleware = deque()
self.response_middleware = deque() self.response_middleware = deque()
@ -60,6 +59,7 @@ class Sanic:
:param uri: path of the URL :param uri: path of the URL
:param methods: list or tuple of methods allowed :param methods: list or tuple of methods allowed
:param host:
:return: decorated function :return: decorated function
""" """
@ -77,25 +77,25 @@ class Sanic:
# Shorthand method decorators # Shorthand method decorators
def get(self, uri, host=None): def get(self, uri, host=None):
return self.route(uri, methods=["GET"], host=host) return self.route(uri, methods=frozenset({"GET"}), host=host)
def post(self, uri, host=None): def post(self, uri, host=None):
return self.route(uri, methods=["POST"], host=host) return self.route(uri, methods=frozenset({"POST"}), host=host)
def put(self, uri, host=None): def put(self, uri, host=None):
return self.route(uri, methods=["PUT"], host=host) return self.route(uri, methods=frozenset({"PUT"}), host=host)
def head(self, uri, host=None): def head(self, uri, host=None):
return self.route(uri, methods=["HEAD"], host=host) return self.route(uri, methods=frozenset({"HEAD"}), host=host)
def options(self, uri, host=None): def options(self, uri, host=None):
return self.route(uri, methods=["OPTIONS"], host=host) return self.route(uri, methods=frozenset({"OPTIONS"}), host=host)
def patch(self, uri, host=None): def patch(self, uri, host=None):
return self.route(uri, methods=["PATCH"], host=host) return self.route(uri, methods=frozenset({"PATCH"}), host=host)
def delete(self, uri, host=None): def delete(self, uri, host=None):
return self.route(uri, methods=["DELETE"], host=host) return self.route(uri, methods=frozenset({"DELETE"}), host=host)
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None): def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None):
""" """
@ -107,6 +107,7 @@ class Sanic:
:param uri: path of the URL :param uri: path of the URL
:param methods: list or tuple of methods allowed, these are overridden :param methods: list or tuple of methods allowed, these are overridden
if using a HTTPMethodView if using a HTTPMethodView
:param host:
:return: function or class instance :return: function or class instance
""" """
# Handle HTTPMethodView differently # Handle HTTPMethodView differently
@ -123,7 +124,7 @@ class Sanic:
""" """
Decorates a function to be registered as a handler for exceptions Decorates a function to be registered as a handler for exceptions
:param \*exceptions: exceptions :param exceptions: exceptions
:return: decorated function :return: decorated function
""" """
@ -158,13 +159,13 @@ class Sanic:
# Static Files # Static Files
def static(self, uri, file_or_directory, pattern='.+', def static(self, uri, file_or_directory, pattern='.+',
use_modified_since=True): use_modified_since=True, use_content_range=False):
""" """
Registers a root to serve files from. The input can either be a file Registers a root to serve files from. The input can either be a file
or a directory. See or a directory. See
""" """
static_register(self, uri, file_or_directory, pattern, static_register(self, uri, file_or_directory, pattern,
use_modified_since) use_modified_since, use_content_range)
def blueprint(self, blueprint, **options): def blueprint(self, blueprint, **options):
""" """
@ -388,6 +389,10 @@ class Sanic:
:param sock: Socket for the server to accept connections from :param sock: Socket for the server to accept connections from
:param workers: Number of processes :param workers: Number of processes
received before it is respected received before it is respected
:param loop:
:param backlog:
:param stop_event:
:param register_sys_signals:
:param protocol: Subclass of asyncio protocol class :param protocol: Subclass of asyncio protocol class
:return: Nothing :return: Nothing
""" """
@ -402,11 +407,9 @@ class Sanic:
serve(**server_settings) serve(**server_settings)
else: else:
serve_multiple(server_settings, workers, stop_event) serve_multiple(server_settings, workers, stop_event)
except Exception as e: except Exception as e:
log.exception( log.exception(
'Experienced exception while trying to serve') 'Experienced exception while trying to serve')
log.info("Server Stopped") log.info("Server Stopped")
def stop(self): def stop(self):

View File

@ -1,14 +1,19 @@
from aiofiles.os import stat from mimetypes import guess_type
from os import path from os import path
from re import sub from re import sub
from time import strftime, gmtime from time import strftime, gmtime
from urllib.parse import unquote from urllib.parse import unquote
from .exceptions import FileNotFound, InvalidUsage from aiofiles.os import stat
from .exceptions import FileNotFound, InvalidUsage, ContentRangeError
from .exceptions import HeaderNotFound
from .handlers import ContentRangeHandler
from .response import file, HTTPResponse from .response import file, HTTPResponse
def register(app, uri, file_or_directory, pattern, use_modified_since): def register(app, uri, file_or_directory, pattern,
use_modified_since, use_content_range):
# TODO: Though sanic is not a file server, I feel like we should at least # TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could # make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching # also look into etags, expires, and caching
@ -23,8 +28,9 @@ def register(app, uri, file_or_directory, pattern, use_modified_since):
:param use_modified_since: If true, send file modified time, and return :param use_modified_since: If true, send file modified time, and return
not modified if the browser's matches the not modified if the browser's matches the
server's server's
:param use_content_range: If true, process header for range requests
and sends the file part that is requested
""" """
# If we're not trying to match a file directly, # If we're not trying to match a file directly,
# serve from the folder # serve from the folder
if not path.isfile(file_or_directory): if not path.isfile(file_or_directory):
@ -50,18 +56,41 @@ def register(app, uri, file_or_directory, pattern, use_modified_since):
headers = {} headers = {}
# Check if the client has been sent this file before # Check if the client has been sent this file before
# and it has not been modified since # and it has not been modified since
stats = None
if use_modified_since: if use_modified_since:
stats = await stat(file_path) stats = await stat(file_path)
modified_since = strftime('%a, %d %b %Y %H:%M:%S GMT', modified_since = strftime(
gmtime(stats.st_mtime)) '%a, %d %b %Y %H:%M:%S GMT', gmtime(stats.st_mtime))
if request.headers.get('If-Modified-Since') == modified_since: if request.headers.get('If-Modified-Since') == modified_since:
return HTTPResponse(status=304) return HTTPResponse(status=304)
headers['Last-Modified'] = modified_since headers['Last-Modified'] = modified_since
_range = None
return await file(file_path, headers=headers) if use_content_range:
except: _range = None
if not stats:
stats = await stat(file_path)
headers['Accept-Ranges'] = 'bytes'
headers['Content-Length'] = str(stats.st_size)
if request.method != 'HEAD':
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers['Content-Length']
for key, value in _range.headers.items():
headers[key] = value
if request.method == 'HEAD':
return HTTPResponse(
headers=headers,
content_type=guess_type(file_path)[0] or 'text/plain')
else:
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
raise FileNotFound('File not found', raise FileNotFound('File not found',
path=file_or_directory, path=file_or_directory,
relative_url=file_uri) relative_url=file_uri)
app.route(uri, methods=['GET'])(_handler) app.route(uri, methods=['GET', 'HEAD'])(_handler)

View File

@ -1,11 +1,13 @@
from json import loads as json_loads, dumps as json_dumps from json import loads as json_loads, dumps as json_dumps
from sanic import Sanic
from sanic.response import json, text, redirect
from sanic.utils import sanic_endpoint_test
from sanic.exceptions import ServerError
import pytest import pytest
from sanic import Sanic
from sanic.exceptions import ServerError
from sanic.response import json, text, redirect
from sanic.utils import sanic_endpoint_test
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
# GET # GET
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
@ -112,7 +114,8 @@ def test_query_string():
async def handler(request): async def handler(request):
return text('OK') return text('OK')
request, response = sanic_endpoint_test(app, params=[("test1", "1"), ("test2", "false"), ("test2", "true")]) request, response = sanic_endpoint_test(
app, params=[("test1", "1"), ("test2", "false"), ("test2", "true")])
assert request.args.get('test1') == '1' assert request.args.get('test1') == '1'
assert request.args.get('test2') == 'false' assert request.args.get('test2') == 'false'
@ -150,7 +153,8 @@ def test_post_json():
payload = {'test': 'OK'} payload = {'test': 'OK'}
headers = {'content-type': 'application/json'} headers = {'content-type': 'application/json'}
request, response = sanic_endpoint_test(app, data=json_dumps(payload), headers=headers) request, response = sanic_endpoint_test(
app, data=json_dumps(payload), headers=headers)
assert request.json.get('test') == 'OK' assert request.json.get('test') == 'OK'
assert response.text == 'OK' assert response.text == 'OK'

View File

@ -48,3 +48,112 @@ def test_static_directory(file_name, base_uri, static_file_directory):
app, uri='{}/{}'.format(base_uri, file_name)) app, uri='{}/{}'.format(base_uri, file_name))
assert response.status == 200 assert response.status == 200
assert response.body == get_file_content(static_file_directory, file_name) assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_head_request(
file_name, static_file_content, static_file_directory):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
request, response = sanic_endpoint_test(
app, uri='/testing.file', method='head')
assert response.status == 200
assert 'Accept-Ranges' in response.headers
assert 'Content-Length' in response.headers
assert int(response.headers['Content-Length']) == len(get_file_content(static_file_directory, file_name))
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_content_range_correct(
file_name, static_file_content, static_file_directory):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
headers = {
'Range': 'bytes=12-19'
}
request, response = sanic_endpoint_test(
app, uri='/testing.file', headers=headers)
assert response.status == 200
assert 'Content-Length' in response.headers
assert 'Content-Range' in response.headers
static_content = bytes(get_file_content(static_file_directory, file_name))[12:19]
assert int(response.headers['Content-Length']) == len(get_file_content(static_file_directory, file_name))
assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_content_range_front(
file_name, static_file_content, static_file_directory):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
headers = {
'Range': 'bytes=12-'
}
request, response = sanic_endpoint_test(
app, uri='/testing.file', headers=headers)
assert response.status == 200
assert 'Content-Length' in response.headers
assert 'Content-Range' in response.headers
static_content = bytes(get_file_content(static_file_directory, file_name))[12:]
assert int(response.headers['Content-Length']) == len(get_file_content(static_file_directory, file_name))
assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_content_range_back(
file_name, static_file_content, static_file_directory):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
headers = {
'Range': 'bytes=-12'
}
request, response = sanic_endpoint_test(
app, uri='/testing.file', headers=headers)
assert response.status == 200
assert 'Content-Length' in response.headers
assert 'Content-Range' in response.headers
static_content = bytes(get_file_content(static_file_directory, file_name))[-12:]
assert int(response.headers['Content-Length']) == len(get_file_content(static_file_directory, file_name))
assert response.body == get_file_content(static_file_directory, file_name)
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_content_range_empty(
file_name, static_file_content, static_file_directory):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
request, response = sanic_endpoint_test(app, uri='/testing.file')
assert response.status == 200
assert 'Content-Length' in response.headers
assert 'Content-Range' not in response.headers
assert int(response.headers['Content-Length']) == len(get_file_content(static_file_directory, file_name))
assert response.body == bytes(get_file_content(static_file_directory, file_name))
@pytest.mark.parametrize('file_name', ['test.file', 'decode me.txt'])
def test_static_content_range_error(static_file_path, static_file_content):
app = Sanic('test_static')
app.static(
'/testing.file', get_file_path(static_file_directory, file_name),
use_content_range=True)
headers = {
'Range': 'bytes=1-0'
}
request, response = sanic_endpoint_test(
app, uri='/testing.file', headers=headers)
assert response.status == 416
assert 'Content-Length' in response.headers
assert 'Content-Range' in response.headers
assert response.headers['Content-Range'] == "bytes */%s" % (
len(static_file_content),)