run black against sanic module

This commit is contained in:
Yun Xu 2018-10-13 17:55:33 -07:00
parent 9ae6dfb6d2
commit aa9bf04dfe
22 changed files with 1657 additions and 1036 deletions

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
line-length = 79

View File

@ -1,6 +1,6 @@
from sanic.app import Sanic
from sanic.blueprints import Blueprint
__version__ = '0.8.3'
__version__ = "0.8.3"
__all__ = ['Sanic', 'Blueprint']
__all__ = ["Sanic", "Blueprint"]

View File

@ -5,16 +5,18 @@ from sanic.log import logger
from sanic.app import Sanic
if __name__ == "__main__":
parser = ArgumentParser(prog='sanic')
parser.add_argument('--host', dest='host', type=str, default='127.0.0.1')
parser.add_argument('--port', dest='port', type=int, default=8000)
parser.add_argument('--cert', dest='cert', type=str,
help='location of certificate for SSL')
parser.add_argument('--key', dest='key', type=str,
help='location of keyfile for SSL.')
parser.add_argument('--workers', dest='workers', type=int, default=1, )
parser.add_argument('--debug', dest='debug', action="store_true")
parser.add_argument('module')
parser = ArgumentParser(prog="sanic")
parser.add_argument("--host", dest="host", type=str, default="127.0.0.1")
parser.add_argument("--port", dest="port", type=int, default=8000)
parser.add_argument(
"--cert", dest="cert", type=str, help="location of certificate for SSL"
)
parser.add_argument(
"--key", dest="key", type=str, help="location of keyfile for SSL."
)
parser.add_argument("--workers", dest="workers", type=int, default=1)
parser.add_argument("--debug", dest="debug", action="store_true")
parser.add_argument("module")
args = parser.parse_args()
try:
@ -25,20 +27,29 @@ if __name__ == "__main__":
module = import_module(module_name)
app = getattr(module, app_name, None)
if not isinstance(app, Sanic):
raise ValueError("Module is not a Sanic app, it is a {}. "
"Perhaps you meant {}.app?"
.format(type(app).__name__, args.module))
raise ValueError(
"Module is not a Sanic app, it is a {}. "
"Perhaps you meant {}.app?".format(
type(app).__name__, args.module
)
)
if args.cert is not None or args.key is not None:
ssl = {'cert': args.cert, 'key': args.key}
ssl = {"cert": args.cert, "key": args.key}
else:
ssl = None
app.run(host=args.host, port=args.port,
workers=args.workers, debug=args.debug, ssl=ssl)
app.run(
host=args.host,
port=args.port,
workers=args.workers,
debug=args.debug,
ssl=ssl,
)
except ImportError as e:
logger.error("No module named {} found.\n"
" Example File: project/sanic_server.py -> app\n"
" Example Module: project.sanic_server.app"
.format(e.name))
logger.error(
"No module named {} found.\n"
" Example File: project/sanic_server.py -> app\n"
" Example Module: project.sanic_server.app".format(e.name)
)
except ValueError as e:
logger.exception("Failed to run app")

View File

@ -27,10 +27,17 @@ import sanic.reloader_helpers as reloader_helpers
class Sanic:
def __init__(self, name=None, router=None, error_handler=None,
load_env=True, request_class=None,
strict_slashes=False, log_config=None,
configure_logging=True):
def __init__(
self,
name=None,
router=None,
error_handler=None,
load_env=True,
request_class=None,
strict_slashes=False,
log_config=None,
configure_logging=True,
):
# Get name from previous stack frame
if name is None:
@ -71,8 +78,9 @@ class Sanic:
"""
if not self.is_running:
raise SanicException(
'Loop can only be retrieved after the app has started '
'running. Not supported with `create_server` function')
"Loop can only be retrieved after the app has started "
"running. Not supported with `create_server` function"
)
return get_event_loop()
# -------------------------------------------------------------------- #
@ -96,7 +104,8 @@ class Sanic:
else:
self.loop.create_task(task)
except SanicException:
@self.listener('before_server_start')
@self.listener("before_server_start")
def run(app, loop):
if callable(task):
try:
@ -133,8 +142,16 @@ class Sanic:
return self.listener(event)(listener)
# Decorator
def route(self, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=None, stream=False, version=None, name=None):
def route(
self,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
"""Decorate a function to be registered as a route
:param uri: path of the URL
@ -149,8 +166,8 @@ class Sanic:
# Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working
if not uri.startswith('/'):
uri = '/' + uri
if not uri.startswith("/"):
uri = "/" + uri
if stream:
self.is_request_stream = True
@ -164,63 +181,141 @@ class Sanic:
if stream:
handler.is_stream = stream
self.router.add(uri=uri, methods=methods, handler=handler,
host=host, strict_slashes=strict_slashes,
version=version, name=name)
self.router.add(
uri=uri,
methods=methods,
handler=handler,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
return handler
else:
raise ValueError(
'Required parameter `request` missing '
'in the {0}() route?'.format(
handler.__name__))
"Required parameter `request` missing "
"in the {0}() route?".format(handler.__name__)
)
return response
# Shorthand method decorators
def get(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=frozenset({"GET"}), host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def get(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=frozenset({"GET"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def post(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=frozenset({"POST"}), host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def post(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"POST"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def put(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=frozenset({"PUT"}), host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def put(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"PUT"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def head(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=frozenset({"HEAD"}), host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def head(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=frozenset({"HEAD"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def options(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=frozenset({"OPTIONS"}), host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def options(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=frozenset({"OPTIONS"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def patch(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=frozenset({"PATCH"}), host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def patch(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"PATCH"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def delete(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=frozenset({"DELETE"}), host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def delete(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=frozenset({"DELETE"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=None, version=None, name=None, stream=False):
def add_route(
self,
handler,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
version=None,
name=None,
stream=False,
):
"""A helper method to register class instance or
functions as a handler to the application url
routes.
@ -237,35 +332,42 @@ class Sanic:
:return: function or class instance
"""
# Handle HTTPMethodView differently
if hasattr(handler, 'view_class'):
if hasattr(handler, "view_class"):
methods = set()
for method in HTTP_METHODS:
_handler = getattr(handler.view_class, method.lower(), None)
if _handler:
methods.add(method)
if hasattr(_handler, 'is_stream'):
if hasattr(_handler, "is_stream"):
stream = True
# handle composition view differently
if isinstance(handler, CompositionView):
methods = handler.handlers.keys()
for _handler in handler.handlers.values():
if hasattr(_handler, 'is_stream'):
if hasattr(_handler, "is_stream"):
stream = True
break
if strict_slashes is None:
strict_slashes = self.strict_slashes
self.route(uri=uri, methods=methods, host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)(handler)
self.route(
uri=uri,
methods=methods,
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)(handler)
return handler
# Decorator
def websocket(self, uri, host=None, strict_slashes=None,
subprotocols=None, name=None):
def websocket(
self, uri, host=None, strict_slashes=None, subprotocols=None, name=None
):
"""Decorate a function to be registered as a websocket route
:param uri: path of the URL
:param subprotocols: optional list of strings with the supported
@ -277,8 +379,8 @@ class Sanic:
# Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working
if not uri.startswith('/'):
uri = '/' + uri
if not uri.startswith("/"):
uri = "/" + uri
if strict_slashes is None:
strict_slashes = self.strict_slashes
@ -307,21 +409,38 @@ class Sanic:
self.websocket_tasks.remove(fut)
await ws.close()
self.router.add(uri=uri, handler=websocket_handler,
methods=frozenset({'GET'}), host=host,
strict_slashes=strict_slashes, name=name)
self.router.add(
uri=uri,
handler=websocket_handler,
methods=frozenset({"GET"}),
host=host,
strict_slashes=strict_slashes,
name=name,
)
return handler
return response
def add_websocket_route(self, handler, uri, host=None,
strict_slashes=None, subprotocols=None, name=None):
def add_websocket_route(
self,
handler,
uri,
host=None,
strict_slashes=None,
subprotocols=None,
name=None,
):
"""A helper method to register a function as a websocket route."""
if strict_slashes is None:
strict_slashes = self.strict_slashes
return self.websocket(uri, host=host, strict_slashes=strict_slashes,
subprotocols=subprotocols, name=name)(handler)
return self.websocket(
uri,
host=host,
strict_slashes=strict_slashes,
subprotocols=subprotocols,
name=name,
)(handler)
def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
@ -332,7 +451,7 @@ class Sanic:
if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly
@self.listener('before_server_stop')
@self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks:
task.cancel()
@ -361,10 +480,10 @@ class Sanic:
return response
def register_middleware(self, middleware, attach_to='request'):
if attach_to == 'request':
def register_middleware(self, middleware, attach_to="request"):
if attach_to == "request":
self.request_middleware.append(middleware)
if attach_to == 'response':
if attach_to == "response":
self.response_middleware.appendleft(middleware)
return middleware
@ -379,21 +498,40 @@ class Sanic:
return self.register_middleware(middleware_or_request)
else:
return partial(self.register_middleware,
attach_to=middleware_or_request)
return partial(
self.register_middleware, attach_to=middleware_or_request
)
# Static Files
def static(self, uri, file_or_directory, pattern=r'/?.+',
use_modified_since=True, use_content_range=False,
stream_large_files=False, name='static', host=None,
strict_slashes=None, content_type=None):
def static(
self,
uri,
file_or_directory,
pattern=r"/?.+",
use_modified_since=True,
use_content_range=False,
stream_large_files=False,
name="static",
host=None,
strict_slashes=None,
content_type=None,
):
"""Register a root to serve files from. The input can either be a
file or a directory. See
"""
static_register(self, uri, file_or_directory, pattern,
use_modified_since, use_content_range,
stream_large_files, name, host, strict_slashes,
content_type)
static_register(
self,
uri,
file_or_directory,
pattern,
use_modified_since,
use_content_range,
stream_large_files,
name,
host,
strict_slashes,
content_type,
)
def blueprint(self, blueprint, **options):
"""Register a blueprint on the application.
@ -407,10 +545,10 @@ class Sanic:
self.blueprint(item, **options)
return
if blueprint.name in self.blueprints:
assert self.blueprints[blueprint.name] is blueprint, \
'A blueprint with the name "%s" is already registered. ' \
'Blueprint names must be unique.' % \
(blueprint.name,)
assert self.blueprints[blueprint.name] is blueprint, (
'A blueprint with the name "%s" is already registered. '
"Blueprint names must be unique." % (blueprint.name,)
)
else:
self.blueprints[blueprint.name] = blueprint
self._blueprint_order.append(blueprint)
@ -419,11 +557,13 @@ class Sanic:
def register_blueprint(self, *args, **kwargs):
# TODO: deprecate 1.0
if self.debug:
warnings.simplefilter('default')
warnings.warn("Use of register_blueprint will be deprecated in "
"version 1.0. Please use the blueprint method"
" instead",
DeprecationWarning)
warnings.simplefilter("default")
warnings.warn(
"Use of register_blueprint will be deprecated in "
"version 1.0. Please use the blueprint method"
" instead",
DeprecationWarning,
)
return self.blueprint(*args, **kwargs)
def url_for(self, view_name: str, **kwargs):
@ -449,67 +589,66 @@ class Sanic:
# find the route by the supplied view name
kw = {}
# special static files url_for
if view_name == 'static':
kw.update(name=kwargs.pop('name', 'static'))
elif view_name.endswith('.static'): # blueprint.static
kwargs.pop('name', None)
if view_name == "static":
kw.update(name=kwargs.pop("name", "static"))
elif view_name.endswith(".static"): # blueprint.static
kwargs.pop("name", None)
kw.update(name=view_name)
uri, route = self.router.find_route_by_view_name(view_name, **kw)
if not (uri and route):
raise URLBuildError('Endpoint with name `{}` was not found'.format(
view_name))
raise URLBuildError(
"Endpoint with name `{}` was not found".format(view_name)
)
if view_name == 'static' or view_name.endswith('.static'):
filename = kwargs.pop('filename', None)
if view_name == "static" or view_name.endswith(".static"):
filename = kwargs.pop("filename", None)
# it's static folder
if '<file_uri:' in uri:
folder_ = uri.split('<file_uri:', 1)[0]
if folder_.endswith('/'):
if "<file_uri:" in uri:
folder_ = uri.split("<file_uri:", 1)[0]
if folder_.endswith("/"):
folder_ = folder_[:-1]
if filename.startswith('/'):
if filename.startswith("/"):
filename = filename[1:]
uri = '{}/{}'.format(folder_, filename)
uri = "{}/{}".format(folder_, filename)
if uri != '/' and uri.endswith('/'):
if uri != "/" and uri.endswith("/"):
uri = uri[:-1]
out = uri
# find all the parameters we will need to build in the URL
matched_params = re.findall(
self.router.parameter_pattern, uri)
matched_params = re.findall(self.router.parameter_pattern, uri)
# _method is only a placeholder now, don't know how to support it
kwargs.pop('_method', None)
anchor = kwargs.pop('_anchor', '')
kwargs.pop("_method", None)
anchor = kwargs.pop("_anchor", "")
# _external need SERVER_NAME in config or pass _server arg
external = kwargs.pop('_external', False)
scheme = kwargs.pop('_scheme', '')
external = kwargs.pop("_external", False)
scheme = kwargs.pop("_scheme", "")
if scheme and not external:
raise ValueError('When specifying _scheme, _external must be True')
raise ValueError("When specifying _scheme, _external must be True")
netloc = kwargs.pop('_server', None)
netloc = kwargs.pop("_server", None)
if netloc is None and external:
netloc = self.config.get('SERVER_NAME', '')
netloc = self.config.get("SERVER_NAME", "")
if external:
if not scheme:
if ':' in netloc[:8]:
scheme = netloc[:8].split(':', 1)[0]
if ":" in netloc[:8]:
scheme = netloc[:8].split(":", 1)[0]
else:
scheme = 'http'
scheme = "http"
if '://' in netloc[:8]:
netloc = netloc.split('://', 1)[-1]
if "://" in netloc[:8]:
netloc = netloc.split("://", 1)[-1]
for match in matched_params:
name, _type, pattern = self.router.parse_parameter_string(
match)
name, _type, pattern = self.router.parse_parameter_string(match)
# we only want to match against each individual parameter
specific_pattern = '^{}$'.format(pattern)
specific_pattern = "^{}$".format(pattern)
supplied_param = None
if name in kwargs:
@ -517,8 +656,10 @@ class Sanic:
del kwargs[name]
else:
raise URLBuildError(
'Required parameter `{}` was not passed to url_for'.format(
name))
"Required parameter `{}` was not passed to url_for".format(
name
)
)
supplied_param = str(supplied_param)
# determine if the parameter supplied by the caller passes the test
@ -529,25 +670,28 @@ class Sanic:
if _type != str:
msg = (
'Value "{}" for parameter `{}` does not '
'match pattern for type `{}`: {}'.format(
supplied_param, name, _type.__name__, pattern))
"match pattern for type `{}`: {}".format(
supplied_param, name, _type.__name__, pattern
)
)
else:
msg = (
'Value "{}" for parameter `{}` '
'does not satisfy pattern {}'.format(
supplied_param, name, pattern))
"does not satisfy pattern {}".format(
supplied_param, name, pattern
)
)
raise URLBuildError(msg)
# replace the parameter in the URL with the supplied value
replacement_regex = '(<{}.*?>)'.format(name)
replacement_regex = "(<{}.*?>)".format(name)
out = re.sub(
replacement_regex, supplied_param, out)
out = re.sub(replacement_regex, supplied_param, out)
# parse the remainder of the keyword arguments into a querystring
query_string = urlencode(kwargs, doseq=True) if kwargs else ''
query_string = urlencode(kwargs, doseq=True) if kwargs else ""
# scheme://netloc/path;parameters?query#fragment
out = urlunparse((scheme, netloc, out, '', query_string, anchor))
out = urlunparse((scheme, netloc, out, "", query_string, anchor))
return out
@ -594,8 +738,11 @@ class Sanic:
request.uri_template = uri
if handler is None:
raise ServerError(
("'None' was returned while requesting a "
"handler from the router"))
(
"'None' was returned while requesting a "
"handler from the router"
)
)
# Run response handler
response = handler(request, *args, **kwargs)
@ -619,16 +766,20 @@ class Sanic:
response = await response
except Exception as e:
if isinstance(e, SanicException):
response = self.error_handler.default(request=request,
exception=e)
response = self.error_handler.default(
request=request, exception=e
)
elif self.debug:
response = HTTPResponse(
"Error while handling error: {}\nStack: {}".format(
e, format_exc()), status=500)
e, format_exc()
),
status=500,
)
else:
response = HTTPResponse(
"An error occurred while handling an error",
status=500)
"An error occurred while handling an error", status=500
)
finally:
# -------------------------------------------- #
# Response Middleware
@ -636,16 +787,17 @@ class Sanic:
# Don't run response middleware if response is None
if response is not None:
try:
response = await self._run_response_middleware(request,
response)
response = await self._run_response_middleware(
request, response
)
except CancelledError:
# Response middleware can timeout too, as above.
response = None
cancelled = True
except BaseException:
error_logger.exception(
'Exception occurred in one of response '
'middleware handlers'
"Exception occurred in one of response "
"middleware handlers"
)
if cancelled:
raise CancelledError()
@ -668,10 +820,21 @@ class Sanic:
# Execution
# -------------------------------------------------------------------- #
def run(self, host=None, port=None, debug=False, ssl=None,
sock=None, workers=1, protocol=None,
backlog=100, stop_event=None, register_sys_signals=True,
access_log=True, **kwargs):
def run(
self,
host=None,
port=None,
debug=False,
ssl=None,
sock=None,
workers=1,
protocol=None,
backlog=100,
stop_event=None,
register_sys_signals=True,
access_log=True,
**kwargs
):
"""Run the HTTP Server and listen until keyboard interrupt or term
signal. On termination, drain connections before closing.
@ -692,7 +855,7 @@ class Sanic:
# Default auto_reload to false
auto_reload = False
# If debug is set, default it to true (unless on windows)
if debug and os.name == 'posix':
if debug and os.name == "posix":
auto_reload = True
# Allow for overriding either of the defaults
auto_reload = kwargs.get("auto_reload", auto_reload)
@ -701,30 +864,43 @@ class Sanic:
host, port = host or "127.0.0.1", port or 8000
if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
protocol = (
WebSocketProtocol if self.websocket_enabled else HttpProtocol
)
if stop_event is not None:
if debug:
warnings.simplefilter('default')
warnings.warn("stop_event will be removed from future versions.",
DeprecationWarning)
warnings.simplefilter("default")
warnings.warn(
"stop_event will be removed from future versions.",
DeprecationWarning,
)
# compatibility old access_log params
self.config.ACCESS_LOG = access_log
server_settings = self._helper(
host=host, port=port, debug=debug, ssl=ssl, sock=sock,
workers=workers, protocol=protocol, backlog=backlog,
register_sys_signals=register_sys_signals, auto_reload=auto_reload)
host=host,
port=port,
debug=debug,
ssl=ssl,
sock=sock,
workers=workers,
protocol=protocol,
backlog=backlog,
register_sys_signals=register_sys_signals,
auto_reload=auto_reload,
)
try:
self.is_running = True
if workers == 1:
if auto_reload and os.name != 'posix':
if auto_reload and os.name != "posix":
# This condition must be removed after implementing
# auto reloader for other operating systems.
raise NotImplementedError
if auto_reload and \
os.environ.get('SANIC_SERVER_RUNNING') != 'true':
if (
auto_reload
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
):
reloader_helpers.watchdog(2)
else:
serve(**server_settings)
@ -732,7 +908,8 @@ class Sanic:
serve_multiple(server_settings, workers)
except BaseException:
error_logger.exception(
'Experienced exception while trying to serve')
"Experienced exception while trying to serve"
)
raise
finally:
self.is_running = False
@ -746,10 +923,18 @@ class Sanic:
"""gunicorn compatibility"""
return self
async def create_server(self, host=None, port=None, debug=False,
ssl=None, sock=None, protocol=None,
backlog=100, stop_event=None,
access_log=True):
async def create_server(
self,
host=None,
port=None,
debug=False,
ssl=None,
sock=None,
protocol=None,
backlog=100,
stop_event=None,
access_log=True,
):
"""Asynchronous version of `run`.
NOTE: This does not support multiprocessing and is not the preferred
@ -760,24 +945,34 @@ class Sanic:
host, port = host or "127.0.0.1", port or 8000
if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled
else HttpProtocol)
protocol = (
WebSocketProtocol if self.websocket_enabled else HttpProtocol
)
if stop_event is not None:
if debug:
warnings.simplefilter('default')
warnings.warn("stop_event will be removed from future versions.",
DeprecationWarning)
warnings.simplefilter("default")
warnings.warn(
"stop_event will be removed from future versions.",
DeprecationWarning,
)
# compatibility old access_log params
self.config.ACCESS_LOG = access_log
server_settings = self._helper(
host=host, port=port, debug=debug, ssl=ssl, sock=sock,
loop=get_event_loop(), protocol=protocol,
backlog=backlog, run_async=True)
host=host,
port=port,
debug=debug,
ssl=ssl,
sock=sock,
loop=get_event_loop(),
protocol=protocol,
backlog=backlog,
run_async=True,
)
# Trigger before_start events
await self.trigger_events(
server_settings.get('before_start', []),
server_settings.get('loop')
server_settings.get("before_start", []),
server_settings.get("loop"),
)
return await serve(**server_settings)
@ -814,15 +1009,27 @@ class Sanic:
break
return response
def _helper(self, host=None, port=None, debug=False,
ssl=None, sock=None, workers=1, loop=None,
protocol=HttpProtocol, backlog=100, stop_event=None,
register_sys_signals=True, run_async=False, auto_reload=False):
def _helper(
self,
host=None,
port=None,
debug=False,
ssl=None,
sock=None,
workers=1,
loop=None,
protocol=HttpProtocol,
backlog=100,
stop_event=None,
register_sys_signals=True,
run_async=False,
auto_reload=False,
):
"""Helper function used by `run` and `create_server`."""
if isinstance(ssl, dict):
# try common aliaseses
cert = ssl.get('cert') or ssl.get('certificate')
key = ssl.get('key') or ssl.get('keyfile')
cert = ssl.get("cert") or ssl.get("certificate")
key = ssl.get("key") or ssl.get("keyfile")
if cert is None or key is None:
raise ValueError("SSLContext or certificate and key required.")
context = create_default_context(purpose=Purpose.CLIENT_AUTH)
@ -830,40 +1037,42 @@ class Sanic:
ssl = context
if stop_event is not None:
if debug:
warnings.simplefilter('default')
warnings.warn("stop_event will be removed from future versions.",
DeprecationWarning)
warnings.simplefilter("default")
warnings.warn(
"stop_event will be removed from future versions.",
DeprecationWarning,
)
self.error_handler.debug = debug
self.debug = debug
server_settings = {
'protocol': protocol,
'request_class': self.request_class,
'is_request_stream': self.is_request_stream,
'router': self.router,
'host': host,
'port': port,
'sock': sock,
'ssl': ssl,
'signal': Signal(),
'debug': debug,
'request_handler': self.handle_request,
'error_handler': self.error_handler,
'request_timeout': self.config.REQUEST_TIMEOUT,
'response_timeout': self.config.RESPONSE_TIMEOUT,
'keep_alive_timeout': self.config.KEEP_ALIVE_TIMEOUT,
'request_max_size': self.config.REQUEST_MAX_SIZE,
'keep_alive': self.config.KEEP_ALIVE,
'loop': loop,
'register_sys_signals': register_sys_signals,
'backlog': backlog,
'access_log': self.config.ACCESS_LOG,
'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE,
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE,
'websocket_read_limit': self.config.WEBSOCKET_READ_LIMIT,
'websocket_write_limit': self.config.WEBSOCKET_WRITE_LIMIT,
'graceful_shutdown_timeout': self.config.GRACEFUL_SHUTDOWN_TIMEOUT
"protocol": protocol,
"request_class": self.request_class,
"is_request_stream": self.is_request_stream,
"router": self.router,
"host": host,
"port": port,
"sock": sock,
"ssl": ssl,
"signal": Signal(),
"debug": debug,
"request_handler": self.handle_request,
"error_handler": self.error_handler,
"request_timeout": self.config.REQUEST_TIMEOUT,
"response_timeout": self.config.RESPONSE_TIMEOUT,
"keep_alive_timeout": self.config.KEEP_ALIVE_TIMEOUT,
"request_max_size": self.config.REQUEST_MAX_SIZE,
"keep_alive": self.config.KEEP_ALIVE,
"loop": loop,
"register_sys_signals": register_sys_signals,
"backlog": backlog,
"access_log": self.config.ACCESS_LOG,
"websocket_max_size": self.config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": self.config.WEBSOCKET_MAX_QUEUE,
"websocket_read_limit": self.config.WEBSOCKET_READ_LIMIT,
"websocket_write_limit": self.config.WEBSOCKET_WRITE_LIMIT,
"graceful_shutdown_timeout": self.config.GRACEFUL_SHUTDOWN_TIMEOUT,
}
# -------------------------------------------- #
@ -871,10 +1080,10 @@ class Sanic:
# -------------------------------------------- #
for event_name, settings_name, reverse in (
("before_server_start", "before_start", False),
("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True),
("before_server_start", "before_start", False),
("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True),
):
listeners = self.listeners[event_name].copy()
if reverse:
@ -886,18 +1095,20 @@ class Sanic:
if self.configure_logging and debug:
logger.setLevel(logging.DEBUG)
if self.config.LOGO is not None and \
os.environ.get('SANIC_SERVER_RUNNING') != 'true':
if (
self.config.LOGO is not None
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
):
logger.debug(self.config.LOGO)
if run_async:
server_settings['run_async'] = True
server_settings["run_async"] = True
# Serve
if host and port and os.environ.get('SANIC_SERVER_RUNNING') != 'true':
if host and port and os.environ.get("SANIC_SERVER_RUNNING") != "true":
proto = "http"
if ssl is not None:
proto = "https"
logger.info('Goin\' Fast @ {}://{}:{}'.format(proto, host, port))
logger.info("Goin' Fast @ {}://{}:{}".format(proto, host, port))
return server_settings

View File

@ -3,21 +3,36 @@ from collections import defaultdict, namedtuple
from sanic.constants import HTTP_METHODS
from sanic.views import CompositionView
FutureRoute = namedtuple('Route',
['handler', 'uri', 'methods', 'host',
'strict_slashes', 'stream', 'version', 'name'])
FutureListener = namedtuple('Listener', ['handler', 'uri', 'methods', 'host'])
FutureMiddleware = namedtuple('Route', ['middleware', 'args', 'kwargs'])
FutureException = namedtuple('Route', ['handler', 'args', 'kwargs'])
FutureStatic = namedtuple('Route',
['uri', 'file_or_directory', 'args', 'kwargs'])
FutureRoute = namedtuple(
"Route",
[
"handler",
"uri",
"methods",
"host",
"strict_slashes",
"stream",
"version",
"name",
],
)
FutureListener = namedtuple("Listener", ["handler", "uri", "methods", "host"])
FutureMiddleware = namedtuple("Route", ["middleware", "args", "kwargs"])
FutureException = namedtuple("Route", ["handler", "args", "kwargs"])
FutureStatic = namedtuple(
"Route", ["uri", "file_or_directory", "args", "kwargs"]
)
class Blueprint:
def __init__(self, name,
url_prefix=None,
host=None, version=None,
strict_slashes=False):
def __init__(
self,
name,
url_prefix=None,
host=None,
version=None,
strict_slashes=False,
):
"""Create a new blueprint
:param name: unique name of the blueprint
@ -38,13 +53,14 @@ class Blueprint:
self.strict_slashes = strict_slashes
@staticmethod
def group(*blueprints, url_prefix=''):
def group(*blueprints, url_prefix=""):
"""Create a list of blueprints, optionally
grouping them under a general URL prefix.
:param blueprints: blueprints to be registered as a group
:param url_prefix: URL route to be prepended to all sub-prefixes
"""
def chain(nested):
"""itertools.chain() but leaves strings untouched"""
for i in nested:
@ -52,10 +68,11 @@ class Blueprint:
yield from chain(i)
else:
yield i
bps = []
for bp in chain(blueprints):
if bp.url_prefix is None:
bp.url_prefix = ''
bp.url_prefix = ""
bp.url_prefix = url_prefix + bp.url_prefix
bps.append(bp)
return bps
@ -63,7 +80,7 @@ class Blueprint:
def register(self, app, options):
"""Register the blueprint to the sanic app."""
url_prefix = options.get('url_prefix', self.url_prefix)
url_prefix = options.get("url_prefix", self.url_prefix)
# Routes
for future in self.routes:
@ -75,14 +92,15 @@ class Blueprint:
version = future.version or self.version
app.route(uri=uri[1:] if uri.startswith('//') else uri,
methods=future.methods,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
stream=future.stream,
version=version,
name=future.name,
)(future.handler)
app.route(
uri=uri[1:] if uri.startswith("//") else uri,
methods=future.methods,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
stream=future.stream,
version=version,
name=future.name,
)(future.handler)
for future in self.websocket_routes:
# attach the blueprint name to the handler so that it can be
@ -90,18 +108,19 @@ class Blueprint:
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
app.websocket(uri=uri,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
name=future.name,
)(future.handler)
app.websocket(
uri=uri,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
name=future.name,
)(future.handler)
# Middleware
for future in self.middlewares:
if future.args or future.kwargs:
app.register_middleware(future.middleware,
*future.args,
**future.kwargs)
app.register_middleware(
future.middleware, *future.args, **future.kwargs
)
else:
app.register_middleware(future.middleware)
@ -113,16 +132,25 @@ class Blueprint:
for future in self.statics:
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
app.static(uri, future.file_or_directory,
*future.args, **future.kwargs)
app.static(
uri, future.file_or_directory, *future.args, **future.kwargs
)
# Event listeners
for event, listeners in self.listeners.items():
for listener in listeners:
app.listener(event)(listener)
def route(self, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=None, stream=False, version=None, name=None):
def route(
self,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
"""Create a blueprint route from a decorated function.
:param uri: endpoint at which the route will be accessible.
@ -133,14 +161,30 @@ class Blueprint:
def decorator(handler):
route = FutureRoute(
handler, uri, methods, host, strict_slashes, stream, version,
name)
handler,
uri,
methods,
host,
strict_slashes,
stream,
version,
name,
)
self.routes.append(route)
return handler
return decorator
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None,
strict_slashes=None, version=None, name=None):
def add_route(
self,
handler,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
version=None,
name=None,
):
"""Create a blueprint route from a function.
:param handler: function for handling uri requests. Accepts function,
@ -154,7 +198,7 @@ class Blueprint:
:return: function or class instance
"""
# Handle HTTPMethodView differently
if hasattr(handler, 'view_class'):
if hasattr(handler, "view_class"):
methods = set()
for method in HTTP_METHODS:
@ -168,13 +212,19 @@ class Blueprint:
if isinstance(handler, CompositionView):
methods = handler.handlers.keys()
self.route(uri=uri, methods=methods, host=host,
strict_slashes=strict_slashes, version=version,
name=name)(handler)
self.route(
uri=uri,
methods=methods,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)(handler)
return handler
def websocket(self, uri, host=None, strict_slashes=None, version=None,
name=None):
def websocket(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
"""Create a blueprint websocket route from a decorated function.
:param uri: endpoint at which the route will be accessible.
@ -183,14 +233,17 @@ class Blueprint:
strict_slashes = self.strict_slashes
def decorator(handler):
route = FutureRoute(handler, uri, [], host, strict_slashes,
False, version, name)
route = FutureRoute(
handler, uri, [], host, strict_slashes, False, version, name
)
self.websocket_routes.append(route)
return handler
return decorator
def add_websocket_route(self, handler, uri, host=None, version=None,
name=None):
def add_websocket_route(
self, handler, uri, host=None, version=None, name=None
):
"""Create a blueprint websocket route from a function.
:param handler: function for handling uri requests. Accepts function,
@ -206,13 +259,16 @@ class Blueprint:
:param event: Event to listen to.
"""
def decorator(listener):
self.listeners[event].append(listener)
return listener
return decorator
def middleware(self, *args, **kwargs):
"""Create a blueprint middleware from a decorated function."""
def register_middleware(_middleware):
future_middleware = FutureMiddleware(_middleware, args, kwargs)
self.middlewares.append(future_middleware)
@ -228,10 +284,12 @@ class Blueprint:
def exception(self, *args, **kwargs):
"""Create a blueprint exception from a decorated function."""
def decorator(handler):
exception = FutureException(handler, args, kwargs)
self.exceptions.append(exception)
return handler
return decorator
def static(self, uri, file_or_directory, *args, **kwargs):
@ -240,12 +298,12 @@ class Blueprint:
:param uri: endpoint at which the route will be accessible.
:param file_or_directory: Static asset.
"""
name = kwargs.pop('name', 'static')
if not name.startswith(self.name + '.'):
name = '{}.{}'.format(self.name, name)
name = kwargs.pop("name", "static")
if not name.startswith(self.name + "."):
name = "{}.{}".format(self.name, name)
kwargs.update(name=name)
strict_slashes = kwargs.get('strict_slashes')
strict_slashes = kwargs.get("strict_slashes")
if strict_slashes is None and self.strict_slashes is not None:
kwargs.update(strict_slashes=self.strict_slashes)
@ -253,44 +311,107 @@ class Blueprint:
self.statics.append(static)
# Shorthand method decorators
def get(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=["GET"], host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def get(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=["GET"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def post(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=["POST"], host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def post(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["POST"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def put(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=["PUT"], host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def put(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["PUT"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def head(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=["HEAD"], host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def head(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=["HEAD"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def options(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=["OPTIONS"], host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def options(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=["OPTIONS"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def patch(self, uri, host=None, strict_slashes=None, stream=False,
version=None, name=None):
return self.route(uri, methods=["PATCH"], host=host,
strict_slashes=strict_slashes, stream=stream,
version=version, name=name)
def patch(
self,
uri,
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["PATCH"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def delete(self, uri, host=None, strict_slashes=None, version=None,
name=None):
return self.route(uri, methods=["DELETE"], host=host,
strict_slashes=strict_slashes, version=version,
name=name)
def delete(
self, uri, host=None, strict_slashes=None, version=None, name=None
):
return self.route(
uri,
methods=["DELETE"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)

View File

@ -4,7 +4,7 @@ import types
from sanic.exceptions import PyFileError
SANIC_PREFIX = 'SANIC_'
SANIC_PREFIX = "SANIC_"
class Config(dict):
@ -65,9 +65,10 @@ class Config(dict):
"""
config_file = os.environ.get(variable_name)
if not config_file:
raise RuntimeError('The environment variable %r is not set and '
'thus configuration could not be loaded.' %
variable_name)
raise RuntimeError(
"The environment variable %r is not set and "
"thus configuration could not be loaded." % variable_name
)
return self.from_pyfile(config_file)
def from_pyfile(self, filename):
@ -76,14 +77,16 @@ class Config(dict):
:param filename: an absolute path to the config file
"""
module = types.ModuleType('config')
module = types.ModuleType("config")
module.__file__ = filename
try:
with open(filename) as config_file:
exec(compile(config_file.read(), filename, 'exec'),
module.__dict__)
exec(
compile(config_file.read(), filename, "exec"),
module.__dict__,
)
except IOError as e:
e.strerror = 'Unable to load configuration file (%s)' % e.strerror
e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise
except Exception as e:
raise PyFileError(filename) from e

View File

@ -1 +1 @@
HTTP_METHODS = ('GET', 'POST', 'PUT', 'HEAD', 'OPTIONS', 'PATCH', 'DELETE')
HTTP_METHODS = ("GET", "POST", "PUT", "HEAD", "OPTIONS", "PATCH", "DELETE")

View File

@ -8,14 +8,12 @@ import string
# Straight up copied this section of dark magic from SimpleCookie
_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:"
_UnescapedChars = _LegalChars + ' ()/<=>?@[]{}'
_UnescapedChars = _LegalChars + " ()/<=>?@[]{}"
_Translator = {n: '\\%03o' % n
for n in set(range(256)) - set(map(ord, _UnescapedChars))}
_Translator.update({
ord('"'): '\\"',
ord('\\'): '\\\\',
})
_Translator = {
n: "\\%03o" % n for n in set(range(256)) - set(map(ord, _UnescapedChars))
}
_Translator.update({ord('"'): '\\"', ord("\\"): "\\\\"})
def _quote(str):
@ -30,7 +28,7 @@ def _quote(str):
return '"' + str.translate(_Translator) + '"'
_is_legal_key = re.compile('[%s]+' % re.escape(_LegalChars)).fullmatch
_is_legal_key = re.compile("[%s]+" % re.escape(_LegalChars)).fullmatch
# ------------------------------------------------------------ #
# Custom SimpleCookie
@ -53,7 +51,7 @@ class CookieJar(dict):
# If this cookie doesn't exist, add it to the header keys
if not self.cookie_headers.get(key):
cookie = Cookie(key, value)
cookie['path'] = '/'
cookie["path"] = "/"
self.cookie_headers[key] = self.header_key
self.headers.add(self.header_key, cookie)
return super().__setitem__(key, cookie)
@ -62,8 +60,8 @@ class CookieJar(dict):
def __delitem__(self, key):
if key not in self.cookie_headers:
self[key] = ''
self[key]['max-age'] = 0
self[key] = ""
self[key]["max-age"] = 0
else:
cookie_header = self.cookie_headers[key]
# remove it from header
@ -77,6 +75,7 @@ class CookieJar(dict):
class Cookie(dict):
"""A stripped down version of Morsel from SimpleCookie #gottagofast"""
_keys = {
"expires": "expires",
"path": "Path",
@ -88,7 +87,7 @@ class Cookie(dict):
"version": "Version",
"samesite": "SameSite",
}
_flags = {'secure', 'httponly'}
_flags = {"secure", "httponly"}
def __init__(self, key, value):
if key in self._keys:
@ -106,24 +105,27 @@ class Cookie(dict):
return super().__setitem__(key, value)
def encode(self, encoding):
output = ['%s=%s' % (self.key, _quote(self.value))]
output = ["%s=%s" % (self.key, _quote(self.value))]
for key, value in self.items():
if key == 'max-age':
if key == "max-age":
try:
output.append('%s=%d' % (self._keys[key], value))
output.append("%s=%d" % (self._keys[key], value))
except TypeError:
output.append('%s=%s' % (self._keys[key], value))
elif key == 'expires':
output.append("%s=%s" % (self._keys[key], value))
elif key == "expires":
try:
output.append('%s=%s' % (
self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT")
))
output.append(
"%s=%s"
% (
self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT"),
)
)
except AttributeError:
output.append('%s=%s' % (self._keys[key], value))
output.append("%s=%s" % (self._keys[key], value))
elif key in self._flags and self[key]:
output.append(self._keys[key])
else:
output.append('%s=%s' % (self._keys[key], value))
output.append("%s=%s" % (self._keys[key], value))
return "; ".join(output).encode(encoding)

View File

@ -1,6 +1,6 @@
from sanic.helpers import STATUS_CODES
TRACEBACK_STYLE = '''
TRACEBACK_STYLE = """
<style>
body {
padding: 20px;
@ -61,9 +61,9 @@ TRACEBACK_STYLE = '''
font-size: 14px;
}
</style>
'''
"""
TRACEBACK_WRAPPER_HTML = '''
TRACEBACK_WRAPPER_HTML = """
<html>
<head>
{style}
@ -78,27 +78,27 @@ TRACEBACK_WRAPPER_HTML = '''
</div>
</body>
</html>
'''
"""
TRACEBACK_WRAPPER_INNER_HTML = '''
TRACEBACK_WRAPPER_INNER_HTML = """
<h1>{exc_name}</h1>
<h3><code>{exc_value}</code></h3>
<div class="tb-wrapper">
<p class="tb-header">Traceback (most recent call last):</p>
{frame_html}
</div>
'''
"""
TRACEBACK_BORDER = '''
TRACEBACK_BORDER = """
<div class="tb-border">
<b><i>
The above exception was the direct cause of the
following exception:
</i></b>
</div>
'''
"""
TRACEBACK_LINE_HTML = '''
TRACEBACK_LINE_HTML = """
<div class="frame-line">
<p class="frame-descriptor">
File {0.filename}, line <i>{0.lineno}</i>,
@ -106,15 +106,15 @@ TRACEBACK_LINE_HTML = '''
</p>
<p class="frame-code"><code>{0.line}</code></p>
</div>
'''
"""
INTERNAL_SERVER_ERROR_HTML = '''
INTERNAL_SERVER_ERROR_HTML = """
<h1>Internal Server Error</h1>
<p>
The server encountered an internal error and cannot complete
your request.
</p>
'''
"""
_sanic_exceptions = {}
@ -124,15 +124,16 @@ def add_status_code(code):
"""
Decorator used for adding exceptions to _sanic_exceptions.
"""
def class_decorator(cls):
cls.status_code = code
_sanic_exceptions[code] = cls
return cls
return class_decorator
class SanicException(Exception):
def __init__(self, message, status_code=None):
super().__init__(message)
@ -156,8 +157,8 @@ class MethodNotSupported(SanicException):
super().__init__(message)
self.headers = dict()
self.headers["Allow"] = ", ".join(allowed_methods)
if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']:
self.headers['Content-Length'] = 0
if method in ["HEAD", "PATCH", "PUT", "DELETE"]:
self.headers["Content-Length"] = 0
@add_status_code(500)
@ -169,6 +170,7 @@ class ServerError(SanicException):
class ServiceUnavailable(SanicException):
"""The server is currently unavailable (because it is overloaded or
down for maintenance). Generally, this is a temporary state."""
pass
@ -192,6 +194,7 @@ class RequestTimeout(SanicException):
the connection. The socket connection has actually been lost - the Web
server has 'timed out' on that particular socket connection.
"""
pass
@ -209,8 +212,8 @@ class ContentRangeError(SanicException):
def __init__(self, message, content_range):
super().__init__(message)
self.headers = {
'Content-Type': 'text/plain',
"Content-Range": "bytes */%s" % (content_range.total,)
"Content-Type": "text/plain",
"Content-Range": "bytes */%s" % (content_range.total,),
}
@ -225,7 +228,7 @@ class InvalidRangeType(ContentRangeError):
class PyFileError(Exception):
def __init__(self, file):
super().__init__('could not execute config file %s', file)
super().__init__("could not execute config file %s", file)
@add_status_code(401)
@ -263,13 +266,14 @@ class Unauthorized(SanicException):
scheme="Bearer",
realm="Restricted Area")
"""
def __init__(self, message, status_code=None, scheme=None, **kwargs):
super().__init__(message, status_code)
# if auth-scheme is specified, set "WWW-Authenticate" header
if scheme is not None:
values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()]
challenge = ', '.join(values)
challenge = ", ".join(values)
self.headers = {
"WWW-Authenticate": "{} {}".format(scheme, challenge).rstrip()
@ -288,6 +292,6 @@ def abort(status_code, message=None):
if message is None:
message = STATUS_CODES.get(status_code)
# These are stored as bytes in the STATUS_CODES dict
message = message.decode('utf8')
message = message.decode("utf8")
sanic_exception = _sanic_exceptions.get(status_code, SanicException)
raise sanic_exception(message=message, status_code=status_code)

View File

@ -11,7 +11,8 @@ from sanic.exceptions import (
TRACEBACK_STYLE,
TRACEBACK_WRAPPER_HTML,
TRACEBACK_WRAPPER_INNER_HTML,
TRACEBACK_BORDER)
TRACEBACK_BORDER,
)
from sanic.log import logger
from sanic.response import text, html
@ -36,7 +37,8 @@ class ErrorHandler:
return TRACEBACK_WRAPPER_INNER_HTML.format(
exc_name=exception.__class__.__name__,
exc_value=exception,
frame_html=''.join(frame_html))
frame_html="".join(frame_html),
)
def _render_traceback_html(self, exception, request):
exc_type, exc_value, tb = sys.exc_info()
@ -51,7 +53,8 @@ class ErrorHandler:
exc_name=exception.__class__.__name__,
exc_value=exception,
inner_html=TRACEBACK_BORDER.join(reversed(exceptions)),
path=request.path)
path=request.path,
)
def add(self, exception, handler):
self.handlers.append((exception, handler))
@ -88,17 +91,18 @@ class ErrorHandler:
url = repr(request.url)
except AttributeError:
url = "unknown"
response_message = ('Exception raised in exception handler '
'"%s" for uri: %s')
response_message = (
"Exception raised in exception handler " '"%s" for uri: %s'
)
logger.exception(response_message, handler.__name__, url)
if self.debug:
return text(response_message % (handler.__name__, url), 500)
else:
return text('An error occurred while handling an error', 500)
return text("An error occurred while handling an error", 500)
return response
def log(self, message, level='error'):
def log(self, message, level="error"):
"""
Deprecated, do not use.
"""
@ -110,14 +114,14 @@ class ErrorHandler:
except AttributeError:
url = "unknown"
response_message = ('Exception occurred while handling uri: %s')
response_message = "Exception occurred while handling uri: %s"
logger.exception(response_message, url)
if issubclass(type(exception), SanicException):
return text(
'Error: {}'.format(exception),
status=getattr(exception, 'status_code', 500),
headers=getattr(exception, 'headers', dict())
"Error: {}".format(exception),
status=getattr(exception, "status_code", 500),
headers=getattr(exception, "headers", dict()),
)
elif self.debug:
html_output = self._render_traceback_html(exception, request)
@ -129,32 +133,37 @@ class ErrorHandler:
class ContentRangeHandler:
"""Class responsible for parsing request header"""
__slots__ = ('start', 'end', 'size', 'total', 'headers')
__slots__ = ("start", "end", "size", "total", "headers")
def __init__(self, request, stats):
self.total = stats.st_size
_range = request.headers.get('Range')
_range = request.headers.get("Range")
if _range is None:
raise HeaderNotFound('Range Header Not Found')
unit, _, value = tuple(map(str.strip, _range.partition('=')))
if unit != 'bytes':
raise HeaderNotFound("Range Header Not Found")
unit, _, value = tuple(map(str.strip, _range.partition("=")))
if unit != "bytes":
raise InvalidRangeType(
'%s is not a valid Range Type' % (unit,), self)
start_b, _, end_b = tuple(map(str.strip, value.partition('-')))
"%s is not a valid Range Type" % (unit,), self
)
start_b, _, end_b = tuple(map(str.strip, value.partition("-")))
try:
self.start = int(start_b) if start_b else None
except ValueError:
raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (start_b,), self)
"'%s' is invalid for Content Range" % (start_b,), self
)
try:
self.end = int(end_b) if end_b else None
except ValueError:
raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (end_b,), self)
"'%s' is invalid for Content Range" % (end_b,), self
)
if self.end is None:
if self.start is None:
raise ContentRangeError(
'Invalid for Content Range parameters', self)
"Invalid for Content Range parameters", self
)
else:
# this case represents `Content-Range: bytes 5-`
self.end = self.total
@ -165,11 +174,13 @@ class ContentRangeHandler:
self.end = self.total
if self.start >= self.end:
raise ContentRangeError(
'Invalid for Content Range parameters', self)
"Invalid for Content Range parameters", self
)
self.size = self.end - self.start
self.headers = {
'Content-Range': "bytes %s-%s/%s" % (
self.start, self.end, self.total)}
"Content-Range": "bytes %s-%s/%s"
% (self.start, self.end, self.total)
}
def __bool__(self):
return self.size > 0

View File

@ -1,93 +1,97 @@
"""Defines basics of HTTP standard."""
STATUS_CODES = {
100: b'Continue',
101: b'Switching Protocols',
102: b'Processing',
200: b'OK',
201: b'Created',
202: b'Accepted',
203: b'Non-Authoritative Information',
204: b'No Content',
205: b'Reset Content',
206: b'Partial Content',
207: b'Multi-Status',
208: b'Already Reported',
226: b'IM Used',
300: b'Multiple Choices',
301: b'Moved Permanently',
302: b'Found',
303: b'See Other',
304: b'Not Modified',
305: b'Use Proxy',
307: b'Temporary Redirect',
308: b'Permanent Redirect',
400: b'Bad Request',
401: b'Unauthorized',
402: b'Payment Required',
403: b'Forbidden',
404: b'Not Found',
405: b'Method Not Allowed',
406: b'Not Acceptable',
407: b'Proxy Authentication Required',
408: b'Request Timeout',
409: b'Conflict',
410: b'Gone',
411: b'Length Required',
412: b'Precondition Failed',
413: b'Request Entity Too Large',
414: b'Request-URI Too Long',
415: b'Unsupported Media Type',
416: b'Requested Range Not Satisfiable',
417: b'Expectation Failed',
418: b'I\'m a teapot',
422: b'Unprocessable Entity',
423: b'Locked',
424: b'Failed Dependency',
426: b'Upgrade Required',
428: b'Precondition Required',
429: b'Too Many Requests',
431: b'Request Header Fields Too Large',
451: b'Unavailable For Legal Reasons',
500: b'Internal Server Error',
501: b'Not Implemented',
502: b'Bad Gateway',
503: b'Service Unavailable',
504: b'Gateway Timeout',
505: b'HTTP Version Not Supported',
506: b'Variant Also Negotiates',
507: b'Insufficient Storage',
508: b'Loop Detected',
510: b'Not Extended',
511: b'Network Authentication Required'
100: b"Continue",
101: b"Switching Protocols",
102: b"Processing",
200: b"OK",
201: b"Created",
202: b"Accepted",
203: b"Non-Authoritative Information",
204: b"No Content",
205: b"Reset Content",
206: b"Partial Content",
207: b"Multi-Status",
208: b"Already Reported",
226: b"IM Used",
300: b"Multiple Choices",
301: b"Moved Permanently",
302: b"Found",
303: b"See Other",
304: b"Not Modified",
305: b"Use Proxy",
307: b"Temporary Redirect",
308: b"Permanent Redirect",
400: b"Bad Request",
401: b"Unauthorized",
402: b"Payment Required",
403: b"Forbidden",
404: b"Not Found",
405: b"Method Not Allowed",
406: b"Not Acceptable",
407: b"Proxy Authentication Required",
408: b"Request Timeout",
409: b"Conflict",
410: b"Gone",
411: b"Length Required",
412: b"Precondition Failed",
413: b"Request Entity Too Large",
414: b"Request-URI Too Long",
415: b"Unsupported Media Type",
416: b"Requested Range Not Satisfiable",
417: b"Expectation Failed",
418: b"I'm a teapot",
422: b"Unprocessable Entity",
423: b"Locked",
424: b"Failed Dependency",
426: b"Upgrade Required",
428: b"Precondition Required",
429: b"Too Many Requests",
431: b"Request Header Fields Too Large",
451: b"Unavailable For Legal Reasons",
500: b"Internal Server Error",
501: b"Not Implemented",
502: b"Bad Gateway",
503: b"Service Unavailable",
504: b"Gateway Timeout",
505: b"HTTP Version Not Supported",
506: b"Variant Also Negotiates",
507: b"Insufficient Storage",
508: b"Loop Detected",
510: b"Not Extended",
511: b"Network Authentication Required",
}
# According to https://tools.ietf.org/html/rfc2616#section-7.1
_ENTITY_HEADERS = frozenset([
'allow',
'content-encoding',
'content-language',
'content-length',
'content-location',
'content-md5',
'content-range',
'content-type',
'expires',
'last-modified',
'extension-header'
])
_ENTITY_HEADERS = frozenset(
[
"allow",
"content-encoding",
"content-language",
"content-length",
"content-location",
"content-md5",
"content-range",
"content-type",
"expires",
"last-modified",
"extension-header",
]
)
# According to https://tools.ietf.org/html/rfc2616#section-13.5.1
_HOP_BY_HOP_HEADERS = frozenset([
'connection',
'keep-alive',
'proxy-authenticate',
'proxy-authorization',
'te',
'trailers',
'transfer-encoding',
'upgrade'
])
_HOP_BY_HOP_HEADERS = frozenset(
[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
]
)
def has_message_body(status):
@ -110,8 +114,7 @@ def is_hop_by_hop_header(header):
return header.lower() in _HOP_BY_HOP_HEADERS
def remove_entity_headers(headers,
allowed=('content-location', 'expires')):
def remove_entity_headers(headers, allowed=("content-location", "expires")):
"""
Removes all the entity headers present in the headers given.
According to RFC 2616 Section 10.3.5,
@ -122,7 +125,9 @@ def remove_entity_headers(headers,
returns the headers without the entity headers
"""
allowed = set([h.lower() for h in allowed])
headers = {header: value for header, value in headers.items()
if not is_entity_header(header)
and header.lower() not in allowed}
headers = {
header: value
for header, value in headers.items()
if not is_entity_header(header) and header.lower() not in allowed
}
return headers

View File

@ -5,59 +5,54 @@ import sys
LOGGING_CONFIG_DEFAULTS = dict(
version=1,
disable_existing_loggers=False,
loggers={
"root": {
"level": "INFO",
"handlers": ["console"]
},
"root": {"level": "INFO", "handlers": ["console"]},
"sanic.error": {
"level": "INFO",
"handlers": ["error_console"],
"propagate": True,
"qualname": "sanic.error"
"qualname": "sanic.error",
},
"sanic.access": {
"level": "INFO",
"handlers": ["access_console"],
"propagate": True,
"qualname": "sanic.access"
}
"qualname": "sanic.access",
},
},
handlers={
"console": {
"class": "logging.StreamHandler",
"formatter": "generic",
"stream": sys.stdout
"stream": sys.stdout,
},
"error_console": {
"class": "logging.StreamHandler",
"formatter": "generic",
"stream": sys.stderr
"stream": sys.stderr,
},
"access_console": {
"class": "logging.StreamHandler",
"formatter": "access",
"stream": sys.stdout
"stream": sys.stdout,
},
},
formatters={
"generic": {
"format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s",
"datefmt": "[%Y-%m-%d %H:%M:%S %z]",
"class": "logging.Formatter"
"class": "logging.Formatter",
},
"access": {
"format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " +
"%(request)s %(message)s %(status)d %(byte)d",
"format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: "
+ "%(request)s %(message)s %(status)d %(byte)d",
"datefmt": "[%Y-%m-%d %H:%M:%S %z]",
"class": "logging.Formatter"
"class": "logging.Formatter",
},
}
},
)
logger = logging.getLogger('sanic.root')
error_logger = logging.getLogger('sanic.error')
access_logger = logging.getLogger('sanic.access')
logger = logging.getLogger("sanic.root")
error_logger = logging.getLogger("sanic.error")
access_logger = logging.getLogger("sanic.access")

View File

@ -18,7 +18,7 @@ def _iter_module_files():
for module in list(sys.modules.values()):
if module is None:
continue
filename = getattr(module, '__file__', None)
filename = getattr(module, "__file__", None)
if filename:
old = None
while not os.path.isfile(filename):
@ -27,7 +27,7 @@ def _iter_module_files():
if filename == old:
break
else:
if filename[-4:] in ('.pyc', '.pyo'):
if filename[-4:] in (".pyc", ".pyo"):
filename = filename[:-1]
yield filename
@ -45,11 +45,13 @@ def restart_with_reloader():
"""
args = _get_args_for_reloading()
new_environ = os.environ.copy()
new_environ['SANIC_SERVER_RUNNING'] = 'true'
cmd = ' '.join(args)
new_environ["SANIC_SERVER_RUNNING"] = "true"
cmd = " ".join(args)
worker_process = Process(
target=subprocess.call, args=(cmd,),
kwargs=dict(shell=True, env=new_environ))
target=subprocess.call,
args=(cmd,),
kwargs=dict(shell=True, env=new_environ),
)
worker_process.start()
return worker_process
@ -67,8 +69,10 @@ def kill_process_children_unix(pid):
children_list_pid = children_list_file.read().split()
for child_pid in children_list_pid:
children_proc_path = "/proc/%s/task/%s/children" % \
(child_pid, child_pid)
children_proc_path = "/proc/%s/task/%s/children" % (
child_pid,
child_pid,
)
if not os.path.isfile(children_proc_path):
continue
with open(children_proc_path) as children_list_file_2:
@ -90,7 +94,7 @@ def kill_process_children_osx(pid):
:param pid: PID of parent process (process ID)
:return: Nothing
"""
subprocess.run(['pkill', '-P', str(pid)])
subprocess.run(["pkill", "-P", str(pid)])
def kill_process_children(pid):
@ -99,12 +103,12 @@ def kill_process_children(pid):
:param pid: PID of parent process (process ID)
:return: Nothing
"""
if sys.platform == 'darwin':
if sys.platform == "darwin":
kill_process_children_osx(pid)
elif sys.platform == 'linux':
elif sys.platform == "linux":
kill_process_children_unix(pid)
else:
pass # should signal error here
pass # should signal error here
def kill_program_completly(proc):
@ -127,9 +131,11 @@ def watchdog(sleep_interval):
mtimes = {}
worker_process = restart_with_reloader()
signal.signal(
signal.SIGTERM, lambda *args: kill_program_completly(worker_process))
signal.SIGTERM, lambda *args: kill_program_completly(worker_process)
)
signal.signal(
signal.SIGINT, lambda *args: kill_program_completly(worker_process))
signal.SIGINT, lambda *args: kill_program_completly(worker_process)
)
while True:
for filename in _iter_module_files():
try:

View File

@ -10,9 +10,11 @@ try:
from ujson import loads as json_loads
except ImportError:
if sys.version_info[:2] == (3, 5):
def json_loads(data):
# on Python 3.5 json.loads only supports str not bytes
return json.loads(data.decode())
else:
json_loads = json.loads
@ -43,11 +45,28 @@ class RequestParameters(dict):
class Request(dict):
"""Properties of an HTTP request such as URL, headers, etc."""
__slots__ = (
'app', 'headers', 'version', 'method', '_cookies', 'transport',
'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files',
'_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr',
'_socket', '_port', '__weakref__', 'raw_url'
"app",
"headers",
"version",
"method",
"_cookies",
"transport",
"body",
"parsed_json",
"parsed_args",
"parsed_form",
"parsed_files",
"_ip",
"_parsed_url",
"uri_template",
"stream",
"_remote_addr",
"_socket",
"_port",
"__weakref__",
"raw_url",
)
def __init__(self, url_bytes, headers, version, method, transport):
@ -73,10 +92,10 @@ class Request(dict):
def __repr__(self):
if self.method is None or not self.path:
return '<{0}>'.format(self.__class__.__name__)
return '<{0}: {1} {2}>'.format(self.__class__.__name__,
self.method,
self.path)
return "<{0}>".format(self.__class__.__name__)
return "<{0}: {1} {2}>".format(
self.__class__.__name__, self.method, self.path
)
def __bool__(self):
if self.transport:
@ -106,8 +125,8 @@ class Request(dict):
:return: token related to request
"""
prefixes = ('Bearer', 'Token')
auth_header = self.headers.get('Authorization')
prefixes = ("Bearer", "Token")
auth_header = self.headers.get("Authorization")
if auth_header is not None:
for prefix in prefixes:
@ -122,17 +141,20 @@ class Request(dict):
self.parsed_form = RequestParameters()
self.parsed_files = RequestParameters()
content_type = self.headers.get(
'Content-Type', DEFAULT_HTTP_CONTENT_TYPE)
"Content-Type", DEFAULT_HTTP_CONTENT_TYPE
)
content_type, parameters = parse_header(content_type)
try:
if content_type == 'application/x-www-form-urlencoded':
if content_type == "application/x-www-form-urlencoded":
self.parsed_form = RequestParameters(
parse_qs(self.body.decode('utf-8')))
elif content_type == 'multipart/form-data':
parse_qs(self.body.decode("utf-8"))
)
elif content_type == "multipart/form-data":
# TODO: Stream this instead of reading to/from memory
boundary = parameters['boundary'].encode('utf-8')
self.parsed_form, self.parsed_files = (
parse_multipart_form(self.body, boundary))
boundary = parameters["boundary"].encode("utf-8")
self.parsed_form, self.parsed_files = parse_multipart_form(
self.body, boundary
)
except Exception:
error_logger.exception("Failed when parsing form")
@ -150,7 +172,8 @@ class Request(dict):
if self.parsed_args is None:
if self.query_string:
self.parsed_args = RequestParameters(
parse_qs(self.query_string))
parse_qs(self.query_string)
)
else:
self.parsed_args = RequestParameters()
return self.parsed_args
@ -162,37 +185,40 @@ class Request(dict):
@property
def cookies(self):
if self._cookies is None:
cookie = self.headers.get('Cookie')
cookie = self.headers.get("Cookie")
if cookie is not None:
cookies = SimpleCookie()
cookies.load(cookie)
self._cookies = {name: cookie.value
for name, cookie in cookies.items()}
self._cookies = {
name: cookie.value for name, cookie in cookies.items()
}
else:
self._cookies = {}
return self._cookies
@property
def ip(self):
if not hasattr(self, '_socket'):
if not hasattr(self, "_socket"):
self._get_address()
return self._ip
@property
def port(self):
if not hasattr(self, '_socket'):
if not hasattr(self, "_socket"):
self._get_address()
return self._port
@property
def socket(self):
if not hasattr(self, '_socket'):
if not hasattr(self, "_socket"):
self._get_address()
return self._socket
def _get_address(self):
self._socket = self.transport.get_extra_info('peername') or \
(None, None)
self._socket = self.transport.get_extra_info("peername") or (
None,
None,
)
self._ip = self._socket[0]
self._port = self._socket[1]
@ -202,29 +228,31 @@ class Request(dict):
:return: original client ip.
"""
if not hasattr(self, '_remote_addr'):
forwarded_for = self.headers.get('X-Forwarded-For', '').split(',')
if not hasattr(self, "_remote_addr"):
forwarded_for = self.headers.get("X-Forwarded-For", "").split(",")
remote_addrs = [
addr for addr in [
addr.strip() for addr in forwarded_for
] if addr
]
addr
for addr in [addr.strip() for addr in forwarded_for]
if addr
]
if len(remote_addrs) > 0:
self._remote_addr = remote_addrs[0]
else:
self._remote_addr = ''
self._remote_addr = ""
return self._remote_addr
@property
def scheme(self):
if self.app.websocket_enabled \
and self.headers.get('upgrade') == 'websocket':
scheme = 'ws'
if (
self.app.websocket_enabled
and self.headers.get("upgrade") == "websocket"
):
scheme = "ws"
else:
scheme = 'http'
scheme = "http"
if self.transport.get_extra_info('sslcontext'):
scheme += 's'
if self.transport.get_extra_info("sslcontext"):
scheme += "s"
return scheme
@ -232,11 +260,11 @@ class Request(dict):
def host(self):
# it appears that httptools doesn't return the host
# so pull it from the headers
return self.headers.get('Host', '')
return self.headers.get("Host", "")
@property
def content_type(self):
return self.headers.get('Content-Type', DEFAULT_HTTP_CONTENT_TYPE)
return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE)
@property
def match_info(self):
@ -245,27 +273,23 @@ class Request(dict):
@property
def path(self):
return self._parsed_url.path.decode('utf-8')
return self._parsed_url.path.decode("utf-8")
@property
def query_string(self):
if self._parsed_url.query:
return self._parsed_url.query.decode('utf-8')
return self._parsed_url.query.decode("utf-8")
else:
return ''
return ""
@property
def url(self):
return urlunparse((
self.scheme,
self.host,
self.path,
None,
self.query_string,
None))
return urlunparse(
(self.scheme, self.host, self.path, None, self.query_string, None)
)
File = namedtuple('File', ['type', 'body', 'name'])
File = namedtuple("File", ["type", "body", "name"])
def parse_multipart_form(body, boundary):
@ -281,37 +305,38 @@ def parse_multipart_form(body, boundary):
form_parts = body.split(boundary)
for form_part in form_parts[1:-1]:
file_name = None
content_type = 'text/plain'
content_charset = 'utf-8'
content_type = "text/plain"
content_charset = "utf-8"
field_name = None
line_index = 2
line_end_index = 0
while not line_end_index == -1:
line_end_index = form_part.find(b'\r\n', line_index)
form_line = form_part[line_index:line_end_index].decode('utf-8')
line_end_index = form_part.find(b"\r\n", line_index)
form_line = form_part[line_index:line_end_index].decode("utf-8")
line_index = line_end_index + 2
if not form_line:
break
colon_index = form_line.index(':')
colon_index = form_line.index(":")
form_header_field = form_line[0:colon_index].lower()
form_header_value, form_parameters = parse_header(
form_line[colon_index + 2:])
form_line[colon_index + 2 :]
)
if form_header_field == 'content-disposition':
file_name = form_parameters.get('filename')
field_name = form_parameters.get('name')
elif form_header_field == 'content-type':
if form_header_field == "content-disposition":
file_name = form_parameters.get("filename")
field_name = form_parameters.get("name")
elif form_header_field == "content-type":
content_type = form_header_value
content_charset = form_parameters.get('charset', 'utf-8')
content_charset = form_parameters.get("charset", "utf-8")
if field_name:
post_data = form_part[line_index:-4]
if file_name:
form_file = File(type=content_type,
name=file_name,
body=post_data)
form_file = File(
type=content_type, name=file_name, body=post_data
)
if field_name in files:
files[field_name].append(form_file)
else:
@ -323,7 +348,9 @@ def parse_multipart_form(body, boundary):
else:
fields[field_name] = [value]
else:
logger.debug('Form-data field does not have a \'name\' parameter \
in the Content-Disposition header')
logger.debug(
"Form-data field does not have a 'name' parameter \
in the Content-Disposition header"
)
return fields, files

View File

@ -24,16 +24,18 @@ class BaseHTTPResponse:
return str(data).encode()
def _parse_headers(self):
headers = b''
headers = b""
for name, value in self.headers.items():
try:
headers += (
b'%b: %b\r\n' % (
name.encode(), value.encode('utf-8')))
headers += b"%b: %b\r\n" % (
name.encode(),
value.encode("utf-8"),
)
except AttributeError:
headers += (
b'%b: %b\r\n' % (
str(name).encode(), str(value).encode('utf-8')))
headers += b"%b: %b\r\n" % (
str(name).encode(),
str(value).encode("utf-8"),
)
return headers
@ -46,12 +48,17 @@ class BaseHTTPResponse:
class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'protocol', 'streaming_fn', 'status',
'content_type', 'headers', '_cookies'
"protocol",
"streaming_fn",
"status",
"content_type",
"headers",
"_cookies",
)
def __init__(self, streaming_fn, status=200, headers=None,
content_type='text/plain'):
def __init__(
self, streaming_fn, status=200, headers=None, content_type="text/plain"
):
self.content_type = content_type
self.streaming_fn = streaming_fn
self.status = status
@ -66,61 +73,69 @@ class StreamingHTTPResponse(BaseHTTPResponse):
if type(data) != bytes:
data = self._encode_body(data)
self.protocol.push_data(
b"%x\r\n%b\r\n" % (len(data), data))
self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
await self.protocol.drain()
async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
"""Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body.
"""
headers = self.get_headers(
version, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout)
version,
keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
)
self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
self.protocol.push_data(b'0\r\n\r\n')
self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.
def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
timeout_header = b""
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
self.headers['Transfer-Encoding'] = 'chunked'
self.headers.pop('Content-Length', None)
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop("Content-Length", None)
self.headers["Content-Type"] = self.headers.get(
"Content-Type", self.content_type
)
headers = self._parse_headers()
if self.status is 200:
status = b'OK'
status = b"OK"
else:
status = STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n'
b'%b'
b'%b\r\n') % (
version.encode(),
self.status,
status,
timeout_header,
headers
)
return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % (
version.encode(),
self.status,
status,
timeout_header,
headers,
)
class HTTPResponse(BaseHTTPResponse):
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies')
__slots__ = ("body", "status", "content_type", "headers", "_cookies")
def __init__(self, body=None, status=200, headers=None,
content_type='text/plain', body_bytes=b''):
def __init__(
self,
body=None,
status=200,
headers=None,
content_type="text/plain",
body_bytes=b"",
):
self.content_type = content_type
if body is not None:
@ -132,22 +147,23 @@ class HTTPResponse(BaseHTTPResponse):
self.headers = CIMultiDict(headers or {})
self._cookies = None
def output(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python
timeout_header = b''
timeout_header = b""
if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout
timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
body = b''
body = b""
if has_message_body(self.status):
body = self.body
self.headers['Content-Length'] = self.headers.get(
'Content-Length', len(self.body))
self.headers["Content-Length"] = self.headers.get(
"Content-Length", len(self.body)
)
self.headers['Content-Type'] = self.headers.get(
'Content-Type', self.content_type)
self.headers["Content-Type"] = self.headers.get(
"Content-Type", self.content_type
)
if self.status in (304, 412):
self.headers = remove_entity_headers(self.headers)
@ -155,23 +171,21 @@ class HTTPResponse(BaseHTTPResponse):
headers = self._parse_headers()
if self.status is 200:
status = b'OK'
status = b"OK"
else:
status = STATUS_CODES.get(self.status, b'UNKNOWN RESPONSE')
status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
return (b'HTTP/%b %d %b\r\n'
b'Connection: %b\r\n'
b'%b'
b'%b\r\n'
b'%b') % (
version.encode(),
self.status,
status,
b'keep-alive' if keep_alive else b'close',
timeout_header,
headers,
body
)
return (
b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b"
) % (
version.encode(),
self.status,
status,
b"keep-alive" if keep_alive else b"close",
timeout_header,
headers,
body,
)
@property
def cookies(self):
@ -180,9 +194,14 @@ class HTTPResponse(BaseHTTPResponse):
return self._cookies
def json(body, status=200, headers=None,
content_type="application/json", dumps=json_dumps,
**kwargs):
def json(
body,
status=200,
headers=None,
content_type="application/json",
dumps=json_dumps,
**kwargs
):
"""
Returns response object with body in json format.
@ -191,12 +210,17 @@ def json(body, status=200, headers=None,
:param headers: Custom Headers.
:param kwargs: Remaining arguments that are passed to the json encoder.
"""
return HTTPResponse(dumps(body, **kwargs), headers=headers,
status=status, content_type=content_type)
return HTTPResponse(
dumps(body, **kwargs),
headers=headers,
status=status,
content_type=content_type,
)
def text(body, status=200, headers=None,
content_type="text/plain; charset=utf-8"):
def text(
body, status=200, headers=None, content_type="text/plain; charset=utf-8"
):
"""
Returns response object with body in text format.
@ -206,12 +230,13 @@ def text(body, status=200, headers=None,
:param content_type: the content type (string) of the response
"""
return HTTPResponse(
body, status=status, headers=headers,
content_type=content_type)
body, status=status, headers=headers, content_type=content_type
)
def raw(body, status=200, headers=None,
content_type="application/octet-stream"):
def raw(
body, status=200, headers=None, content_type="application/octet-stream"
):
"""
Returns response object without encoding the body.
@ -220,8 +245,12 @@ def raw(body, status=200, headers=None,
:param headers: Custom Headers.
:param content_type: the content type (string) of the response.
"""
return HTTPResponse(body_bytes=body, status=status, headers=headers,
content_type=content_type)
return HTTPResponse(
body_bytes=body,
status=status,
headers=headers,
content_type=content_type,
)
def html(body, status=200, headers=None):
@ -232,12 +261,22 @@ def html(body, status=200, headers=None):
:param status: Response code.
:param headers: Custom Headers.
"""
return HTTPResponse(body, status=status, headers=headers,
content_type="text/html; charset=utf-8")
return HTTPResponse(
body,
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
)
async def file(location, status=200, mime_type=None, headers=None,
filename=None, _range=None):
async def file(
location,
status=200,
mime_type=None,
headers=None,
filename=None,
_range=None,
):
"""Return a response object with file data.
:param location: Location of file on system.
@ -249,28 +288,40 @@ async def file(location, status=200, mime_type=None, headers=None,
headers = headers or {}
if filename:
headers.setdefault(
'Content-Disposition',
'attachment; filename="{}"'.format(filename))
"Content-Disposition", 'attachment; filename="{}"'.format(filename)
)
filename = filename or path.split(location)[-1]
async with open_async(location, mode='rb') as _file:
async with open_async(location, mode="rb") as _file:
if _range:
await _file.seek(_range.start)
out_stream = await _file.read(_range.size)
headers['Content-Range'] = 'bytes %s-%s/%s' % (
_range.start, _range.end, _range.total)
headers["Content-Range"] = "bytes %s-%s/%s" % (
_range.start,
_range.end,
_range.total,
)
else:
out_stream = await _file.read()
mime_type = mime_type or guess_type(filename)[0] or 'text/plain'
return HTTPResponse(status=status,
headers=headers,
content_type=mime_type,
body_bytes=out_stream)
mime_type = mime_type or guess_type(filename)[0] or "text/plain"
return HTTPResponse(
status=status,
headers=headers,
content_type=mime_type,
body_bytes=out_stream,
)
async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
headers=None, filename=None, _range=None):
async def file_stream(
location,
status=200,
chunk_size=4096,
mime_type=None,
headers=None,
filename=None,
_range=None,
):
"""Return a streaming response object with file data.
:param location: Location of file on system.
@ -283,11 +334,11 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
headers = headers or {}
if filename:
headers.setdefault(
'Content-Disposition',
'attachment; filename="{}"'.format(filename))
"Content-Disposition", 'attachment; filename="{}"'.format(filename)
)
filename = filename or path.split(location)[-1]
_file = await open_async(location, mode='rb')
_file = await open_async(location, mode="rb")
async def _streaming_fn(response):
nonlocal _file, chunk_size
@ -312,19 +363,27 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
await _file.close()
return # Returning from this fn closes the stream
mime_type = mime_type or guess_type(filename)[0] or 'text/plain'
mime_type = mime_type or guess_type(filename)[0] or "text/plain"
if _range:
headers['Content-Range'] = 'bytes %s-%s/%s' % (
_range.start, _range.end, _range.total)
return StreamingHTTPResponse(streaming_fn=_streaming_fn,
status=status,
headers=headers,
content_type=mime_type)
headers["Content-Range"] = "bytes %s-%s/%s" % (
_range.start,
_range.end,
_range.total,
)
return StreamingHTTPResponse(
streaming_fn=_streaming_fn,
status=status,
headers=headers,
content_type=mime_type,
)
def stream(
streaming_fn, status=200, headers=None,
content_type="text/plain; charset=utf-8"):
streaming_fn,
status=200,
headers=None,
content_type="text/plain; charset=utf-8",
):
"""Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
@ -344,15 +403,13 @@ def stream(
:param headers: Custom Headers.
"""
return StreamingHTTPResponse(
streaming_fn,
headers=headers,
content_type=content_type,
status=status
streaming_fn, headers=headers, content_type=content_type, status=status
)
def redirect(to, headers=None, status=302,
content_type="text/html; charset=utf-8"):
def redirect(
to, headers=None, status=302, content_type="text/html; charset=utf-8"
):
"""Abort execution and cause a 302 redirect (by default).
:param to: path or fully qualified URL to redirect to
@ -367,9 +424,8 @@ def redirect(to, headers=None, status=302,
safe_to = quote_plus(to, safe=":/#?&=@[]!$&'()*+,;")
# According to RFC 7231, a relative URI is now permitted.
headers['Location'] = safe_to
headers["Location"] = safe_to
return HTTPResponse(
status=status,
headers=headers,
content_type=content_type)
status=status, headers=headers, content_type=content_type
)

View File

@ -9,25 +9,28 @@ from sanic.exceptions import NotFound, MethodNotSupported
from sanic.views import CompositionView
Route = namedtuple(
'Route',
['handler', 'methods', 'pattern', 'parameters', 'name', 'uri'])
Parameter = namedtuple('Parameter', ['name', 'cast'])
"Route", ["handler", "methods", "pattern", "parameters", "name", "uri"]
)
Parameter = namedtuple("Parameter", ["name", "cast"])
REGEX_TYPES = {
'string': (str, r'[^/]+'),
'int': (int, r'\d+'),
'number': (float, r'[0-9\\.]+'),
'alpha': (str, r'[A-Za-z]+'),
'path': (str, r'[^/].*?'),
'uuid': (uuid.UUID, r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-'
r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}')
"string": (str, r"[^/]+"),
"int": (int, r"\d+"),
"number": (float, r"[0-9\\.]+"),
"alpha": (str, r"[A-Za-z]+"),
"path": (str, r"[^/].*?"),
"uuid": (
uuid.UUID,
r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-"
r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}",
),
}
ROUTER_CACHE_SIZE = 1024
def url_hash(url):
return url.count('/')
return url.count("/")
class RouteExists(Exception):
@ -68,10 +71,11 @@ class Router:
also be passed in as the type. The argument given to the function will
always be a string, independent of the type.
"""
routes_static = None
routes_dynamic = None
routes_always_check = None
parameter_pattern = re.compile(r'<(.+?)>')
parameter_pattern = re.compile(r"<(.+?)>")
def __init__(self):
self.routes_all = {}
@ -98,9 +102,9 @@ class Router:
"""
# We could receive NAME or NAME:PATTERN
name = parameter_string
pattern = 'string'
if ':' in parameter_string:
name, pattern = parameter_string.split(':', 1)
pattern = "string"
if ":" in parameter_string:
name, pattern = parameter_string.split(":", 1)
if not name:
raise ValueError(
"Invalid parameter syntax: {}".format(parameter_string)
@ -112,8 +116,16 @@ class Router:
return name, _type, pattern
def add(self, uri, methods, handler, host=None, strict_slashes=False,
version=None, name=None):
def add(
self,
uri,
methods,
handler,
host=None,
strict_slashes=False,
version=None,
name=None,
):
"""Add a handler to the route list
:param uri: path to match
@ -127,8 +139,8 @@ class Router:
:return: Nothing
"""
if version is not None:
version = re.escape(str(version).strip('/').lstrip('v'))
uri = "/".join(["/v{}".format(version), uri.lstrip('/')])
version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
# add regular version
self._add(uri, methods, handler, host, name)
@ -143,28 +155,26 @@ class Router:
return
# Add versions with and without trailing /
slashed_methods = self.routes_all.get(uri + '/', frozenset({}))
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
unslashed_methods = self.routes_all.get(uri[:-1], frozenset({}))
if isinstance(methods, Iterable):
_slash_is_missing = all(method in slashed_methods for
method in methods)
_without_slash_is_missing = all(method in unslashed_methods for
method in methods)
_slash_is_missing = all(
method in slashed_methods for method in methods
)
_without_slash_is_missing = all(
method in unslashed_methods for method in methods
)
else:
_slash_is_missing = methods in slashed_methods
_without_slash_is_missing = methods in unslashed_methods
slash_is_missing = (
not uri[-1] == '/' and not _slash_is_missing
)
slash_is_missing = not uri[-1] == "/" and not _slash_is_missing
without_slash_is_missing = (
uri[-1] == '/' and not
_without_slash_is_missing and not
uri == '/'
uri[-1] == "/" and not _without_slash_is_missing and not uri == "/"
)
# add version with trailing slash
if slash_is_missing:
self._add(uri + '/', methods, handler, host, name)
self._add(uri + "/", methods, handler, host, name)
# add version without trailing slash
elif without_slash_is_missing:
self._add(uri[:-1], methods, handler, host, name)
@ -187,8 +197,10 @@ class Router:
else:
if not isinstance(host, Iterable):
raise ValueError("Expected either string or Iterable of "
"host strings, not {!r}".format(host))
raise ValueError(
"Expected either string or Iterable of "
"host strings, not {!r}".format(host)
)
for host_ in host:
self.add(uri, methods, handler, host_, name)
@ -208,38 +220,39 @@ class Router:
if name in parameter_names:
raise ParameterNameConflicts(
"Multiple parameter named <{name}> "
"in route uri {uri}".format(name=name, uri=uri))
"Multiple parameter named <{name}> "
"in route uri {uri}".format(name=name, uri=uri)
)
parameter_names.add(name)
parameter = Parameter(
name=name, cast=_type)
parameter = Parameter(name=name, cast=_type)
parameters.append(parameter)
# Mark the whole route as unhashable if it has the hash key in it
if re.search(r'(^|[^^]){1}/', pattern):
properties['unhashable'] = True
if re.search(r"(^|[^^]){1}/", pattern):
properties["unhashable"] = True
# Mark the route as unhashable if it matches the hash key
elif re.search(r'/', pattern):
properties['unhashable'] = True
elif re.search(r"/", pattern):
properties["unhashable"] = True
return '({})'.format(pattern)
return "({})".format(pattern)
pattern_string = re.sub(self.parameter_pattern, add_parameter, uri)
pattern = re.compile(r'^{}$'.format(pattern_string))
pattern = re.compile(r"^{}$".format(pattern_string))
def merge_route(route, methods, handler):
# merge to the existing route when possible.
if not route.methods or not methods:
# method-unspecified routes are not mergeable.
raise RouteExists(
"Route already registered: {}".format(uri))
raise RouteExists("Route already registered: {}".format(uri))
elif route.methods.intersection(methods):
# already existing method is not overloadable.
duplicated = methods.intersection(route.methods)
raise RouteExists(
"Route already registered: {} [{}]".format(
uri, ','.join(list(duplicated))))
uri, ",".join(list(duplicated))
)
)
if isinstance(route.handler, CompositionView):
view = route.handler
else:
@ -247,19 +260,22 @@ class Router:
view.add(route.methods, route.handler)
view.add(methods, handler)
route = route._replace(
handler=view, methods=methods.union(route.methods))
handler=view, methods=methods.union(route.methods)
)
return route
if parameters:
# TODO: This is too complex, we need to reduce the complexity
if properties['unhashable']:
if properties["unhashable"]:
routes_to_check = self.routes_always_check
ndx, route = self.check_dynamic_route_exists(
pattern, routes_to_check, parameters)
pattern, routes_to_check, parameters
)
else:
routes_to_check = self.routes_dynamic[url_hash(uri)]
ndx, route = self.check_dynamic_route_exists(
pattern, routes_to_check, parameters)
pattern, routes_to_check, parameters
)
if ndx != -1:
# Pop the ndx of the route, no dups of the same route
routes_to_check.pop(ndx)
@ -270,35 +286,41 @@ class Router:
# if available
# special prefix for static files
is_static = False
if name and name.startswith('_static_'):
if name and name.startswith("_static_"):
is_static = True
name = name.split('_static_', 1)[-1]
name = name.split("_static_", 1)[-1]
if hasattr(handler, '__blueprintname__'):
handler_name = '{}.{}'.format(
handler.__blueprintname__, name or handler.__name__)
if hasattr(handler, "__blueprintname__"):
handler_name = "{}.{}".format(
handler.__blueprintname__, name or handler.__name__
)
else:
handler_name = name or getattr(handler, '__name__', None)
handler_name = name or getattr(handler, "__name__", None)
if route:
route = merge_route(route, methods, handler)
else:
route = Route(
handler=handler, methods=methods, pattern=pattern,
parameters=parameters, name=handler_name, uri=uri)
handler=handler,
methods=methods,
pattern=pattern,
parameters=parameters,
name=handler_name,
uri=uri,
)
self.routes_all[uri] = route
if is_static:
pair = self.routes_static_files.get(handler_name)
if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])):
if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])):
self.routes_static_files[handler_name] = (uri, route)
else:
pair = self.routes_names.get(handler_name)
if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])):
if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])):
self.routes_names[handler_name] = (uri, route)
if properties['unhashable']:
if properties["unhashable"]:
self.routes_always_check.append(route)
elif parameters:
self.routes_dynamic[url_hash(uri)].append(route)
@ -333,8 +355,10 @@ class Router:
if route in self.routes_always_check:
self.routes_always_check.remove(route)
elif url_hash(uri) in self.routes_dynamic \
and route in self.routes_dynamic[url_hash(uri)]:
elif (
url_hash(uri) in self.routes_dynamic
and route in self.routes_dynamic[url_hash(uri)]
):
self.routes_dynamic[url_hash(uri)].remove(route)
else:
self.routes_static.pop(uri)
@ -353,7 +377,7 @@ class Router:
if not view_name:
return (None, None)
if view_name == 'static' or view_name.endswith('.static'):
if view_name == "static" or view_name.endswith(".static"):
return self.routes_static_files.get(name, (None, None))
return self.routes_names.get(view_name, (None, None))
@ -367,14 +391,15 @@ class Router:
"""
# No virtual hosts specified; default behavior
if not self.hosts:
return self._get(request.path, request.method, '')
return self._get(request.path, request.method, "")
# virtual hosts specified; try to match route to the host header
try:
return self._get(request.path, request.method,
request.headers.get("Host", ''))
return self._get(
request.path, request.method, request.headers.get("Host", "")
)
# try default hosts
except NotFound:
return self._get(request.path, request.method, '')
return self._get(request.path, request.method, "")
def get_supported_methods(self, url):
"""Get a list of supported methods for a url and optional host.
@ -384,7 +409,7 @@ class Router:
"""
route = self.routes_all.get(url)
# if methods are None then this logic will prevent an error
return getattr(route, 'methods', None) or frozenset()
return getattr(route, "methods", None) or frozenset()
@lru_cache(maxsize=ROUTER_CACHE_SIZE)
def _get(self, url, method, host):
@ -399,9 +424,10 @@ class Router:
# Check against known static routes
route = self.routes_static.get(url)
method_not_supported = MethodNotSupported(
'Method {} not allowed for URL {}'.format(method, url),
"Method {} not allowed for URL {}".format(method, url),
method=method,
allowed_methods=self.get_supported_methods(url))
allowed_methods=self.get_supported_methods(url),
)
if route:
if route.methods and method not in route.methods:
raise method_not_supported
@ -427,13 +453,14 @@ class Router:
# Route was found but the methods didn't match
if route_found:
raise method_not_supported
raise NotFound('Requested URL {} not found'.format(url))
raise NotFound("Requested URL {} not found".format(url))
kwargs = {p.name: p.cast(value)
for value, p
in zip(match.groups(1), route.parameters)}
kwargs = {
p.name: p.cast(value)
for value, p in zip(match.groups(1), route.parameters)
}
route_handler = route.handler
if hasattr(route_handler, 'handlers'):
if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri
@ -446,7 +473,8 @@ class Router:
handler = self.get(request)[0]
except (NotFound, MethodNotSupported):
return False
if (hasattr(handler, 'view_class') and
hasattr(handler.view_class, request.method.lower())):
if hasattr(handler, "view_class") and hasattr(
handler.view_class, request.method.lower()
):
handler = getattr(handler.view_class, request.method.lower())
return hasattr(handler, 'is_stream')
return hasattr(handler, "is_stream")

View File

@ -4,16 +4,8 @@ import traceback
from functools import partial
from inspect import isawaitable
from multiprocessing import Process
from signal import (
SIGTERM, SIGINT, SIG_IGN,
signal as signal_func,
Signals
)
from socket import (
socket,
SOL_SOCKET,
SO_REUSEADDR,
)
from signal import SIGTERM, SIGINT, SIG_IGN, signal as signal_func, Signals
from socket import socket, SOL_SOCKET, SO_REUSEADDR
from time import time
from httptools import HttpRequestParser
@ -22,6 +14,7 @@ from multidict import CIMultiDict
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
@ -30,8 +23,12 @@ from sanic.log import logger, access_logger
from sanic.response import HTTPResponse
from sanic.request import Request
from sanic.exceptions import (
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError,
ServiceUnavailable)
RequestTimeout,
PayloadTooLarge,
InvalidUsage,
ServerError,
ServiceUnavailable,
)
current_time = None
@ -43,27 +40,58 @@ class Signal:
class HttpProtocol(asyncio.Protocol):
__slots__ = (
# event loop, connection
'loop', 'transport', 'connections', 'signal',
"loop",
"transport",
"connections",
"signal",
# request params
'parser', 'request', 'url', 'headers',
"parser",
"request",
"url",
"headers",
# request config
'request_handler', 'request_timeout', 'response_timeout',
'keep_alive_timeout', 'request_max_size', 'request_class',
'is_request_stream', 'router',
"request_handler",
"request_timeout",
"response_timeout",
"keep_alive_timeout",
"request_max_size",
"request_class",
"is_request_stream",
"router",
# enable or disable access log purpose
'access_log',
"access_log",
# connection management
'_total_request_size', '_request_timeout_handler',
'_response_timeout_handler', '_keep_alive_timeout_handler',
'_last_request_time', '_last_response_time', '_is_stream_handler',
'_not_paused')
"_total_request_size",
"_request_timeout_handler",
"_response_timeout_handler",
"_keep_alive_timeout_handler",
"_last_request_time",
"_last_response_time",
"_is_stream_handler",
"_not_paused",
)
def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60,
response_timeout=60, keep_alive_timeout=5,
request_max_size=None, request_class=None, access_log=True,
keep_alive=True, is_request_stream=False, router=None,
state=None, debug=False, **kwargs):
def __init__(
self,
*,
loop,
request_handler,
error_handler,
signal=Signal(),
connections=set(),
request_timeout=60,
response_timeout=60,
keep_alive_timeout=5,
request_max_size=None,
request_class=None,
access_log=True,
keep_alive=True,
is_request_stream=False,
router=None,
state=None,
debug=False,
**kwargs
):
self.loop = loop
self.transport = None
self.request = None
@ -93,19 +121,20 @@ class HttpProtocol(asyncio.Protocol):
self._request_handler_task = None
self._request_stream_task = None
self._keep_alive = keep_alive
self._header_fragment = b''
self._header_fragment = b""
self.state = state if state else {}
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._debug = debug
self._not_paused.set()
@property
def keep_alive(self):
return (
self._keep_alive and
not self.signal.stopped and
self.parser.should_keep_alive())
self._keep_alive
and not self.signal.stopped
and self.parser.should_keep_alive()
)
# -------------------------------------------- #
# Connection
@ -114,7 +143,8 @@ class HttpProtocol(asyncio.Protocol):
def connection_made(self, transport):
self.connections.add(self)
self._request_timeout_handler = self.loop.call_later(
self.request_timeout, self.request_timeout_callback)
self.request_timeout, self.request_timeout_callback
)
self.transport = transport
self._last_request_time = current_time
@ -145,16 +175,15 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_request_time
if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = (
self.loop.call_later(time_left,
self.request_timeout_callback)
self._request_timeout_handler = self.loop.call_later(
time_left, self.request_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(RequestTimeout('Request Timeout'))
self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our
@ -162,16 +191,15 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_request_time
if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = (
self.loop.call_later(time_left,
self.response_timeout_callback)
self._response_timeout_handler = self.loop.call_later(
time_left, self.response_timeout_callback
)
else:
if self._request_stream_task:
self._request_stream_task.cancel()
if self._request_handler_task:
self._request_handler_task.cancel()
self.write_error(ServiceUnavailable('Response Timeout'))
self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self):
# Check if elapsed time since last response exceeds our configured
@ -179,12 +207,11 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_response_time
if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = (
self.loop.call_later(time_left,
self.keep_alive_timeout_callback)
self._keep_alive_timeout_handler = self.loop.call_later(
time_left, self.keep_alive_timeout_callback
)
else:
logger.debug('KeepAlive Timeout. Closing connection.')
logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close()
self.transport = None
@ -197,7 +224,7 @@ class HttpProtocol(asyncio.Protocol):
# memory limits
self._total_request_size += len(data)
if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge('Payload Too Large'))
self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data
if self.parser is None:
@ -206,15 +233,15 @@ class HttpProtocol(asyncio.Protocol):
self.parser = HttpRequestParser(self)
# requests count
self.state['requests_count'] = self.state['requests_count'] + 1
self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection
try:
self.parser.feed_data(data)
except HttpParserError:
message = 'Bad Request'
message = "Bad Request"
if self._debug:
message += '\n' + traceback.format_exc()
message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message))
def on_url(self, url):
@ -227,17 +254,20 @@ class HttpProtocol(asyncio.Protocol):
self._header_fragment += name
if value is not None:
if self._header_fragment == b'Content-Length' \
and int(value) > self.request_max_size:
self.write_error(PayloadTooLarge('Payload Too Large'))
if (
self._header_fragment == b"Content-Length"
and int(value) > self.request_max_size
):
self.write_error(PayloadTooLarge("Payload Too Large"))
try:
value = value.decode()
except UnicodeDecodeError:
value = value.decode('latin_1')
value = value.decode("latin_1")
self.headers.append(
(self._header_fragment.decode().casefold(), value))
(self._header_fragment.decode().casefold(), value)
)
self._header_fragment = b''
self._header_fragment = b""
def on_headers_complete(self):
self.request = self.request_class(
@ -245,7 +275,7 @@ class HttpProtocol(asyncio.Protocol):
headers=CIMultiDict(self.headers),
version=self.parser.get_http_version(),
method=self.parser.get_method().decode(),
transport=self.transport
transport=self.transport,
)
# Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request.
@ -254,7 +284,8 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = None
if self.is_request_stream:
self._is_stream_handler = self.router.is_stream_handler(
self.request)
self.request
)
if self._is_stream_handler:
self.request.stream = asyncio.Queue()
self.execute_request_handler()
@ -262,7 +293,8 @@ class HttpProtocol(asyncio.Protocol):
def on_body(self, body):
if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task(
self.request.stream.put(body))
self.request.stream.put(body)
)
return
self.request.body.append(body)
@ -274,47 +306,49 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task(
self.request.stream.put(None))
self.request.stream.put(None)
)
return
self.request.body = b''.join(self.request.body)
self.request.body = b"".join(self.request.body)
self.execute_request_handler()
def execute_request_handler(self):
self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback)
self.response_timeout, self.response_timeout_callback
)
self._last_request_time = current_time
self._request_handler_task = self.loop.create_task(
self.request_handler(
self.request,
self.write_response,
self.stream_response))
self.request, self.write_response, self.stream_response
)
)
# -------------------------------------------- #
# Responding
# -------------------------------------------- #
def log_response(self, response):
if self.access_log:
extra = {
'status': getattr(response, 'status', 0),
}
extra = {"status": getattr(response, "status", 0)}
if isinstance(response, HTTPResponse):
extra['byte'] = len(response.body)
extra["byte"] = len(response.body)
else:
extra['byte'] = -1
extra["byte"] = -1
extra['host'] = 'UNKNOWN'
extra["host"] = "UNKNOWN"
if self.request is not None:
if self.request.ip:
extra['host'] = '{0}:{1}'.format(self.request.ip,
self.request.port)
extra["host"] = "{0}:{1}".format(
self.request.ip, self.request.port
)
extra['request'] = '{0} {1}'.format(self.request.method,
self.request.url)
extra["request"] = "{0} {1}".format(
self.request.method, self.request.url
)
else:
extra['request'] = 'nil'
extra["request"] = "nil"
access_logger.info('', extra=extra)
access_logger.info("", extra=extra)
def write_response(self, response):
"""
@ -327,31 +361,37 @@ class HttpProtocol(asyncio.Protocol):
keep_alive = self.keep_alive
self.transport.write(
response.output(
self.request.version, keep_alive,
self.keep_alive_timeout))
self.request.version, keep_alive, self.keep_alive_timeout
)
)
self.log_response(response)
except AttributeError:
logger.error('Invalid response object for url %s, '
'Expected Type: HTTPResponse, Actual Type: %s',
self.url, type(response))
self.write_error(ServerError('Invalid response type'))
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self._debug:
logger.error('Connection lost before response written @ %s',
self.request.ip)
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(
repr(e)))
"Writing response failed, connection closed {}".format(repr(e))
)
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout,
self.keep_alive_timeout_callback)
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = current_time
self.cleanup()
@ -375,30 +415,36 @@ class HttpProtocol(asyncio.Protocol):
keep_alive = self.keep_alive
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout)
self.request.version, keep_alive, self.keep_alive_timeout
)
self.log_response(response)
except AttributeError:
logger.error('Invalid response object for url %s, '
'Expected Type: HTTPResponse, Actual Type: %s',
self.url, type(response))
self.write_error(ServerError('Invalid response type'))
logger.error(
"Invalid response object for url %s, "
"Expected Type: HTTPResponse, Actual Type: %s",
self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError:
if self._debug:
logger.error('Connection lost before response written @ %s',
self.request.ip)
logger.error(
"Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False
except Exception as e:
self.bail_out(
"Writing response failed, connection closed {}".format(
repr(e)))
"Writing response failed, connection closed {}".format(repr(e))
)
finally:
if not keep_alive:
self.transport.close()
self.transport = None
else:
self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout,
self.keep_alive_timeout_callback)
self.keep_alive_timeout, self.keep_alive_timeout_callback
)
self._last_response_time = current_time
self.cleanup()
@ -411,32 +457,37 @@ class HttpProtocol(asyncio.Protocol):
response = None
try:
response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else '1.1'
version = self.request.version if self.request else "1.1"
self.transport.write(response.output(version))
except RuntimeError:
if self._debug:
logger.error('Connection lost before error written @ %s',
self.request.ip if self.request else 'Unknown')
logger.error(
"Connection lost before error written @ %s",
self.request.ip if self.request else "Unknown",
)
except Exception as e:
self.bail_out(
"Writing error failed, connection closed {}".format(
repr(e)), from_error=True
"Writing error failed, connection closed {}".format(repr(e)),
from_error=True,
)
finally:
if self.parser and (self.keep_alive
or getattr(response, 'status', 0) == 408):
if self.parser and (
self.keep_alive or getattr(response, "status", 0) == 408
):
self.log_response(response)
try:
self.transport.close()
except AttributeError:
logger.debug('Connection lost before server could close it.')
logger.debug("Connection lost before server could close it.")
def bail_out(self, message, from_error=False):
if from_error or self.transport.is_closing():
logger.error("Transport closed @ %s and exception "
"experienced during error handling",
self.transport.get_extra_info('peername'))
logger.debug('Exception:', exc_info=True)
logger.error(
"Transport closed @ %s and exception "
"experienced during error handling",
self.transport.get_extra_info("peername"),
)
logger.debug("Exception:", exc_info=True)
else:
self.write_error(ServerError(message))
logger.error(message)
@ -497,17 +548,43 @@ def trigger_events(events, loop):
loop.run_until_complete(result)
def serve(host, port, request_handler, error_handler, before_start=None,
after_start=None, before_stop=None, after_stop=None, debug=False,
request_timeout=60, response_timeout=60, keep_alive_timeout=5,
ssl=None, sock=None, request_max_size=None, reuse_port=False,
loop=None, protocol=HttpProtocol, backlog=100,
register_sys_signals=True, run_multiple=False, run_async=False,
connections=None, signal=Signal(), request_class=None,
access_log=True, keep_alive=True, is_request_stream=False,
router=None, websocket_max_size=None, websocket_max_queue=None,
websocket_read_limit=2 ** 16, websocket_write_limit=2 ** 16,
state=None, graceful_shutdown_timeout=15.0):
def serve(
host,
port,
request_handler,
error_handler,
before_start=None,
after_start=None,
before_stop=None,
after_stop=None,
debug=False,
request_timeout=60,
response_timeout=60,
keep_alive_timeout=5,
ssl=None,
sock=None,
request_max_size=None,
reuse_port=False,
loop=None,
protocol=HttpProtocol,
backlog=100,
register_sys_signals=True,
run_multiple=False,
run_async=False,
connections=None,
signal=Signal(),
request_class=None,
access_log=True,
keep_alive=True,
is_request_stream=False,
router=None,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
state=None,
graceful_shutdown_timeout=15.0,
):
"""Start asynchronous HTTP Server on an individual process.
:param host: Address to host on
@ -592,7 +669,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
ssl=ssl,
reuse_port=reuse_port,
sock=sock,
backlog=backlog
backlog=backlog,
)
# Instead of pulling time at the end of every request,
@ -623,11 +700,13 @@ def serve(host, port, request_handler, error_handler, before_start=None,
try:
loop.add_signal_handler(_signal, loop.stop)
except NotImplementedError:
logger.warning('Sanic tried to use loop.add_signal_handler '
'but it is not implemented on this platform.')
logger.warning(
"Sanic tried to use loop.add_signal_handler "
"but it is not implemented on this platform."
)
pid = os.getpid()
try:
logger.info('Starting worker [%s]', pid)
logger.info("Starting worker [%s]", pid)
loop.run_forever()
finally:
logger.info("Stopping worker [%s]", pid)
@ -658,9 +737,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
coros = []
for conn in connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(
conn.websocket.close_connection()
)
coros.append(conn.websocket.close_connection())
else:
conn.close()
@ -681,18 +758,18 @@ def serve_multiple(server_settings, workers):
:param stop_event: if provided, is used as a stop signal
:return:
"""
server_settings['reuse_port'] = True
server_settings['run_multiple'] = True
server_settings["reuse_port"] = True
server_settings["run_multiple"] = True
# Handling when custom socket is not provided.
if server_settings.get('sock') is None:
if server_settings.get("sock") is None:
sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
sock.bind((server_settings['host'], server_settings['port']))
sock.bind((server_settings["host"], server_settings["port"]))
sock.set_inheritable(True)
server_settings['sock'] = sock
server_settings['host'] = None
server_settings['port'] = None
server_settings["sock"] = sock
server_settings["host"] = None
server_settings["port"] = None
def sig_handler(signal, frame):
logger.info("Received signal %s. Shutting down.", Signals(signal).name)
@ -716,4 +793,4 @@ def serve_multiple(server_settings, workers):
# the above processes will block this until they're stopped
for process in processes:
process.terminate()
server_settings.get('sock').close()
server_settings.get("sock").close()

View File

@ -16,10 +16,19 @@ from sanic.handlers import ContentRangeHandler
from sanic.response import file, file_stream, HTTPResponse
def register(app, uri, file_or_directory, pattern,
use_modified_since, use_content_range,
stream_large_files, name='static', host=None,
strict_slashes=None, content_type=None):
def register(
app,
uri,
file_or_directory,
pattern,
use_modified_since,
use_content_range,
stream_large_files,
name="static",
host=None,
strict_slashes=None,
content_type=None,
):
# TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching
@ -46,12 +55,12 @@ def register(app, uri, file_or_directory, pattern,
# If we're not trying to match a file directly,
# serve from the folder
if not path.isfile(file_or_directory):
uri += '<file_uri:' + pattern + '>'
uri += "<file_uri:" + pattern + ">"
async def _handler(request, file_uri=None):
# Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow
if file_uri and '../' in file_uri:
if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python
@ -59,15 +68,16 @@ def register(app, uri, file_or_directory, pattern,
root_path = file_path = file_or_directory
if file_uri:
file_path = path.join(
file_or_directory, sub('^[/]*', '', file_uri))
file_or_directory, sub("^[/]*", "", file_uri)
)
# URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))):
raise FileNotFound('File not found',
path=file_or_directory,
relative_url=file_uri)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
try:
headers = {}
# Check if the client has been sent this file before
@ -76,29 +86,31 @@ def register(app, uri, file_or_directory, pattern,
if use_modified_since:
stats = await stat(file_path)
modified_since = strftime(
'%a, %d %b %Y %H:%M:%S GMT', gmtime(stats.st_mtime))
if request.headers.get('If-Modified-Since') == modified_since:
"%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
)
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304)
headers['Last-Modified'] = modified_since
headers["Last-Modified"] = modified_since
_range = None
if use_content_range:
_range = None
if not stats:
stats = await stat(file_path)
headers['Accept-Ranges'] = 'bytes'
headers['Content-Length'] = str(stats.st_size)
if request.method != 'HEAD':
headers["Accept-Ranges"] = "bytes"
headers["Content-Length"] = str(stats.st_size)
if request.method != "HEAD":
try:
_range = ContentRangeHandler(request, stats)
except HeaderNotFound:
pass
else:
del headers['Content-Length']
del headers["Content-Length"]
for key, value in _range.headers.items():
headers[key] = value
headers['Content-Type'] = content_type \
or guess_type(file_path)[0] or 'text/plain'
if request.method == 'HEAD':
headers["Content-Type"] = (
content_type or guess_type(file_path)[0] or "text/plain"
)
if request.method == "HEAD":
return HTTPResponse(headers=headers)
else:
if stream_large_files:
@ -110,19 +122,25 @@ def register(app, uri, file_or_directory, pattern,
if not stats:
stats = await stat(file_path)
if stats.st_size >= threshold:
return await file_stream(file_path, headers=headers,
_range=_range)
return await file_stream(
file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range)
except ContentRangeError:
raise
except Exception:
raise FileNotFound('File not found',
path=file_or_directory,
relative_url=file_uri)
raise FileNotFound(
"File not found", path=file_or_directory, relative_url=file_uri
)
# special prefix for static files
if not name.startswith('_static_'):
name = '_static_{}'.format(name)
if not name.startswith("_static_"):
name = "_static_{}".format(name)
app.route(uri, methods=['GET', 'HEAD'], name=name, host=host,
strict_slashes=strict_slashes)(_handler)
app.route(
uri,
methods=["GET", "HEAD"],
name=name,
host=host,
strict_slashes=strict_slashes,
)(_handler)

View File

@ -4,7 +4,7 @@ from sanic.exceptions import MethodNotSupported
from sanic.response import text
HOST = '127.0.0.1'
HOST = "127.0.0.1"
PORT = 42101
@ -15,18 +15,22 @@ class SanicTestClient:
async def _local_request(self, method, uri, cookies=None, *args, **kwargs):
import aiohttp
if uri.startswith(('http:', 'https:', 'ftp:', 'ftps://' '//')):
if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")):
url = uri
else:
url = 'http://{host}:{port}{uri}'.format(
host=HOST, port=self.port, uri=uri)
url = "http://{host}:{port}{uri}".format(
host=HOST, port=self.port, uri=uri
)
logger.info(url)
conn = aiohttp.TCPConnector(verify_ssl=False)
async with aiohttp.ClientSession(
cookies=cookies, connector=conn) as session:
async with getattr(
session, method.lower())(url, *args, **kwargs) as response:
cookies=cookies, connector=conn
) as session:
async with getattr(session, method.lower())(
url, *args, **kwargs
) as response:
try:
response.text = await response.text()
except UnicodeDecodeError as e:
@ -34,50 +38,60 @@ class SanicTestClient:
try:
response.json = await response.json()
except (JSONDecodeError,
UnicodeDecodeError,
aiohttp.ClientResponseError):
except (
JSONDecodeError,
UnicodeDecodeError,
aiohttp.ClientResponseError,
):
response.json = None
response.body = await response.read()
return response
def _sanic_endpoint_test(
self, method='get', uri='/', gather_request=True,
debug=False, server_kwargs={"auto_reload": False},
*request_args, **request_kwargs):
self,
method="get",
uri="/",
gather_request=True,
debug=False,
server_kwargs={"auto_reload": False},
*request_args,
**request_kwargs
):
results = [None, None]
exceptions = []
if gather_request:
def _collect_request(request):
if results[0] is None:
results[0] = request
self.app.request_middleware.appendleft(_collect_request)
@self.app.exception(MethodNotSupported)
async def error_handler(request, exception):
if request.method in ['HEAD', 'PATCH', 'PUT', 'DELETE']:
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text(
'', exception.status_code, headers=exception.headers
"", exception.status_code, headers=exception.headers
)
else:
return self.app.error_handler.default(request, exception)
@self.app.listener('after_server_start')
@self.app.listener("after_server_start")
async def _collect_response(sanic, loop):
try:
response = await self._local_request(
method, uri, *request_args,
**request_kwargs)
method, uri, *request_args, **request_kwargs
)
results[-1] = response
except Exception as e:
logger.exception('Exception')
logger.exception("Exception")
exceptions.append(e)
self.app.stop()
self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs)
self.app.listeners['after_server_start'].pop()
self.app.listeners["after_server_start"].pop()
if exceptions:
raise ValueError("Exception during request: {}".format(exceptions))
@ -89,31 +103,34 @@ class SanicTestClient:
except BaseException:
raise ValueError(
"Request and response object expected, got ({})".format(
results))
results
)
)
else:
try:
return results[-1]
except BaseException:
raise ValueError(
"Request object expected, got ({})".format(results))
"Request object expected, got ({})".format(results)
)
def get(self, *args, **kwargs):
return self._sanic_endpoint_test('get', *args, **kwargs)
return self._sanic_endpoint_test("get", *args, **kwargs)
def post(self, *args, **kwargs):
return self._sanic_endpoint_test('post', *args, **kwargs)
return self._sanic_endpoint_test("post", *args, **kwargs)
def put(self, *args, **kwargs):
return self._sanic_endpoint_test('put', *args, **kwargs)
return self._sanic_endpoint_test("put", *args, **kwargs)
def delete(self, *args, **kwargs):
return self._sanic_endpoint_test('delete', *args, **kwargs)
return self._sanic_endpoint_test("delete", *args, **kwargs)
def patch(self, *args, **kwargs):
return self._sanic_endpoint_test('patch', *args, **kwargs)
return self._sanic_endpoint_test("patch", *args, **kwargs)
def options(self, *args, **kwargs):
return self._sanic_endpoint_test('options', *args, **kwargs)
return self._sanic_endpoint_test("options", *args, **kwargs)
def head(self, *args, **kwargs):
return self._sanic_endpoint_test('head', *args, **kwargs)
return self._sanic_endpoint_test("head", *args, **kwargs)

View File

@ -48,6 +48,7 @@ class HTTPMethodView:
"""Return view function for use with the routing system, that
dispatches request to appropriate handler method.
"""
def view(*args, **kwargs):
self = view.view_class(*class_args, **class_kwargs)
return self.dispatch_request(*args, **kwargs)
@ -94,11 +95,13 @@ class CompositionView:
for method in methods:
if method not in HTTP_METHODS:
raise InvalidUsage(
'{} is not a valid HTTP method.'.format(method))
"{} is not a valid HTTP method.".format(method)
)
if method in self.handlers:
raise InvalidUsage(
'Method {} is already registered.'.format(method))
"Method {} is already registered.".format(method)
)
self.handlers[method] = handler
def __call__(self, request, *args, **kwargs):

View File

@ -6,11 +6,16 @@ from websockets import ConnectionClosed # noqa
class WebSocketProtocol(HttpProtocol):
def __init__(self, *args, websocket_timeout=10,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16, **kwargs):
def __init__(
self,
*args,
websocket_timeout=10,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
**kwargs
):
super().__init__(*args, **kwargs)
self.websocket = None
self.websocket_timeout = websocket_timeout
@ -63,24 +68,26 @@ class WebSocketProtocol(HttpProtocol):
key = handshake.check_request(request.headers)
handshake.build_response(headers, key)
except InvalidHandshake:
raise InvalidUsage('Invalid websocket request')
raise InvalidUsage("Invalid websocket request")
subprotocol = None
if subprotocols and 'Sec-Websocket-Protocol' in request.headers:
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
# select a subprotocol
client_subprotocols = [p.strip() for p in request.headers[
'Sec-Websocket-Protocol'].split(',')]
client_subprotocols = [
p.strip()
for p in request.headers["Sec-Websocket-Protocol"].split(",")
]
for p in client_subprotocols:
if p in subprotocols:
subprotocol = p
headers['Sec-Websocket-Protocol'] = subprotocol
headers["Sec-Websocket-Protocol"] = subprotocol
break
# write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n'
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
for k, v in headers.items():
rv += k.encode('utf-8') + b': ' + v.encode('utf-8') + b'\r\n'
rv += b'\r\n'
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
rv += b"\r\n"
request.transport.write(rv)
# hook up the websocket protocol
@ -89,7 +96,7 @@ class WebSocketProtocol(HttpProtocol):
max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue,
read_limit=self.websocket_read_limit,
write_limit=self.websocket_write_limit
write_limit=self.websocket_write_limit,
)
self.websocket.subprotocol = subprotocol
self.websocket.connection_made(request.transport)

View File

@ -12,6 +12,7 @@ except ImportError:
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
@ -50,36 +51,43 @@ class GunicornWorker(base.Worker):
def run(self):
is_debug = self.log.loglevel == logging.DEBUG
protocol = (
self.websocket_protocol if self.app.callable.websocket_enabled
else self.http_protocol)
self.websocket_protocol
if self.app.callable.websocket_enabled
else self.http_protocol
)
self._server_settings = self.app.callable._helper(
loop=self.loop,
debug=is_debug,
protocol=protocol,
ssl=self.ssl_context,
run_async=True)
self._server_settings['signal'] = self.signal
self._server_settings.pop('sock')
trigger_events(self._server_settings.get('before_start', []),
self.loop)
self._server_settings['before_start'] = ()
run_async=True,
)
self._server_settings["signal"] = self.signal
self._server_settings.pop("sock")
trigger_events(
self._server_settings.get("before_start", []), self.loop
)
self._server_settings["before_start"] = ()
self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try:
self.loop.run_until_complete(self._runner)
self.app.callable.is_running = True
trigger_events(self._server_settings.get('after_start', []),
self.loop)
trigger_events(
self._server_settings.get("after_start", []), self.loop
)
self.loop.run_until_complete(self._check_alive())
trigger_events(self._server_settings.get('before_stop', []),
self.loop)
trigger_events(
self._server_settings.get("before_stop", []), self.loop
)
self.loop.run_until_complete(self.close())
except BaseException:
traceback.print_exc()
finally:
try:
trigger_events(self._server_settings.get('after_stop', []),
self.loop)
trigger_events(
self._server_settings.get("after_stop", []), self.loop
)
except BaseException:
traceback.print_exc()
finally:
@ -90,8 +98,11 @@ class GunicornWorker(base.Worker):
async def close(self):
if self.servers:
# stop accepting connections
self.log.info("Stopping server: %s, connections: %s",
self.pid, len(self.connections))
self.log.info(
"Stopping server: %s, connections: %s",
self.pid,
len(self.connections),
)
for server in self.servers:
server.close()
await server.wait_closed()
@ -105,8 +116,9 @@ class GunicornWorker(base.Worker):
# gracefully shutdown timeout
start_shutdown = 0
graceful_shutdown_timeout = self.cfg.graceful_timeout
while self.connections and \
(start_shutdown < graceful_shutdown_timeout):
while self.connections and (
start_shutdown < graceful_shutdown_timeout
):
await asyncio.sleep(0.1)
start_shutdown = start_shutdown + 0.1
@ -115,9 +127,7 @@ class GunicornWorker(base.Worker):
coros = []
for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket:
coros.append(
conn.websocket.close_connection()
)
coros.append(conn.websocket.close_connection())
else:
conn.close()
_shutdown = asyncio.gather(*coros, loop=self.loop)
@ -148,8 +158,9 @@ class GunicornWorker(base.Worker):
)
if self.max_requests and req_count > self.max_requests:
self.alive = False
self.log.info("Max requests exceeded, shutting down: %s",
self)
self.log.info(
"Max requests exceeded, shutting down: %s", self
)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
@ -175,23 +186,29 @@ class GunicornWorker(base.Worker):
def init_signals(self):
# Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
signal.SIGQUIT, None)
self.loop.add_signal_handler(
signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
signal.SIGTERM, None)
self.loop.add_signal_handler(
signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
signal.SIGINT, None)
self.loop.add_signal_handler(
signal.SIGINT, self.handle_quit, signal.SIGINT, None
)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
signal.SIGWINCH, None)
self.loop.add_signal_handler(
signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
signal.SIGUSR1, None)
self.loop.add_signal_handler(
signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
signal.SIGABRT, None)
self.loop.add_signal_handler(
signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls