Merge branch 'master' into unauthorized-exception
This commit is contained in:
47
sanic/app.py
47
sanic/app.py
@@ -113,7 +113,7 @@ class Sanic:
|
||||
|
||||
# Decorator
|
||||
def route(self, uri, methods=frozenset({'GET'}), host=None,
|
||||
strict_slashes=False, stream=False):
|
||||
strict_slashes=False, stream=False, version=None):
|
||||
"""Decorate a function to be registered as a route
|
||||
|
||||
:param uri: path of the URL
|
||||
@@ -136,42 +136,49 @@ class Sanic:
|
||||
if stream:
|
||||
handler.is_stream = stream
|
||||
self.router.add(uri=uri, methods=methods, handler=handler,
|
||||
host=host, strict_slashes=strict_slashes)
|
||||
host=host, strict_slashes=strict_slashes,
|
||||
version=version)
|
||||
return handler
|
||||
|
||||
return response
|
||||
|
||||
# Shorthand method decorators
|
||||
def get(self, uri, host=None, strict_slashes=False):
|
||||
def get(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=frozenset({"GET"}), host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def post(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def post(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=frozenset({"POST"}), host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def put(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def put(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=frozenset({"PUT"}), host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def head(self, uri, host=None, strict_slashes=False):
|
||||
def head(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=frozenset({"HEAD"}), host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def options(self, uri, host=None, strict_slashes=False):
|
||||
def options(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=frozenset({"OPTIONS"}), host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def patch(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def patch(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=frozenset({"PATCH"}), host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def delete(self, uri, host=None, strict_slashes=False):
|
||||
def delete(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=frozenset({"DELETE"}), host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
|
||||
strict_slashes=False):
|
||||
strict_slashes=False, version=None):
|
||||
"""A helper method to register class instance or
|
||||
functions as a handler to the application url
|
||||
routes.
|
||||
@@ -204,7 +211,8 @@ class Sanic:
|
||||
break
|
||||
|
||||
self.route(uri=uri, methods=methods, host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)(handler)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)(handler)
|
||||
return handler
|
||||
|
||||
# Decorator
|
||||
@@ -701,7 +709,8 @@ class Sanic:
|
||||
'backlog': backlog,
|
||||
'has_log': has_log,
|
||||
'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE,
|
||||
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE
|
||||
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE,
|
||||
'graceful_shutdown_timeout': self.config.GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
}
|
||||
|
||||
# -------------------------------------------- #
|
||||
|
||||
@@ -4,8 +4,8 @@ from sanic.constants import HTTP_METHODS
|
||||
from sanic.views import CompositionView
|
||||
|
||||
FutureRoute = namedtuple('Route',
|
||||
['handler', 'uri', 'methods',
|
||||
'host', 'strict_slashes', 'stream'])
|
||||
['handler', 'uri', 'methods', 'host',
|
||||
'strict_slashes', 'stream', 'version'])
|
||||
FutureListener = namedtuple('Listener', ['handler', 'uri', 'methods', 'host'])
|
||||
FutureMiddleware = namedtuple('Route', ['middleware', 'args', 'kwargs'])
|
||||
FutureException = namedtuple('Route', ['handler', 'args', 'kwargs'])
|
||||
@@ -14,7 +14,7 @@ FutureStatic = namedtuple('Route',
|
||||
|
||||
|
||||
class Blueprint:
|
||||
def __init__(self, name, url_prefix=None, host=None):
|
||||
def __init__(self, name, url_prefix=None, host=None, version=None):
|
||||
"""Create a new blueprint
|
||||
|
||||
:param name: unique name of the blueprint
|
||||
@@ -30,6 +30,7 @@ class Blueprint:
|
||||
self.listeners = defaultdict(list)
|
||||
self.middlewares = []
|
||||
self.statics = []
|
||||
self.version = version
|
||||
|
||||
def register(self, app, options):
|
||||
"""Register the blueprint to the sanic app."""
|
||||
@@ -43,12 +44,16 @@ class Blueprint:
|
||||
future.handler.__blueprintname__ = self.name
|
||||
# Prepend the blueprint URI prefix if available
|
||||
uri = url_prefix + future.uri if url_prefix else future.uri
|
||||
|
||||
version = future.version or self.version
|
||||
|
||||
app.route(
|
||||
uri=uri[1:] if uri.startswith('//') else uri,
|
||||
methods=future.methods,
|
||||
host=future.host or self.host,
|
||||
strict_slashes=future.strict_slashes,
|
||||
stream=future.stream
|
||||
stream=future.stream,
|
||||
version=version
|
||||
)(future.handler)
|
||||
|
||||
for future in self.websocket_routes:
|
||||
@@ -89,7 +94,7 @@ class Blueprint:
|
||||
app.listener(event)(listener)
|
||||
|
||||
def route(self, uri, methods=frozenset({'GET'}), host=None,
|
||||
strict_slashes=False, stream=False):
|
||||
strict_slashes=False, stream=False, version=None):
|
||||
"""Create a blueprint route from a decorated function.
|
||||
|
||||
:param uri: endpoint at which the route will be accessible.
|
||||
@@ -97,13 +102,13 @@ class Blueprint:
|
||||
"""
|
||||
def decorator(handler):
|
||||
route = FutureRoute(
|
||||
handler, uri, methods, host, strict_slashes, stream)
|
||||
handler, uri, methods, host, strict_slashes, stream, version)
|
||||
self.routes.append(route)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
|
||||
strict_slashes=False):
|
||||
strict_slashes=False, version=None):
|
||||
"""Create a blueprint route from a function.
|
||||
|
||||
:param handler: function for handling uri requests. Accepts function,
|
||||
@@ -125,21 +130,22 @@ class Blueprint:
|
||||
methods = handler.handlers.keys()
|
||||
|
||||
self.route(uri=uri, methods=methods, host=host,
|
||||
strict_slashes=strict_slashes)(handler)
|
||||
strict_slashes=strict_slashes, version=version)(handler)
|
||||
return handler
|
||||
|
||||
def websocket(self, uri, host=None, strict_slashes=False):
|
||||
def websocket(self, uri, host=None, strict_slashes=False, version=None):
|
||||
"""Create a blueprint websocket route from a decorated function.
|
||||
|
||||
:param uri: endpoint at which the route will be accessible.
|
||||
"""
|
||||
def decorator(handler):
|
||||
route = FutureRoute(handler, uri, [], host, strict_slashes, False)
|
||||
route = FutureRoute(handler, uri, [], host, strict_slashes,
|
||||
False, version)
|
||||
self.websocket_routes.append(route)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def add_websocket_route(self, handler, uri, host=None):
|
||||
def add_websocket_route(self, handler, uri, host=None, version=None):
|
||||
"""Create a blueprint websocket route from a function.
|
||||
|
||||
:param handler: function for handling uri requests. Accepts function,
|
||||
@@ -147,7 +153,7 @@ class Blueprint:
|
||||
:param uri: endpoint at which the route will be accessible.
|
||||
:return: function or class instance
|
||||
"""
|
||||
self.websocket(uri=uri, host=host)(handler)
|
||||
self.websocket(uri=uri, host=host, version=version)(handler)
|
||||
return handler
|
||||
|
||||
def listener(self, event):
|
||||
@@ -193,30 +199,36 @@ class Blueprint:
|
||||
self.statics.append(static)
|
||||
|
||||
# Shorthand method decorators
|
||||
def get(self, uri, host=None, strict_slashes=False):
|
||||
def get(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=["GET"], host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def post(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def post(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=["POST"], host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def put(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def put(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=["PUT"], host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def head(self, uri, host=None, strict_slashes=False):
|
||||
def head(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=["HEAD"], host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def options(self, uri, host=None, strict_slashes=False):
|
||||
def options(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=["OPTIONS"], host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
def patch(self, uri, host=None, strict_slashes=False, stream=False):
|
||||
def patch(self, uri, host=None, strict_slashes=False, stream=False,
|
||||
version=None):
|
||||
return self.route(uri, methods=["PATCH"], host=host,
|
||||
strict_slashes=strict_slashes, stream=stream)
|
||||
strict_slashes=strict_slashes, stream=stream,
|
||||
version=version)
|
||||
|
||||
def delete(self, uri, host=None, strict_slashes=False):
|
||||
def delete(self, uri, host=None, strict_slashes=False, version=None):
|
||||
return self.route(uri, methods=["DELETE"], host=host,
|
||||
strict_slashes=strict_slashes)
|
||||
strict_slashes=strict_slashes, version=version)
|
||||
|
||||
@@ -12,11 +12,12 @@ _address_dict = {
|
||||
'Windows': ('localhost', 514),
|
||||
'Darwin': '/var/run/syslog',
|
||||
'Linux': '/dev/log',
|
||||
'FreeBSD': '/dev/log'
|
||||
'FreeBSD': '/var/run/log'
|
||||
}
|
||||
|
||||
LOGGING = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'filters': {
|
||||
'accessFilter': {
|
||||
'()': DefaultFilter,
|
||||
@@ -127,6 +128,7 @@ class Config(dict):
|
||||
self.KEEP_ALIVE = keep_alive
|
||||
self.WEBSOCKET_MAX_SIZE = 2 ** 20 # 1 megabytes
|
||||
self.WEBSOCKET_MAX_QUEUE = 32
|
||||
self.GRACEFUL_SHUTDOWN_TIMEOUT = 15.0 # 15 sec
|
||||
|
||||
if load_env:
|
||||
self.load_environment_vars()
|
||||
@@ -195,10 +197,16 @@ class Config(dict):
|
||||
|
||||
def load_environment_vars(self):
|
||||
"""
|
||||
Looks for any SANIC_ prefixed environment variables and applies
|
||||
Looks for any ``SANIC_`` prefixed environment variables and applies
|
||||
them to the configuration if present.
|
||||
"""
|
||||
for k, v in os.environ.items():
|
||||
if k.startswith(SANIC_PREFIX):
|
||||
_, config_key = k.split(SANIC_PREFIX, 1)
|
||||
self[config_key] = v
|
||||
try:
|
||||
self[config_key] = int(v)
|
||||
except ValueError:
|
||||
try:
|
||||
self[config_key] = float(v)
|
||||
except ValueError:
|
||||
self[config_key] = v
|
||||
|
||||
@@ -194,6 +194,11 @@ class ContentRangeError(SanicException):
|
||||
}
|
||||
|
||||
|
||||
@add_status_code(403)
|
||||
class Forbidden(SanicException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidRangeType(ContentRangeError):
|
||||
pass
|
||||
|
||||
@@ -205,8 +210,8 @@ class Unauthorized(SanicException):
|
||||
|
||||
:param scheme: Name of the authentication scheme to be used.
|
||||
:param challenge: A dict containing values to add to the WWW-Authenticate
|
||||
header that is generated. This is especially useful when dealing with the
|
||||
Digest scheme. (optional)
|
||||
header that is generated. This is especially useful when dealing with
|
||||
the Digest scheme. (optional)
|
||||
|
||||
Examples::
|
||||
|
||||
@@ -227,7 +232,6 @@ class Unauthorized(SanicException):
|
||||
# With a Bearer auth-scheme, realm is optional:
|
||||
challenge = {"realm": "Restricted Area"}
|
||||
raise Unauthorized("Auth required.", "Bearer", challenge)
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -249,9 +253,10 @@ def abort(status_code, message=None):
|
||||
"""
|
||||
Raise an exception based on SanicException. Returns the HTTP response
|
||||
message appropriate for the given status code, unless provided.
|
||||
|
||||
:param status_code: The HTTP status code to return.
|
||||
:param message: The HTTP response body. Defaults to the messages
|
||||
in response.py for the given status code.
|
||||
in response.py for the given status code.
|
||||
"""
|
||||
if message is None:
|
||||
message = COMMON_STATUS_CODES.get(status_code,
|
||||
|
||||
@@ -45,7 +45,7 @@ class Request(dict):
|
||||
__slots__ = (
|
||||
'app', 'headers', 'version', 'method', '_cookies', 'transport',
|
||||
'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files',
|
||||
'_ip', '_parsed_url', 'uri_template', 'stream'
|
||||
'_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr'
|
||||
)
|
||||
|
||||
def __init__(self, url_bytes, headers, version, method, transport):
|
||||
@@ -86,11 +86,15 @@ class Request(dict):
|
||||
|
||||
:return: token related to request
|
||||
"""
|
||||
prefixes = ('Bearer', 'Token')
|
||||
auth_header = self.headers.get('Authorization')
|
||||
if auth_header is not None and 'Token ' in auth_header:
|
||||
return auth_header.partition('Token ')[-1]
|
||||
else:
|
||||
return auth_header
|
||||
|
||||
if auth_header is not None:
|
||||
for prefix in prefixes:
|
||||
if prefix in auth_header:
|
||||
return auth_header.partition(prefix)[-1].strip()
|
||||
|
||||
return auth_header
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
@@ -138,7 +142,7 @@ class Request(dict):
|
||||
@property
|
||||
def cookies(self):
|
||||
if self._cookies is None:
|
||||
cookie = self.headers.get('Cookie') or self.headers.get('cookie')
|
||||
cookie = self.headers.get('Cookie')
|
||||
if cookie is not None:
|
||||
cookies = SimpleCookie()
|
||||
cookies.load(cookie)
|
||||
@@ -155,6 +159,25 @@ class Request(dict):
|
||||
(None, None))
|
||||
return self._ip
|
||||
|
||||
@property
|
||||
def remote_addr(self):
|
||||
"""Attempt to return the original client ip based on X-Forwarded-For.
|
||||
|
||||
:return: original client ip.
|
||||
"""
|
||||
if not hasattr(self, '_remote_addr'):
|
||||
forwarded_for = self.headers.get('X-Forwarded-For', '').split(',')
|
||||
remote_addrs = [
|
||||
addr for addr in [
|
||||
addr.strip() for addr in forwarded_for
|
||||
] if addr
|
||||
]
|
||||
if len(remote_addrs) > 0:
|
||||
self._remote_addr = remote_addrs[0]
|
||||
else:
|
||||
self._remote_addr = ''
|
||||
return self._remote_addr
|
||||
|
||||
@property
|
||||
def scheme(self):
|
||||
if self.app.websocket_enabled \
|
||||
@@ -234,15 +257,15 @@ def parse_multipart_form(body, boundary):
|
||||
break
|
||||
|
||||
colon_index = form_line.index(':')
|
||||
form_header_field = form_line[0:colon_index]
|
||||
form_header_field = form_line[0:colon_index].lower()
|
||||
form_header_value, form_parameters = parse_header(
|
||||
form_line[colon_index + 2:])
|
||||
|
||||
if form_header_field == 'Content-Disposition':
|
||||
if form_header_field == 'content-disposition':
|
||||
if 'filename' in form_parameters:
|
||||
file_name = form_parameters['filename']
|
||||
field_name = form_parameters.get('name')
|
||||
elif form_header_field == 'Content-Type':
|
||||
elif form_header_field == 'content-type':
|
||||
file_type = form_header_value
|
||||
|
||||
post_data = form_part[line_index:-4]
|
||||
|
||||
@@ -237,6 +237,7 @@ def json(body, status=200, headers=None,
|
||||
content_type="application/json", **kwargs):
|
||||
"""
|
||||
Returns response object with body in json format.
|
||||
|
||||
:param body: Response data to be serialized.
|
||||
:param status: Response code.
|
||||
:param headers: Custom Headers.
|
||||
@@ -250,6 +251,7 @@ def text(body, status=200, headers=None,
|
||||
content_type="text/plain; charset=utf-8"):
|
||||
"""
|
||||
Returns response object with body in text format.
|
||||
|
||||
:param body: Response data to be encoded.
|
||||
:param status: Response code.
|
||||
:param headers: Custom Headers.
|
||||
@@ -264,6 +266,7 @@ def raw(body, status=200, headers=None,
|
||||
content_type="application/octet-stream"):
|
||||
"""
|
||||
Returns response object without encoding the body.
|
||||
|
||||
:param body: Response data.
|
||||
:param status: Response code.
|
||||
:param headers: Custom Headers.
|
||||
@@ -276,6 +279,7 @@ def raw(body, status=200, headers=None,
|
||||
def html(body, status=200, headers=None):
|
||||
"""
|
||||
Returns response object with body in html format.
|
||||
|
||||
:param body: Response data to be encoded.
|
||||
:param status: Response code.
|
||||
:param headers: Custom Headers.
|
||||
|
||||
@@ -98,8 +98,25 @@ class Router:
|
||||
|
||||
return name, _type, pattern
|
||||
|
||||
def add(self, uri, methods, handler, host=None, strict_slashes=False):
|
||||
def add(self, uri, methods, handler, host=None, strict_slashes=False,
|
||||
version=None):
|
||||
"""Add a handler to the route list
|
||||
|
||||
:param uri: path to match
|
||||
:param methods: sequence of accepted method names. If none are
|
||||
provided, any method is allowed
|
||||
:param handler: request handler function.
|
||||
When executed, it should provide a response object.
|
||||
:param strict_slashes: strict to trailing slash
|
||||
:param version: current version of the route or blueprint. See
|
||||
docs for further details.
|
||||
:return: Nothing
|
||||
"""
|
||||
if version is not None:
|
||||
if uri.startswith('/'):
|
||||
uri = "/".join(["/v{}".format(str(version)), uri[1:]])
|
||||
else:
|
||||
uri = "/".join(["/v{}".format(str(version)), uri])
|
||||
# add regular version
|
||||
self._add(uri, methods, handler, host)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class HttpProtocol(asyncio.Protocol):
|
||||
signal=Signal(), connections=set(), request_timeout=60,
|
||||
request_max_size=None, request_class=None, has_log=True,
|
||||
keep_alive=True, is_request_stream=False, router=None,
|
||||
**kwargs):
|
||||
state=None, debug=False, **kwargs):
|
||||
self.loop = loop
|
||||
self.transport = None
|
||||
self.request = None
|
||||
@@ -99,12 +99,18 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self._request_handler_task = None
|
||||
self._request_stream_task = None
|
||||
self._keep_alive = keep_alive
|
||||
self._header_fragment = b''
|
||||
self.state = state if state else {}
|
||||
if 'requests_count' not in self.state:
|
||||
self.state['requests_count'] = 0
|
||||
self._debug = debug
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
return (self._keep_alive
|
||||
and not self.signal.stopped
|
||||
and self.parser.should_keep_alive())
|
||||
return (
|
||||
self._keep_alive and
|
||||
not self.signal.stopped and
|
||||
self.parser.should_keep_alive())
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Connection
|
||||
@@ -154,22 +160,39 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
|
||||
# requests count
|
||||
self.state['requests_count'] = self.state['requests_count'] + 1
|
||||
|
||||
# Parse request chunk or close connection
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except HttpParserError:
|
||||
exception = InvalidUsage('Bad Request')
|
||||
message = 'Bad Request'
|
||||
if self._debug:
|
||||
message += '\n' + traceback.format_exc()
|
||||
exception = InvalidUsage(message)
|
||||
self.write_error(exception)
|
||||
|
||||
def on_url(self, url):
|
||||
self.url = url
|
||||
if not self.url:
|
||||
self.url = url
|
||||
else:
|
||||
self.url += url
|
||||
|
||||
def on_header(self, name, value):
|
||||
if name == b'Content-Length' and int(value) > self.request_max_size:
|
||||
exception = PayloadTooLarge('Payload Too Large')
|
||||
self.write_error(exception)
|
||||
self._header_fragment += name
|
||||
|
||||
self.headers.append((name.decode().casefold(), value.decode()))
|
||||
if value is not None:
|
||||
if self._header_fragment == b'Content-Length' \
|
||||
and int(value) > self.request_max_size:
|
||||
exception = PayloadTooLarge('Payload Too Large')
|
||||
self.write_error(exception)
|
||||
|
||||
self.headers.append(
|
||||
(self._header_fragment.decode().casefold(),
|
||||
value.decode()))
|
||||
|
||||
self._header_fragment = b''
|
||||
|
||||
def on_headers_complete(self):
|
||||
self.request = self.request_class(
|
||||
@@ -357,6 +380,14 @@ class HttpProtocol(asyncio.Protocol):
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Force close the connection.
|
||||
"""
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
|
||||
def update_current_time(loop):
|
||||
"""Cache the current time, since it is needed at the end of every
|
||||
@@ -389,7 +420,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
register_sys_signals=True, run_async=False, connections=None,
|
||||
signal=Signal(), request_class=None, has_log=True, keep_alive=True,
|
||||
is_request_stream=False, router=None, websocket_max_size=None,
|
||||
websocket_max_queue=None):
|
||||
websocket_max_queue=None, state=None,
|
||||
graceful_shutdown_timeout=15.0):
|
||||
"""Start asynchronous HTTP Server on an individual process.
|
||||
|
||||
:param host: Address to host on
|
||||
@@ -427,8 +459,6 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
if debug:
|
||||
loop.set_debug(debug)
|
||||
|
||||
trigger_events(before_start, loop)
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
server = partial(
|
||||
protocol,
|
||||
@@ -445,7 +475,9 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
is_request_stream=is_request_stream,
|
||||
router=router,
|
||||
websocket_max_size=websocket_max_size,
|
||||
websocket_max_queue=websocket_max_queue
|
||||
websocket_max_queue=websocket_max_queue,
|
||||
state=state,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
server_coroutine = loop.create_server(
|
||||
@@ -457,6 +489,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
sock=sock,
|
||||
backlog=backlog
|
||||
)
|
||||
|
||||
# Instead of pulling time at the end of every request,
|
||||
# pull it once per minute
|
||||
loop.call_soon(partial(update_current_time, loop))
|
||||
@@ -464,6 +497,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
if run_async:
|
||||
return server_coroutine
|
||||
|
||||
trigger_events(before_start, loop)
|
||||
|
||||
try:
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
except:
|
||||
@@ -499,8 +534,26 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
for connection in connections:
|
||||
connection.close_if_idle()
|
||||
|
||||
while connections:
|
||||
# Gracefully shutdown timeout.
|
||||
# We should provide graceful_shutdown_timeout,
|
||||
# instead of letting connection hangs forever.
|
||||
# Let's roughly calcucate time.
|
||||
start_shutdown = 0
|
||||
while connections and (start_shutdown < graceful_shutdown_timeout):
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
start_shutdown = start_shutdown + 0.1
|
||||
|
||||
# Force close non-idle connection after waiting for
|
||||
# graceful_shutdown_timeout
|
||||
coros = []
|
||||
for conn in connections:
|
||||
if hasattr(conn, "websocket") and conn.websocket:
|
||||
coros.append(conn.websocket.close_connection(force=True))
|
||||
else:
|
||||
conn.close()
|
||||
|
||||
_shutdown = asyncio.gather(*coros, loop=loop)
|
||||
loop.run_until_complete(_shutdown)
|
||||
|
||||
trigger_events(after_stop, loop)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import sys
|
||||
import signal
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import ssl
|
||||
@@ -29,7 +30,7 @@ class GunicornWorker(base.Worker):
|
||||
self.ssl_context = self._create_ssl_context(cfg)
|
||||
else:
|
||||
self.ssl_context = None
|
||||
self.servers = []
|
||||
self.servers = {}
|
||||
self.connections = set()
|
||||
self.exit_code = 0
|
||||
self.signal = Signal()
|
||||
@@ -69,10 +70,16 @@ class GunicornWorker(base.Worker):
|
||||
trigger_events(self._server_settings.get('before_stop', []),
|
||||
self.loop)
|
||||
self.loop.run_until_complete(self.close())
|
||||
except:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
trigger_events(self._server_settings.get('after_stop', []),
|
||||
self.loop)
|
||||
self.loop.close()
|
||||
try:
|
||||
trigger_events(self._server_settings.get('after_stop', []),
|
||||
self.loop)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.loop.close()
|
||||
|
||||
sys.exit(self.exit_code)
|
||||
|
||||
@@ -91,16 +98,37 @@ class GunicornWorker(base.Worker):
|
||||
for conn in self.connections:
|
||||
conn.close_if_idle()
|
||||
|
||||
while self.connections:
|
||||
# gracefully shutdown timeout
|
||||
start_shutdown = 0
|
||||
graceful_shutdown_timeout = self.cfg.graceful_timeout
|
||||
while self.connections and \
|
||||
(start_shutdown < graceful_shutdown_timeout):
|
||||
await asyncio.sleep(0.1)
|
||||
start_shutdown = start_shutdown + 0.1
|
||||
|
||||
# Force close non-idle connection after waiting for
|
||||
# graceful_shutdown_timeout
|
||||
coros = []
|
||||
for conn in self.connections:
|
||||
if hasattr(conn, "websocket") and conn.websocket:
|
||||
coros.append(conn.websocket.close_connection(force=True))
|
||||
else:
|
||||
conn.close()
|
||||
_shutdown = asyncio.gather(*coros, loop=self.loop)
|
||||
await _shutdown
|
||||
|
||||
async def _run(self):
|
||||
for sock in self.sockets:
|
||||
self.servers.append(await serve(
|
||||
state = dict(requests_count=0)
|
||||
self._server_settings["host"] = None
|
||||
self._server_settings["port"] = None
|
||||
server = await serve(
|
||||
sock=sock,
|
||||
connections=self.connections,
|
||||
state=state,
|
||||
**self._server_settings
|
||||
))
|
||||
)
|
||||
self.servers[server] = state
|
||||
|
||||
async def _check_alive(self):
|
||||
# If our parent changed then we shut down.
|
||||
@@ -109,7 +137,15 @@ class GunicornWorker(base.Worker):
|
||||
while self.alive:
|
||||
self.notify()
|
||||
|
||||
if pid == os.getpid() and self.ppid != os.getppid():
|
||||
req_count = sum(
|
||||
self.servers[srv]["requests_count"] for srv in self.servers
|
||||
)
|
||||
if self.max_requests and req_count > self.max_requests:
|
||||
self.alive = False
|
||||
self.log.info(
|
||||
"Max requests exceeded, shutting down: %s", self
|
||||
)
|
||||
elif pid == os.getpid() and self.ppid != os.getppid():
|
||||
self.alive = False
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
else:
|
||||
@@ -166,3 +202,4 @@ class GunicornWorker(base.Worker):
|
||||
self.alive = False
|
||||
self.exit_code = 1
|
||||
self.cfg.worker_abort(self)
|
||||
sys.exit(1)
|
||||
|
||||
Reference in New Issue
Block a user