Merge remote-tracking branch 'upstream/master'
# Conflicts: # sanic/server.py
This commit is contained in:
13
sanic/app.py
13
sanic/app.py
@@ -33,7 +33,9 @@ class Sanic:
|
||||
logging.config.dictConfig(log_config)
|
||||
# Only set up a default log handler if the
|
||||
# end-user application didn't set anything up.
|
||||
if not logging.root.handlers and log.level == logging.NOTSET:
|
||||
if not (logging.root.handlers and
|
||||
log.level == logging.NOTSET and
|
||||
log_config):
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s: %(levelname)s: %(message)s")
|
||||
handler = logging.StreamHandler()
|
||||
@@ -543,7 +545,7 @@ class Sanic:
|
||||
def run(self, host=None, port=None, debug=False, ssl=None,
|
||||
sock=None, workers=1, protocol=None,
|
||||
backlog=100, stop_event=None, register_sys_signals=True,
|
||||
log_config=LOGGING):
|
||||
log_config=None):
|
||||
"""Run the HTTP Server and listen until keyboard interrupt or term
|
||||
signal. On termination, drain connections before closing.
|
||||
|
||||
@@ -565,6 +567,7 @@ class Sanic:
|
||||
host, port = host or "127.0.0.1", port or 8000
|
||||
|
||||
if log_config:
|
||||
self.log_config = log_config
|
||||
logging.config.dictConfig(log_config)
|
||||
if protocol is None:
|
||||
protocol = (WebSocketProtocol if self.websocket_enabled
|
||||
@@ -578,7 +581,7 @@ class Sanic:
|
||||
host=host, port=port, debug=debug, ssl=ssl, sock=sock,
|
||||
workers=workers, protocol=protocol, backlog=backlog,
|
||||
register_sys_signals=register_sys_signals,
|
||||
has_log=log_config is not None)
|
||||
has_log=self.log_config is not None)
|
||||
|
||||
try:
|
||||
self.is_running = True
|
||||
@@ -696,7 +699,9 @@ class Sanic:
|
||||
'loop': loop,
|
||||
'register_sys_signals': register_sys_signals,
|
||||
'backlog': backlog,
|
||||
'has_log': has_log
|
||||
'has_log': has_log,
|
||||
'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE,
|
||||
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE
|
||||
}
|
||||
|
||||
# -------------------------------------------- #
|
||||
|
||||
@@ -122,9 +122,11 @@ class Config(dict):
|
||||
▌ ▐ ▀▀▄▄▄▀
|
||||
▀▀▄▄▀
|
||||
"""
|
||||
self.REQUEST_MAX_SIZE = 100000000 # 100 megababies
|
||||
self.REQUEST_MAX_SIZE = 100000000 # 100 megabytes
|
||||
self.REQUEST_TIMEOUT = 60 # 60 seconds
|
||||
self.KEEP_ALIVE = keep_alive
|
||||
self.WEBSOCKET_MAX_SIZE = 2 ** 20 # 1 megabytes
|
||||
self.WEBSOCKET_MAX_QUEUE = 32
|
||||
|
||||
if load_env:
|
||||
self.load_environment_vars()
|
||||
@@ -199,4 +201,10 @@ class Config(dict):
|
||||
for k, v in os.environ.items():
|
||||
if k.startswith(SANIC_PREFIX):
|
||||
_, config_key = k.split(SANIC_PREFIX, 1)
|
||||
self[config_key] = v
|
||||
try:
|
||||
self[config_key] = int(v)
|
||||
except ValueError:
|
||||
try:
|
||||
self[config_key] = float(v)
|
||||
except ValueError:
|
||||
self[config_key] = v
|
||||
|
||||
@@ -198,6 +198,34 @@ class InvalidRangeType(ContentRangeError):
|
||||
pass
|
||||
|
||||
|
||||
@add_status_code(401)
|
||||
class Unauthorized(SanicException):
|
||||
"""
|
||||
Unauthorized exception (401 HTTP status code).
|
||||
|
||||
:param scheme: Name of the authentication scheme to be used.
|
||||
:param realm: Description of the protected area. (optional)
|
||||
:param challenge: A dict containing values to add to the WWW-Authenticate
|
||||
header that is generated. This is especially useful when dealing with the
|
||||
Digest scheme. (optional)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, message, scheme, realm="", challenge=None):
|
||||
super().__init__(message)
|
||||
|
||||
adds = ""
|
||||
|
||||
if challenge is not None:
|
||||
values = ["{!s}={!r}".format(k, v) for k, v in challenge.items()]
|
||||
adds = ', '.join(values)
|
||||
adds = ', {}'.format(adds)
|
||||
|
||||
self.headers = {
|
||||
"WWW-Authenticate": "{} realm='{}'{}".format(scheme, realm, adds)
|
||||
}
|
||||
|
||||
|
||||
def abort(status_code, message=None):
|
||||
"""
|
||||
Raise an exception based on SanicException. Returns the HTTP response
|
||||
|
||||
@@ -86,11 +86,15 @@ class Request(dict):
|
||||
|
||||
:return: token related to request
|
||||
"""
|
||||
prefixes = ('Token ', 'Bearer ')
|
||||
auth_header = self.headers.get('Authorization')
|
||||
if auth_header is not None and 'Token ' in auth_header:
|
||||
return auth_header.partition('Token ')[-1]
|
||||
else:
|
||||
return auth_header
|
||||
|
||||
if auth_header is not None:
|
||||
for prefix in prefixes:
|
||||
if prefix in auth_header:
|
||||
return auth_header.partition(prefix)[-1]
|
||||
|
||||
return auth_header
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
@@ -174,6 +178,15 @@ class Request(dict):
|
||||
# so pull it from the headers
|
||||
return self.headers.get('Host', '')
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
return self.headers.get('Content-Type', DEFAULT_HTTP_CONTENT_TYPE)
|
||||
|
||||
@property
|
||||
def match_info(self):
|
||||
"""return matched info after resolving route"""
|
||||
return self.app.router.get(self)[2]
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._parsed_url.path.decode('utf-8')
|
||||
|
||||
@@ -233,7 +233,8 @@ class HTTPResponse(BaseHTTPResponse):
|
||||
return self._cookies
|
||||
|
||||
|
||||
def json(body, status=200, headers=None, **kwargs):
|
||||
def json(body, status=200, headers=None,
|
||||
content_type="application/json", **kwargs):
|
||||
"""
|
||||
Returns response object with body in json format.
|
||||
:param body: Response data to be serialized.
|
||||
@@ -242,7 +243,7 @@ def json(body, status=200, headers=None, **kwargs):
|
||||
:param kwargs: Remaining arguments that are passed to the json encoder.
|
||||
"""
|
||||
return HTTPResponse(json_dumps(body, **kwargs), headers=headers,
|
||||
status=status, content_type="application/json")
|
||||
status=status, content_type=content_type)
|
||||
|
||||
|
||||
def text(body, status=200, headers=None,
|
||||
|
||||
@@ -351,7 +351,10 @@ class Router:
|
||||
:param request: Request object
|
||||
:return: bool
|
||||
"""
|
||||
handler = self.get(request)[0]
|
||||
try:
|
||||
handler = self.get(request)[0]
|
||||
except (NotFound, InvalidUsage):
|
||||
return False
|
||||
if (hasattr(handler, 'view_class') and
|
||||
hasattr(handler.view_class, request.method.lower())):
|
||||
handler = getattr(handler.view_class, request.method.lower())
|
||||
|
||||
@@ -74,7 +74,8 @@ class HttpProtocol(asyncio.Protocol):
|
||||
def __init__(self, *, loop, request_handler, error_handler,
|
||||
signal=Signal(), connections=set(), request_timeout=60,
|
||||
request_max_size=None, request_class=None, has_log=True,
|
||||
keep_alive=True, is_request_stream=False, router=None):
|
||||
keep_alive=True, is_request_stream=False, router=None,
|
||||
state=None, debug=False, **kwargs):
|
||||
self.loop = loop
|
||||
self.transport = None
|
||||
self.request = None
|
||||
@@ -99,12 +100,17 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self._request_stream_task = None
|
||||
self._keep_alive = keep_alive
|
||||
self._header_fragment = b''
|
||||
self.state = state if state else {}
|
||||
if 'requests_count' not in self.state:
|
||||
self.state['requests_count'] = 0
|
||||
self._debug = debug
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
return (self._keep_alive
|
||||
and not self.signal.stopped
|
||||
and self.parser.should_keep_alive())
|
||||
return (
|
||||
self._keep_alive and
|
||||
not self.signal.stopped and
|
||||
self.parser.should_keep_alive())
|
||||
|
||||
# -------------------------------------------- #
|
||||
# Connection
|
||||
@@ -154,11 +160,17 @@ class HttpProtocol(asyncio.Protocol):
|
||||
self.headers = []
|
||||
self.parser = HttpRequestParser(self)
|
||||
|
||||
# requests count
|
||||
self.state['requests_count'] = self.state['requests_count'] + 1
|
||||
|
||||
# Parse request chunk or close connection
|
||||
try:
|
||||
self.parser.feed_data(data)
|
||||
except HttpParserError:
|
||||
exception = InvalidUsage('Bad Request')
|
||||
message = 'Bad Request'
|
||||
if self._debug:
|
||||
message += '\n' + traceback.format_exc()
|
||||
exception = InvalidUsage(message)
|
||||
self.write_error(exception)
|
||||
|
||||
def on_url(self, url):
|
||||
@@ -399,7 +411,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
reuse_port=False, loop=None, protocol=HttpProtocol, backlog=100,
|
||||
register_sys_signals=True, run_async=False, connections=None,
|
||||
signal=Signal(), request_class=None, has_log=True, keep_alive=True,
|
||||
is_request_stream=False, router=None):
|
||||
is_request_stream=False, router=None, websocket_max_size=None,
|
||||
websocket_max_queue=None, state=None):
|
||||
"""Start asynchronous HTTP Server on an individual process.
|
||||
|
||||
:param host: Address to host on
|
||||
@@ -437,8 +450,6 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
if debug:
|
||||
loop.set_debug(debug)
|
||||
|
||||
trigger_events(before_start, loop)
|
||||
|
||||
connections = connections if connections is not None else set()
|
||||
server = partial(
|
||||
protocol,
|
||||
@@ -454,6 +465,10 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
keep_alive=keep_alive,
|
||||
is_request_stream=is_request_stream,
|
||||
router=router,
|
||||
websocket_max_size=websocket_max_size,
|
||||
websocket_max_queue=websocket_max_queue,
|
||||
state=state,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
server_coroutine = loop.create_server(
|
||||
@@ -465,6 +480,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
sock=sock,
|
||||
backlog=backlog
|
||||
)
|
||||
|
||||
# Instead of pulling time at the end of every request,
|
||||
# pull it once per minute
|
||||
loop.call_soon(partial(update_current_time, loop))
|
||||
@@ -472,6 +488,8 @@ def serve(host, port, request_handler, error_handler, before_start=None,
|
||||
if run_async:
|
||||
return server_coroutine
|
||||
|
||||
trigger_events(before_start, loop)
|
||||
|
||||
try:
|
||||
http_server = loop.run_until_complete(server_coroutine)
|
||||
except:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
from json import JSONDecodeError
|
||||
|
||||
from sanic.log import log
|
||||
|
||||
@@ -28,6 +29,14 @@ class SanicTestClient:
|
||||
response.text = await response.text()
|
||||
except UnicodeDecodeError as e:
|
||||
response.text = None
|
||||
|
||||
try:
|
||||
response.json = await response.json()
|
||||
except (JSONDecodeError,
|
||||
UnicodeDecodeError,
|
||||
aiohttp.ClientResponseError):
|
||||
response.json = None
|
||||
|
||||
response.body = await response.read()
|
||||
return response
|
||||
|
||||
|
||||
@@ -6,9 +6,12 @@ from websockets import ConnectionClosed # noqa
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, websocket_max_size=None,
|
||||
websocket_max_queue=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.websocket = None
|
||||
self.websocket_max_size = websocket_max_size
|
||||
self.websocket_max_queue = websocket_max_queue
|
||||
|
||||
def connection_timeout(self):
|
||||
# timeouts make no sense for websocket routes
|
||||
@@ -62,6 +65,9 @@ class WebSocketProtocol(HttpProtocol):
|
||||
request.transport.write(rv)
|
||||
|
||||
# hook up the websocket protocol
|
||||
self.websocket = WebSocketCommonProtocol()
|
||||
self.websocket = WebSocketCommonProtocol(
|
||||
max_size=self.websocket_max_size,
|
||||
max_queue=self.websocket_max_queue
|
||||
)
|
||||
self.websocket.connection_made(request.transport)
|
||||
return self.websocket
|
||||
|
||||
@@ -29,7 +29,7 @@ class GunicornWorker(base.Worker):
|
||||
self.ssl_context = self._create_ssl_context(cfg)
|
||||
else:
|
||||
self.ssl_context = None
|
||||
self.servers = []
|
||||
self.servers = {}
|
||||
self.connections = set()
|
||||
self.exit_code = 0
|
||||
self.signal = Signal()
|
||||
@@ -96,11 +96,16 @@ class GunicornWorker(base.Worker):
|
||||
|
||||
async def _run(self):
|
||||
for sock in self.sockets:
|
||||
self.servers.append(await serve(
|
||||
state = dict(requests_count=0)
|
||||
self._server_settings["host"] = None
|
||||
self._server_settings["port"] = None
|
||||
server = await serve(
|
||||
sock=sock,
|
||||
connections=self.connections,
|
||||
state=state,
|
||||
**self._server_settings
|
||||
))
|
||||
)
|
||||
self.servers[server] = state
|
||||
|
||||
async def _check_alive(self):
|
||||
# If our parent changed then we shut down.
|
||||
@@ -109,7 +114,15 @@ class GunicornWorker(base.Worker):
|
||||
while self.alive:
|
||||
self.notify()
|
||||
|
||||
if pid == os.getpid() and self.ppid != os.getppid():
|
||||
req_count = sum(
|
||||
self.servers[srv]["requests_count"] for srv in self.servers
|
||||
)
|
||||
if self.max_requests and req_count > self.max_requests:
|
||||
self.alive = False
|
||||
self.log.info(
|
||||
"Max requests exceeded, shutting down: %s", self
|
||||
)
|
||||
elif pid == os.getpid() and self.ppid != os.getppid():
|
||||
self.alive = False
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
else:
|
||||
@@ -159,9 +172,11 @@ class GunicornWorker(base.Worker):
|
||||
|
||||
def handle_quit(self, sig, frame):
|
||||
self.alive = False
|
||||
self.app.callable.is_running = False
|
||||
self.cfg.worker_int(self)
|
||||
|
||||
def handle_abort(self, sig, frame):
|
||||
self.alive = False
|
||||
self.exit_code = 1
|
||||
self.cfg.worker_abort(self)
|
||||
sys.exit(1)
|
||||
|
||||
Reference in New Issue
Block a user