Merge branch 'master' into unauthorized-exception

This commit is contained in:
Raphael Deem
2017-07-21 01:54:45 -07:00
committed by GitHub
29 changed files with 1316 additions and 222 deletions

View File

@@ -113,7 +113,7 @@ class Sanic:
# Decorator
def route(self, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=False, stream=False):
strict_slashes=False, stream=False, version=None):
"""Decorate a function to be registered as a route
:param uri: path of the URL
@@ -136,42 +136,49 @@ class Sanic:
if stream:
handler.is_stream = stream
self.router.add(uri=uri, methods=methods, handler=handler,
host=host, strict_slashes=strict_slashes)
host=host, strict_slashes=strict_slashes,
version=version)
return handler
return response
# Shorthand method decorators
def get(self, uri, host=None, strict_slashes=False):
def get(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=frozenset({"GET"}), host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def post(self, uri, host=None, strict_slashes=False, stream=False):
def post(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=frozenset({"POST"}), host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def put(self, uri, host=None, strict_slashes=False, stream=False):
def put(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=frozenset({"PUT"}), host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def head(self, uri, host=None, strict_slashes=False):
def head(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=frozenset({"HEAD"}), host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def options(self, uri, host=None, strict_slashes=False):
def options(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=frozenset({"OPTIONS"}), host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def patch(self, uri, host=None, strict_slashes=False, stream=False):
def patch(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=frozenset({"PATCH"}), host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def delete(self, uri, host=None, strict_slashes=False):
def delete(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=frozenset({"DELETE"}), host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=False):
strict_slashes=False, version=None):
"""A helper method to register class instance or
functions as a handler to the application url
routes.
@@ -204,7 +211,8 @@ class Sanic:
break
self.route(uri=uri, methods=methods, host=host,
strict_slashes=strict_slashes, stream=stream)(handler)
strict_slashes=strict_slashes, stream=stream,
version=version)(handler)
return handler
# Decorator
@@ -701,7 +709,8 @@ class Sanic:
'backlog': backlog,
'has_log': has_log,
'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE,
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE,
'graceful_shutdown_timeout': self.config.GRACEFUL_SHUTDOWN_TIMEOUT
}
# -------------------------------------------- #

View File

@@ -4,8 +4,8 @@ from sanic.constants import HTTP_METHODS
from sanic.views import CompositionView
FutureRoute = namedtuple('Route',
['handler', 'uri', 'methods',
'host', 'strict_slashes', 'stream'])
['handler', 'uri', 'methods', 'host',
'strict_slashes', 'stream', 'version'])
FutureListener = namedtuple('Listener', ['handler', 'uri', 'methods', 'host'])
FutureMiddleware = namedtuple('Route', ['middleware', 'args', 'kwargs'])
FutureException = namedtuple('Route', ['handler', 'args', 'kwargs'])
@@ -14,7 +14,7 @@ FutureStatic = namedtuple('Route',
class Blueprint:
def __init__(self, name, url_prefix=None, host=None):
def __init__(self, name, url_prefix=None, host=None, version=None):
"""Create a new blueprint
:param name: unique name of the blueprint
@@ -30,6 +30,7 @@ class Blueprint:
self.listeners = defaultdict(list)
self.middlewares = []
self.statics = []
self.version = version
def register(self, app, options):
"""Register the blueprint to the sanic app."""
@@ -43,12 +44,16 @@ class Blueprint:
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
version = future.version or self.version
app.route(
uri=uri[1:] if uri.startswith('//') else uri,
methods=future.methods,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
stream=future.stream
stream=future.stream,
version=version
)(future.handler)
for future in self.websocket_routes:
@@ -89,7 +94,7 @@ class Blueprint:
app.listener(event)(listener)
def route(self, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=False, stream=False):
strict_slashes=False, stream=False, version=None):
"""Create a blueprint route from a decorated function.
:param uri: endpoint at which the route will be accessible.
@@ -97,13 +102,13 @@ class Blueprint:
"""
def decorator(handler):
route = FutureRoute(
handler, uri, methods, host, strict_slashes, stream)
handler, uri, methods, host, strict_slashes, stream, version)
self.routes.append(route)
return handler
return decorator
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=False):
strict_slashes=False, version=None):
"""Create a blueprint route from a function.
:param handler: function for handling uri requests. Accepts function,
@@ -125,21 +130,22 @@ class Blueprint:
methods = handler.handlers.keys()
self.route(uri=uri, methods=methods, host=host,
strict_slashes=strict_slashes)(handler)
strict_slashes=strict_slashes, version=version)(handler)
return handler
def websocket(self, uri, host=None, strict_slashes=False):
def websocket(self, uri, host=None, strict_slashes=False, version=None):
"""Create a blueprint websocket route from a decorated function.
:param uri: endpoint at which the route will be accessible.
"""
def decorator(handler):
route = FutureRoute(handler, uri, [], host, strict_slashes, False)
route = FutureRoute(handler, uri, [], host, strict_slashes,
False, version)
self.websocket_routes.append(route)
return handler
return decorator
def add_websocket_route(self, handler, uri, host=None):
def add_websocket_route(self, handler, uri, host=None, version=None):
"""Create a blueprint websocket route from a function.
:param handler: function for handling uri requests. Accepts function,
@@ -147,7 +153,7 @@ class Blueprint:
:param uri: endpoint at which the route will be accessible.
:return: function or class instance
"""
self.websocket(uri=uri, host=host)(handler)
self.websocket(uri=uri, host=host, version=version)(handler)
return handler
def listener(self, event):
@@ -193,30 +199,36 @@ class Blueprint:
self.statics.append(static)
# Shorthand method decorators
def get(self, uri, host=None, strict_slashes=False):
def get(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=["GET"], host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def post(self, uri, host=None, strict_slashes=False, stream=False):
def post(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=["POST"], host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def put(self, uri, host=None, strict_slashes=False, stream=False):
def put(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=["PUT"], host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def head(self, uri, host=None, strict_slashes=False):
def head(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=["HEAD"], host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def options(self, uri, host=None, strict_slashes=False):
def options(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=["OPTIONS"], host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)
def patch(self, uri, host=None, strict_slashes=False, stream=False):
def patch(self, uri, host=None, strict_slashes=False, stream=False,
version=None):
return self.route(uri, methods=["PATCH"], host=host,
strict_slashes=strict_slashes, stream=stream)
strict_slashes=strict_slashes, stream=stream,
version=version)
def delete(self, uri, host=None, strict_slashes=False):
def delete(self, uri, host=None, strict_slashes=False, version=None):
return self.route(uri, methods=["DELETE"], host=host,
strict_slashes=strict_slashes)
strict_slashes=strict_slashes, version=version)

View File

@@ -12,11 +12,12 @@ _address_dict = {
'Windows': ('localhost', 514),
'Darwin': '/var/run/syslog',
'Linux': '/dev/log',
'FreeBSD': '/dev/log'
'FreeBSD': '/var/run/log'
}
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'filters': {
'accessFilter': {
'()': DefaultFilter,
@@ -127,6 +128,7 @@ class Config(dict):
self.KEEP_ALIVE = keep_alive
self.WEBSOCKET_MAX_SIZE = 2 ** 20 # 1 megabytes
self.WEBSOCKET_MAX_QUEUE = 32
self.GRACEFUL_SHUTDOWN_TIMEOUT = 15.0 # 15 sec
if load_env:
self.load_environment_vars()
@@ -195,10 +197,16 @@ class Config(dict):
def load_environment_vars(self):
"""
Looks for any SANIC_ prefixed environment variables and applies
Looks for any ``SANIC_`` prefixed environment variables and applies
them to the configuration if present.
"""
for k, v in os.environ.items():
if k.startswith(SANIC_PREFIX):
_, config_key = k.split(SANIC_PREFIX, 1)
self[config_key] = v
try:
self[config_key] = int(v)
except ValueError:
try:
self[config_key] = float(v)
except ValueError:
self[config_key] = v

View File

@@ -194,6 +194,11 @@ class ContentRangeError(SanicException):
}
@add_status_code(403)
class Forbidden(SanicException):
pass
class InvalidRangeType(ContentRangeError):
pass
@@ -205,8 +210,8 @@ class Unauthorized(SanicException):
:param scheme: Name of the authentication scheme to be used.
:param challenge: A dict containing values to add to the WWW-Authenticate
header that is generated. This is especially useful when dealing with the
Digest scheme. (optional)
header that is generated. This is especially useful when dealing with
the Digest scheme. (optional)
Examples::
@@ -227,7 +232,6 @@ class Unauthorized(SanicException):
# With a Bearer auth-scheme, realm is optional:
challenge = {"realm": "Restricted Area"}
raise Unauthorized("Auth required.", "Bearer", challenge)
"""
pass
@@ -249,9 +253,10 @@ def abort(status_code, message=None):
"""
Raise an exception based on SanicException. Returns the HTTP response
message appropriate for the given status code, unless provided.
:param status_code: The HTTP status code to return.
:param message: The HTTP response body. Defaults to the messages
in response.py for the given status code.
in response.py for the given status code.
"""
if message is None:
message = COMMON_STATUS_CODES.get(status_code,

View File

@@ -45,7 +45,7 @@ class Request(dict):
__slots__ = (
'app', 'headers', 'version', 'method', '_cookies', 'transport',
'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files',
'_ip', '_parsed_url', 'uri_template', 'stream'
'_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr'
)
def __init__(self, url_bytes, headers, version, method, transport):
@@ -86,11 +86,15 @@ class Request(dict):
:return: token related to request
"""
prefixes = ('Bearer', 'Token')
auth_header = self.headers.get('Authorization')
if auth_header is not None and 'Token ' in auth_header:
return auth_header.partition('Token ')[-1]
else:
return auth_header
if auth_header is not None:
for prefix in prefixes:
if prefix in auth_header:
return auth_header.partition(prefix)[-1].strip()
return auth_header
@property
def form(self):
@@ -138,7 +142,7 @@ class Request(dict):
@property
def cookies(self):
if self._cookies is None:
cookie = self.headers.get('Cookie') or self.headers.get('cookie')
cookie = self.headers.get('Cookie')
if cookie is not None:
cookies = SimpleCookie()
cookies.load(cookie)
@@ -155,6 +159,25 @@ class Request(dict):
(None, None))
return self._ip
@property
def remote_addr(self):
"""Attempt to return the original client ip based on X-Forwarded-For.
:return: original client ip.
"""
if not hasattr(self, '_remote_addr'):
forwarded_for = self.headers.get('X-Forwarded-For', '').split(',')
remote_addrs = [
addr for addr in [
addr.strip() for addr in forwarded_for
] if addr
]
if len(remote_addrs) > 0:
self._remote_addr = remote_addrs[0]
else:
self._remote_addr = ''
return self._remote_addr
@property
def scheme(self):
if self.app.websocket_enabled \
@@ -234,15 +257,15 @@ def parse_multipart_form(body, boundary):
break
colon_index = form_line.index(':')
form_header_field = form_line[0:colon_index]
form_header_field = form_line[0:colon_index].lower()
form_header_value, form_parameters = parse_header(
form_line[colon_index + 2:])
if form_header_field == 'Content-Disposition':
if form_header_field == 'content-disposition':
if 'filename' in form_parameters:
file_name = form_parameters['filename']
field_name = form_parameters.get('name')
elif form_header_field == 'Content-Type':
elif form_header_field == 'content-type':
file_type = form_header_value
post_data = form_part[line_index:-4]

View File

@@ -237,6 +237,7 @@ def json(body, status=200, headers=None,
content_type="application/json", **kwargs):
"""
Returns response object with body in json format.
:param body: Response data to be serialized.
:param status: Response code.
:param headers: Custom Headers.
@@ -250,6 +251,7 @@ def text(body, status=200, headers=None,
content_type="text/plain; charset=utf-8"):
"""
Returns response object with body in text format.
:param body: Response data to be encoded.
:param status: Response code.
:param headers: Custom Headers.
@@ -264,6 +266,7 @@ def raw(body, status=200, headers=None,
content_type="application/octet-stream"):
"""
Returns response object without encoding the body.
:param body: Response data.
:param status: Response code.
:param headers: Custom Headers.
@@ -276,6 +279,7 @@ def raw(body, status=200, headers=None,
def html(body, status=200, headers=None):
"""
Returns response object with body in html format.
:param body: Response data to be encoded.
:param status: Response code.
:param headers: Custom Headers.

View File

@@ -98,8 +98,25 @@ class Router:
return name, _type, pattern
def add(self, uri, methods, handler, host=None, strict_slashes=False):
def add(self, uri, methods, handler, host=None, strict_slashes=False,
version=None):
"""Add a handler to the route list
:param uri: path to match
:param methods: sequence of accepted method names. If none are
provided, any method is allowed
:param handler: request handler function.
When executed, it should provide a response object.
:param strict_slashes: strict to trailing slash
:param version: current version of the route or blueprint. See
docs for further details.
:return: Nothing
"""
if version is not None:
if uri.startswith('/'):
uri = "/".join(["/v{}".format(str(version)), uri[1:]])
else:
uri = "/".join(["/v{}".format(str(version)), uri])
# add regular version
self._add(uri, methods, handler, host)

View File

@@ -75,7 +75,7 @@ class HttpProtocol(asyncio.Protocol):
signal=Signal(), connections=set(), request_timeout=60,
request_max_size=None, request_class=None, has_log=True,
keep_alive=True, is_request_stream=False, router=None,
**kwargs):
state=None, debug=False, **kwargs):
self.loop = loop
self.transport = None
self.request = None
@@ -99,12 +99,18 @@ class HttpProtocol(asyncio.Protocol):
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive
self._header_fragment = b''
self.state = state if state else {}
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
self._debug = debug
@property
def keep_alive(self):
return (self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive())
return (
self._keep_alive and
not self.signal.stopped and
self.parser.should_keep_alive())
# -------------------------------------------- #
# Connection
@@ -154,22 +160,39 @@ class HttpProtocol(asyncio.Protocol):
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state['requests_count'] = self.state['requests_count'] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
exception = InvalidUsage('Bad Request')
message = 'Bad Request'
if self._debug:
message += '\n' + traceback.format_exc()
exception = InvalidUsage(message)
self.write_error(exception)
def on_url(self, url):
self.url = url
if not self.url:
self.url = url
else:
self.url += url
def on_header(self, name, value):
if name == b'Content-Length' and int(value) > self.request_max_size:
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
self._header_fragment += name
self.headers.append((name.decode().casefold(), value.decode()))
if value is not None:
if self._header_fragment == b'Content-Length' \
and int(value) > self.request_max_size:
exception = PayloadTooLarge('Payload Too Large')
self.write_error(exception)
self.headers.append(
(self._header_fragment.decode().casefold(),
value.decode()))
self._header_fragment = b''
def on_headers_complete(self):
self.request = self.request_class(
@@ -357,6 +380,14 @@ class HttpProtocol(asyncio.Protocol):
return True
return False
def close(self):
"""
Force close the connection.
"""
if self.transport is not None:
self.transport.close()
self.transport = None
def update_current_time(loop):
"""Cache the current time, since it is needed at the end of every
@@ -389,7 +420,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
register_sys_signals=True, run_async=False, connections=None,
signal=Signal(), request_class=None, has_log=True, keep_alive=True,
is_request_stream=False, router=None, websocket_max_size=None,
websocket_max_queue=None):
websocket_max_queue=None, state=None,
graceful_shutdown_timeout=15.0):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
@@ -427,8 +459,6 @@ def serve(host, port, request_handler, error_handler, before_start=None,
if debug:
loop.set_debug(debug)
trigger_events(before_start, loop)
connections = connections if connections is not None else set()
server = partial(
protocol,
@@ -445,7 +475,9 @@ def serve(host, port, request_handler, error_handler, before_start=None,
is_request_stream=is_request_stream,
router=router,
websocket_max_size=websocket_max_size,
websocket_max_queue=websocket_max_queue
websocket_max_queue=websocket_max_queue,
state=state,
debug=debug,
)
server_coroutine = loop.create_server(
@@ -457,6 +489,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
sock=sock,
backlog=backlog
)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
@@ -464,6 +497,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
if run_async:
return server_coroutine
trigger_events(before_start, loop)
try:
http_server = loop.run_until_complete(server_coroutine)
except:
@@ -499,8 +534,26 @@ def serve(host, port, request_handler, error_handler, before_start=None,
for connection in connections:
connection.close_if_idle()
while connections:
# Gracefully shutdown timeout.
# We should provide graceful_shutdown_timeout,
# instead of letting connection hangs forever.
# Let's roughly calcucate time.
start_shutdown = 0
while connections and (start_shutdown < graceful_shutdown_timeout):
loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection(force=True))
else:
conn.close()
_shutdown = asyncio.gather(*coros, loop=loop)
loop.run_until_complete(_shutdown)
trigger_events(after_stop, loop)

View File

@@ -3,6 +3,7 @@ import sys
import signal
import asyncio
import logging
import traceback
try:
import ssl
@@ -29,7 +30,7 @@ class GunicornWorker(base.Worker):
self.ssl_context = self._create_ssl_context(cfg)
else:
self.ssl_context = None
self.servers = []
self.servers = {}
self.connections = set()
self.exit_code = 0
self.signal = Signal()
@@ -69,10 +70,16 @@ class GunicornWorker(base.Worker):
trigger_events(self._server_settings.get('before_stop', []),
self.loop)
self.loop.run_until_complete(self.close())
except:
traceback.print_exc()
finally:
trigger_events(self._server_settings.get('after_stop', []),
self.loop)
self.loop.close()
try:
trigger_events(self._server_settings.get('after_stop', []),
self.loop)
except:
traceback.print_exc()
finally:
self.loop.close()
sys.exit(self.exit_code)
@@ -91,16 +98,37 @@ class GunicornWorker(base.Worker):
for conn in self.connections:
conn.close_if_idle()
while self.connections:
# gracefully shutdown timeout
start_shutdown = 0
graceful_shutdown_timeout = self.cfg.graceful_timeout
while self.connections and \
(start_shutdown < graceful_shutdown_timeout):
await asyncio.sleep(0.1)
start_shutdown = start_shutdown + 0.1
# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
coros = []
for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection(force=True))
else:
conn.close()
_shutdown = asyncio.gather(*coros, loop=self.loop)
await _shutdown
async def _run(self):
for sock in self.sockets:
self.servers.append(await serve(
state = dict(requests_count=0)
self._server_settings["host"] = None
self._server_settings["port"] = None
server = await serve(
sock=sock,
connections=self.connections,
state=state,
**self._server_settings
))
)
self.servers[server] = state
async def _check_alive(self):
# If our parent changed then we shut down.
@@ -109,7 +137,15 @@ class GunicornWorker(base.Worker):
while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
req_count = sum(
self.servers[srv]["requests_count"] for srv in self.servers
)
if self.max_requests and req_count > self.max_requests:
self.alive = False
self.log.info(
"Max requests exceeded, shutting down: %s", self
)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
@@ -166,3 +202,4 @@ class GunicornWorker(base.Worker):
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)