Merge branch 'master' into forbidden-exception

This commit is contained in:
François
2017-06-28 17:25:40 +02:00
committed by GitHub
11 changed files with 305 additions and 96 deletions

View File

@@ -201,4 +201,10 @@ class Config(dict):
for k, v in os.environ.items():
if k.startswith(SANIC_PREFIX):
_, config_key = k.split(SANIC_PREFIX, 1)
self[config_key] = v
try:
self[config_key] = int(v)
except ValueError:
try:
self[config_key] = float(v)
except ValueError:
self[config_key] = v

View File

@@ -203,6 +203,34 @@ class InvalidRangeType(ContentRangeError):
pass
@add_status_code(401)
class Unauthorized(SanicException):
"""
Unauthorized exception (401 HTTP status code).
:param scheme: Name of the authentication scheme to be used.
:param realm: Description of the protected area. (optional)
:param challenge: A dict containing values to add to the WWW-Authenticate
header that is generated. This is especially useful when dealing with the
Digest scheme. (optional)
"""
pass
def __init__(self, message, scheme, realm="", challenge=None):
super().__init__(message)
adds = ""
if challenge is not None:
values = ["{!s}={!r}".format(k, v) for k, v in challenge.items()]
adds = ', '.join(values)
adds = ', {}'.format(adds)
self.headers = {
"WWW-Authenticate": "{} realm='{}'{}".format(scheme, realm, adds)
}
def abort(status_code, message=None):
"""
Raise an exception based on SanicException. Returns the HTTP response

View File

@@ -86,11 +86,15 @@ class Request(dict):
:return: token related to request
"""
prefixes = ('Token ', 'Bearer ')
auth_header = self.headers.get('Authorization')
if auth_header is not None and 'Token ' in auth_header:
return auth_header.partition('Token ')[-1]
else:
return auth_header
if auth_header is not None:
for prefix in prefixes:
if prefix in auth_header:
return auth_header.partition(prefix)[-1]
return auth_header
@property
def form(self):

View File

@@ -75,7 +75,7 @@ class HttpProtocol(asyncio.Protocol):
signal=Signal(), connections=set(), request_timeout=60,
request_max_size=None, request_class=None, has_log=True,
keep_alive=True, is_request_stream=False, router=None,
**kwargs):
state=None, debug=False, **kwargs):
self.loop = loop
self.transport = None
self.request = None
@@ -99,12 +99,17 @@ class HttpProtocol(asyncio.Protocol):
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive
self.state = state if state else {}
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
self._debug = debug
@property
def keep_alive(self):
return (self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive())
return (
self._keep_alive and
not self.signal.stopped and
self.parser.should_keep_alive())
# -------------------------------------------- #
# Connection
@@ -154,11 +159,17 @@ class HttpProtocol(asyncio.Protocol):
self.headers = []
self.parser = HttpRequestParser(self)
# requests count
self.state['requests_count'] = self.state['requests_count'] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
exception = InvalidUsage('Bad Request')
message = 'Bad Request'
if self._debug:
message += '\n' + traceback.format_exc()
exception = InvalidUsage(message)
self.write_error(exception)
def on_url(self, url):
@@ -389,7 +400,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
register_sys_signals=True, run_async=False, connections=None,
signal=Signal(), request_class=None, has_log=True, keep_alive=True,
is_request_stream=False, router=None, websocket_max_size=None,
websocket_max_queue=None):
websocket_max_queue=None, state=None):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
@@ -427,8 +438,6 @@ def serve(host, port, request_handler, error_handler, before_start=None,
if debug:
loop.set_debug(debug)
trigger_events(before_start, loop)
connections = connections if connections is not None else set()
server = partial(
protocol,
@@ -445,7 +454,9 @@ def serve(host, port, request_handler, error_handler, before_start=None,
is_request_stream=is_request_stream,
router=router,
websocket_max_size=websocket_max_size,
websocket_max_queue=websocket_max_queue
websocket_max_queue=websocket_max_queue,
state=state,
debug=debug,
)
server_coroutine = loop.create_server(
@@ -457,6 +468,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
sock=sock,
backlog=backlog
)
# Instead of pulling time at the end of every request,
# pull it once per minute
loop.call_soon(partial(update_current_time, loop))
@@ -464,6 +476,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
if run_async:
return server_coroutine
trigger_events(before_start, loop)
try:
http_server = loop.run_until_complete(server_coroutine)
except:

View File

@@ -29,7 +29,7 @@ class GunicornWorker(base.Worker):
self.ssl_context = self._create_ssl_context(cfg)
else:
self.ssl_context = None
self.servers = []
self.servers = {}
self.connections = set()
self.exit_code = 0
self.signal = Signal()
@@ -96,11 +96,16 @@ class GunicornWorker(base.Worker):
async def _run(self):
for sock in self.sockets:
self.servers.append(await serve(
state = dict(requests_count=0)
self._server_settings["host"] = None
self._server_settings["port"] = None
server = await serve(
sock=sock,
connections=self.connections,
state=state,
**self._server_settings
))
)
self.servers[server] = state
async def _check_alive(self):
# If our parent changed then we shut down.
@@ -109,7 +114,15 @@ class GunicornWorker(base.Worker):
while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
req_count = sum(
self.servers[srv]["requests_count"] for srv in self.servers
)
if self.max_requests and req_count > self.max_requests:
self.alive = False
self.log.info(
"Max requests exceeded, shutting down: %s", self
)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
@@ -166,3 +179,4 @@ class GunicornWorker(base.Worker):
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)