run black against sanic module

This commit is contained in:
Yun Xu 2018-10-13 17:55:33 -07:00
parent 9ae6dfb6d2
commit aa9bf04dfe
22 changed files with 1657 additions and 1036 deletions

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.black]
line-length = 79

View File

@ -1,6 +1,6 @@
from sanic.app import Sanic from sanic.app import Sanic
from sanic.blueprints import Blueprint from sanic.blueprints import Blueprint
__version__ = '0.8.3' __version__ = "0.8.3"
__all__ = ['Sanic', 'Blueprint'] __all__ = ["Sanic", "Blueprint"]

View File

@ -5,16 +5,18 @@ from sanic.log import logger
from sanic.app import Sanic from sanic.app import Sanic
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser(prog='sanic') parser = ArgumentParser(prog="sanic")
parser.add_argument('--host', dest='host', type=str, default='127.0.0.1') parser.add_argument("--host", dest="host", type=str, default="127.0.0.1")
parser.add_argument('--port', dest='port', type=int, default=8000) parser.add_argument("--port", dest="port", type=int, default=8000)
parser.add_argument('--cert', dest='cert', type=str, parser.add_argument(
help='location of certificate for SSL') "--cert", dest="cert", type=str, help="location of certificate for SSL"
parser.add_argument('--key', dest='key', type=str, )
help='location of keyfile for SSL.') parser.add_argument(
parser.add_argument('--workers', dest='workers', type=int, default=1, ) "--key", dest="key", type=str, help="location of keyfile for SSL."
parser.add_argument('--debug', dest='debug', action="store_true") )
parser.add_argument('module') parser.add_argument("--workers", dest="workers", type=int, default=1)
parser.add_argument("--debug", dest="debug", action="store_true")
parser.add_argument("module")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -25,20 +27,29 @@ if __name__ == "__main__":
module = import_module(module_name) module = import_module(module_name)
app = getattr(module, app_name, None) app = getattr(module, app_name, None)
if not isinstance(app, Sanic): if not isinstance(app, Sanic):
raise ValueError("Module is not a Sanic app, it is a {}. " raise ValueError(
"Perhaps you meant {}.app?" "Module is not a Sanic app, it is a {}. "
.format(type(app).__name__, args.module)) "Perhaps you meant {}.app?".format(
type(app).__name__, args.module
)
)
if args.cert is not None or args.key is not None: if args.cert is not None or args.key is not None:
ssl = {'cert': args.cert, 'key': args.key} ssl = {"cert": args.cert, "key": args.key}
else: else:
ssl = None ssl = None
app.run(host=args.host, port=args.port, app.run(
workers=args.workers, debug=args.debug, ssl=ssl) host=args.host,
port=args.port,
workers=args.workers,
debug=args.debug,
ssl=ssl,
)
except ImportError as e: except ImportError as e:
logger.error("No module named {} found.\n" logger.error(
" Example File: project/sanic_server.py -> app\n" "No module named {} found.\n"
" Example Module: project.sanic_server.app" " Example File: project/sanic_server.py -> app\n"
.format(e.name)) " Example Module: project.sanic_server.app".format(e.name)
)
except ValueError as e: except ValueError as e:
logger.exception("Failed to run app") logger.exception("Failed to run app")

View File

@ -27,10 +27,17 @@ import sanic.reloader_helpers as reloader_helpers
class Sanic: class Sanic:
def __init__(self, name=None, router=None, error_handler=None, def __init__(
load_env=True, request_class=None, self,
strict_slashes=False, log_config=None, name=None,
configure_logging=True): router=None,
error_handler=None,
load_env=True,
request_class=None,
strict_slashes=False,
log_config=None,
configure_logging=True,
):
# Get name from previous stack frame # Get name from previous stack frame
if name is None: if name is None:
@ -71,8 +78,9 @@ class Sanic:
""" """
if not self.is_running: if not self.is_running:
raise SanicException( raise SanicException(
'Loop can only be retrieved after the app has started ' "Loop can only be retrieved after the app has started "
'running. Not supported with `create_server` function') "running. Not supported with `create_server` function"
)
return get_event_loop() return get_event_loop()
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
@ -96,7 +104,8 @@ class Sanic:
else: else:
self.loop.create_task(task) self.loop.create_task(task)
except SanicException: except SanicException:
@self.listener('before_server_start')
@self.listener("before_server_start")
def run(app, loop): def run(app, loop):
if callable(task): if callable(task):
try: try:
@ -133,8 +142,16 @@ class Sanic:
return self.listener(event)(listener) return self.listener(event)(listener)
# Decorator # Decorator
def route(self, uri, methods=frozenset({'GET'}), host=None, def route(
strict_slashes=None, stream=False, version=None, name=None): self,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
"""Decorate a function to be registered as a route """Decorate a function to be registered as a route
:param uri: path of the URL :param uri: path of the URL
@ -149,8 +166,8 @@ class Sanic:
# Fix case where the user did not prefix the URL with a / # Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working # and will probably get confused as to why it's not working
if not uri.startswith('/'): if not uri.startswith("/"):
uri = '/' + uri uri = "/" + uri
if stream: if stream:
self.is_request_stream = True self.is_request_stream = True
@ -164,63 +181,141 @@ class Sanic:
if stream: if stream:
handler.is_stream = stream handler.is_stream = stream
self.router.add(uri=uri, methods=methods, handler=handler, self.router.add(
host=host, strict_slashes=strict_slashes, uri=uri,
version=version, name=name) methods=methods,
handler=handler,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
return handler return handler
else: else:
raise ValueError( raise ValueError(
'Required parameter `request` missing ' "Required parameter `request` missing "
'in the {0}() route?'.format( "in the {0}() route?".format(handler.__name__)
handler.__name__)) )
return response return response
# Shorthand method decorators # Shorthand method decorators
def get(self, uri, host=None, strict_slashes=None, version=None, def get(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=frozenset({"GET"}), host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=frozenset({"GET"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def post(self, uri, host=None, strict_slashes=None, stream=False, def post(
version=None, name=None): self,
return self.route(uri, methods=frozenset({"POST"}), host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"POST"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def put(self, uri, host=None, strict_slashes=None, stream=False, def put(
version=None, name=None): self,
return self.route(uri, methods=frozenset({"PUT"}), host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"PUT"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def head(self, uri, host=None, strict_slashes=None, version=None, def head(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=frozenset({"HEAD"}), host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=frozenset({"HEAD"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def options(self, uri, host=None, strict_slashes=None, version=None, def options(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=frozenset({"OPTIONS"}), host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=frozenset({"OPTIONS"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def patch(self, uri, host=None, strict_slashes=None, stream=False, def patch(
version=None, name=None): self,
return self.route(uri, methods=frozenset({"PATCH"}), host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=frozenset({"PATCH"}),
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def delete(self, uri, host=None, strict_slashes=None, version=None, def delete(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=frozenset({"DELETE"}), host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=frozenset({"DELETE"}),
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None, def add_route(
strict_slashes=None, version=None, name=None, stream=False): self,
handler,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
version=None,
name=None,
stream=False,
):
"""A helper method to register class instance or """A helper method to register class instance or
functions as a handler to the application url functions as a handler to the application url
routes. routes.
@ -237,35 +332,42 @@ class Sanic:
:return: function or class instance :return: function or class instance
""" """
# Handle HTTPMethodView differently # Handle HTTPMethodView differently
if hasattr(handler, 'view_class'): if hasattr(handler, "view_class"):
methods = set() methods = set()
for method in HTTP_METHODS: for method in HTTP_METHODS:
_handler = getattr(handler.view_class, method.lower(), None) _handler = getattr(handler.view_class, method.lower(), None)
if _handler: if _handler:
methods.add(method) methods.add(method)
if hasattr(_handler, 'is_stream'): if hasattr(_handler, "is_stream"):
stream = True stream = True
# handle composition view differently # handle composition view differently
if isinstance(handler, CompositionView): if isinstance(handler, CompositionView):
methods = handler.handlers.keys() methods = handler.handlers.keys()
for _handler in handler.handlers.values(): for _handler in handler.handlers.values():
if hasattr(_handler, 'is_stream'): if hasattr(_handler, "is_stream"):
stream = True stream = True
break break
if strict_slashes is None: if strict_slashes is None:
strict_slashes = self.strict_slashes strict_slashes = self.strict_slashes
self.route(uri=uri, methods=methods, host=host, self.route(
strict_slashes=strict_slashes, stream=stream, uri=uri,
version=version, name=name)(handler) methods=methods,
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)(handler)
return handler return handler
# Decorator # Decorator
def websocket(self, uri, host=None, strict_slashes=None, def websocket(
subprotocols=None, name=None): self, uri, host=None, strict_slashes=None, subprotocols=None, name=None
):
"""Decorate a function to be registered as a websocket route """Decorate a function to be registered as a websocket route
:param uri: path of the URL :param uri: path of the URL
:param subprotocols: optional list of strings with the supported :param subprotocols: optional list of strings with the supported
@ -277,8 +379,8 @@ class Sanic:
# Fix case where the user did not prefix the URL with a / # Fix case where the user did not prefix the URL with a /
# and will probably get confused as to why it's not working # and will probably get confused as to why it's not working
if not uri.startswith('/'): if not uri.startswith("/"):
uri = '/' + uri uri = "/" + uri
if strict_slashes is None: if strict_slashes is None:
strict_slashes = self.strict_slashes strict_slashes = self.strict_slashes
@ -307,21 +409,38 @@ class Sanic:
self.websocket_tasks.remove(fut) self.websocket_tasks.remove(fut)
await ws.close() await ws.close()
self.router.add(uri=uri, handler=websocket_handler, self.router.add(
methods=frozenset({'GET'}), host=host, uri=uri,
strict_slashes=strict_slashes, name=name) handler=websocket_handler,
methods=frozenset({"GET"}),
host=host,
strict_slashes=strict_slashes,
name=name,
)
return handler return handler
return response return response
def add_websocket_route(self, handler, uri, host=None, def add_websocket_route(
strict_slashes=None, subprotocols=None, name=None): self,
handler,
uri,
host=None,
strict_slashes=None,
subprotocols=None,
name=None,
):
"""A helper method to register a function as a websocket route.""" """A helper method to register a function as a websocket route."""
if strict_slashes is None: if strict_slashes is None:
strict_slashes = self.strict_slashes strict_slashes = self.strict_slashes
return self.websocket(uri, host=host, strict_slashes=strict_slashes, return self.websocket(
subprotocols=subprotocols, name=name)(handler) uri,
host=host,
strict_slashes=strict_slashes,
subprotocols=subprotocols,
name=name,
)(handler)
def enable_websocket(self, enable=True): def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket. """Enable or disable the support for websocket.
@ -332,7 +451,7 @@ class Sanic:
if not self.websocket_enabled: if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing # if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly # websocket tasks, to allow the server to exit promptly
@self.listener('before_server_stop') @self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop): def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks: for task in self.websocket_tasks:
task.cancel() task.cancel()
@ -361,10 +480,10 @@ class Sanic:
return response return response
def register_middleware(self, middleware, attach_to='request'): def register_middleware(self, middleware, attach_to="request"):
if attach_to == 'request': if attach_to == "request":
self.request_middleware.append(middleware) self.request_middleware.append(middleware)
if attach_to == 'response': if attach_to == "response":
self.response_middleware.appendleft(middleware) self.response_middleware.appendleft(middleware)
return middleware return middleware
@ -379,21 +498,40 @@ class Sanic:
return self.register_middleware(middleware_or_request) return self.register_middleware(middleware_or_request)
else: else:
return partial(self.register_middleware, return partial(
attach_to=middleware_or_request) self.register_middleware, attach_to=middleware_or_request
)
# Static Files # Static Files
def static(self, uri, file_or_directory, pattern=r'/?.+', def static(
use_modified_since=True, use_content_range=False, self,
stream_large_files=False, name='static', host=None, uri,
strict_slashes=None, content_type=None): file_or_directory,
pattern=r"/?.+",
use_modified_since=True,
use_content_range=False,
stream_large_files=False,
name="static",
host=None,
strict_slashes=None,
content_type=None,
):
"""Register a root to serve files from. The input can either be a """Register a root to serve files from. The input can either be a
file or a directory. See file or a directory. See
""" """
static_register(self, uri, file_or_directory, pattern, static_register(
use_modified_since, use_content_range, self,
stream_large_files, name, host, strict_slashes, uri,
content_type) file_or_directory,
pattern,
use_modified_since,
use_content_range,
stream_large_files,
name,
host,
strict_slashes,
content_type,
)
def blueprint(self, blueprint, **options): def blueprint(self, blueprint, **options):
"""Register a blueprint on the application. """Register a blueprint on the application.
@ -407,10 +545,10 @@ class Sanic:
self.blueprint(item, **options) self.blueprint(item, **options)
return return
if blueprint.name in self.blueprints: if blueprint.name in self.blueprints:
assert self.blueprints[blueprint.name] is blueprint, \ assert self.blueprints[blueprint.name] is blueprint, (
'A blueprint with the name "%s" is already registered. ' \ 'A blueprint with the name "%s" is already registered. '
'Blueprint names must be unique.' % \ "Blueprint names must be unique." % (blueprint.name,)
(blueprint.name,) )
else: else:
self.blueprints[blueprint.name] = blueprint self.blueprints[blueprint.name] = blueprint
self._blueprint_order.append(blueprint) self._blueprint_order.append(blueprint)
@ -419,11 +557,13 @@ class Sanic:
def register_blueprint(self, *args, **kwargs): def register_blueprint(self, *args, **kwargs):
# TODO: deprecate 1.0 # TODO: deprecate 1.0
if self.debug: if self.debug:
warnings.simplefilter('default') warnings.simplefilter("default")
warnings.warn("Use of register_blueprint will be deprecated in " warnings.warn(
"version 1.0. Please use the blueprint method" "Use of register_blueprint will be deprecated in "
" instead", "version 1.0. Please use the blueprint method"
DeprecationWarning) " instead",
DeprecationWarning,
)
return self.blueprint(*args, **kwargs) return self.blueprint(*args, **kwargs)
def url_for(self, view_name: str, **kwargs): def url_for(self, view_name: str, **kwargs):
@ -449,67 +589,66 @@ class Sanic:
# find the route by the supplied view name # find the route by the supplied view name
kw = {} kw = {}
# special static files url_for # special static files url_for
if view_name == 'static': if view_name == "static":
kw.update(name=kwargs.pop('name', 'static')) kw.update(name=kwargs.pop("name", "static"))
elif view_name.endswith('.static'): # blueprint.static elif view_name.endswith(".static"): # blueprint.static
kwargs.pop('name', None) kwargs.pop("name", None)
kw.update(name=view_name) kw.update(name=view_name)
uri, route = self.router.find_route_by_view_name(view_name, **kw) uri, route = self.router.find_route_by_view_name(view_name, **kw)
if not (uri and route): if not (uri and route):
raise URLBuildError('Endpoint with name `{}` was not found'.format( raise URLBuildError(
view_name)) "Endpoint with name `{}` was not found".format(view_name)
)
if view_name == 'static' or view_name.endswith('.static'): if view_name == "static" or view_name.endswith(".static"):
filename = kwargs.pop('filename', None) filename = kwargs.pop("filename", None)
# it's static folder # it's static folder
if '<file_uri:' in uri: if "<file_uri:" in uri:
folder_ = uri.split('<file_uri:', 1)[0] folder_ = uri.split("<file_uri:", 1)[0]
if folder_.endswith('/'): if folder_.endswith("/"):
folder_ = folder_[:-1] folder_ = folder_[:-1]
if filename.startswith('/'): if filename.startswith("/"):
filename = filename[1:] filename = filename[1:]
uri = '{}/{}'.format(folder_, filename) uri = "{}/{}".format(folder_, filename)
if uri != '/' and uri.endswith('/'): if uri != "/" and uri.endswith("/"):
uri = uri[:-1] uri = uri[:-1]
out = uri out = uri
# find all the parameters we will need to build in the URL # find all the parameters we will need to build in the URL
matched_params = re.findall( matched_params = re.findall(self.router.parameter_pattern, uri)
self.router.parameter_pattern, uri)
# _method is only a placeholder now, don't know how to support it # _method is only a placeholder now, don't know how to support it
kwargs.pop('_method', None) kwargs.pop("_method", None)
anchor = kwargs.pop('_anchor', '') anchor = kwargs.pop("_anchor", "")
# _external need SERVER_NAME in config or pass _server arg # _external need SERVER_NAME in config or pass _server arg
external = kwargs.pop('_external', False) external = kwargs.pop("_external", False)
scheme = kwargs.pop('_scheme', '') scheme = kwargs.pop("_scheme", "")
if scheme and not external: if scheme and not external:
raise ValueError('When specifying _scheme, _external must be True') raise ValueError("When specifying _scheme, _external must be True")
netloc = kwargs.pop('_server', None) netloc = kwargs.pop("_server", None)
if netloc is None and external: if netloc is None and external:
netloc = self.config.get('SERVER_NAME', '') netloc = self.config.get("SERVER_NAME", "")
if external: if external:
if not scheme: if not scheme:
if ':' in netloc[:8]: if ":" in netloc[:8]:
scheme = netloc[:8].split(':', 1)[0] scheme = netloc[:8].split(":", 1)[0]
else: else:
scheme = 'http' scheme = "http"
if '://' in netloc[:8]: if "://" in netloc[:8]:
netloc = netloc.split('://', 1)[-1] netloc = netloc.split("://", 1)[-1]
for match in matched_params: for match in matched_params:
name, _type, pattern = self.router.parse_parameter_string( name, _type, pattern = self.router.parse_parameter_string(match)
match)
# we only want to match against each individual parameter # we only want to match against each individual parameter
specific_pattern = '^{}$'.format(pattern) specific_pattern = "^{}$".format(pattern)
supplied_param = None supplied_param = None
if name in kwargs: if name in kwargs:
@ -517,8 +656,10 @@ class Sanic:
del kwargs[name] del kwargs[name]
else: else:
raise URLBuildError( raise URLBuildError(
'Required parameter `{}` was not passed to url_for'.format( "Required parameter `{}` was not passed to url_for".format(
name)) name
)
)
supplied_param = str(supplied_param) supplied_param = str(supplied_param)
# determine if the parameter supplied by the caller passes the test # determine if the parameter supplied by the caller passes the test
@ -529,25 +670,28 @@ class Sanic:
if _type != str: if _type != str:
msg = ( msg = (
'Value "{}" for parameter `{}` does not ' 'Value "{}" for parameter `{}` does not '
'match pattern for type `{}`: {}'.format( "match pattern for type `{}`: {}".format(
supplied_param, name, _type.__name__, pattern)) supplied_param, name, _type.__name__, pattern
)
)
else: else:
msg = ( msg = (
'Value "{}" for parameter `{}` ' 'Value "{}" for parameter `{}` '
'does not satisfy pattern {}'.format( "does not satisfy pattern {}".format(
supplied_param, name, pattern)) supplied_param, name, pattern
)
)
raise URLBuildError(msg) raise URLBuildError(msg)
# replace the parameter in the URL with the supplied value # replace the parameter in the URL with the supplied value
replacement_regex = '(<{}.*?>)'.format(name) replacement_regex = "(<{}.*?>)".format(name)
out = re.sub( out = re.sub(replacement_regex, supplied_param, out)
replacement_regex, supplied_param, out)
# parse the remainder of the keyword arguments into a querystring # parse the remainder of the keyword arguments into a querystring
query_string = urlencode(kwargs, doseq=True) if kwargs else '' query_string = urlencode(kwargs, doseq=True) if kwargs else ""
# scheme://netloc/path;parameters?query#fragment # scheme://netloc/path;parameters?query#fragment
out = urlunparse((scheme, netloc, out, '', query_string, anchor)) out = urlunparse((scheme, netloc, out, "", query_string, anchor))
return out return out
@ -594,8 +738,11 @@ class Sanic:
request.uri_template = uri request.uri_template = uri
if handler is None: if handler is None:
raise ServerError( raise ServerError(
("'None' was returned while requesting a " (
"handler from the router")) "'None' was returned while requesting a "
"handler from the router"
)
)
# Run response handler # Run response handler
response = handler(request, *args, **kwargs) response = handler(request, *args, **kwargs)
@ -619,16 +766,20 @@ class Sanic:
response = await response response = await response
except Exception as e: except Exception as e:
if isinstance(e, SanicException): if isinstance(e, SanicException):
response = self.error_handler.default(request=request, response = self.error_handler.default(
exception=e) request=request, exception=e
)
elif self.debug: elif self.debug:
response = HTTPResponse( response = HTTPResponse(
"Error while handling error: {}\nStack: {}".format( "Error while handling error: {}\nStack: {}".format(
e, format_exc()), status=500) e, format_exc()
),
status=500,
)
else: else:
response = HTTPResponse( response = HTTPResponse(
"An error occurred while handling an error", "An error occurred while handling an error", status=500
status=500) )
finally: finally:
# -------------------------------------------- # # -------------------------------------------- #
# Response Middleware # Response Middleware
@ -636,16 +787,17 @@ class Sanic:
# Don't run response middleware if response is None # Don't run response middleware if response is None
if response is not None: if response is not None:
try: try:
response = await self._run_response_middleware(request, response = await self._run_response_middleware(
response) request, response
)
except CancelledError: except CancelledError:
# Response middleware can timeout too, as above. # Response middleware can timeout too, as above.
response = None response = None
cancelled = True cancelled = True
except BaseException: except BaseException:
error_logger.exception( error_logger.exception(
'Exception occurred in one of response ' "Exception occurred in one of response "
'middleware handlers' "middleware handlers"
) )
if cancelled: if cancelled:
raise CancelledError() raise CancelledError()
@ -668,10 +820,21 @@ class Sanic:
# Execution # Execution
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
def run(self, host=None, port=None, debug=False, ssl=None, def run(
sock=None, workers=1, protocol=None, self,
backlog=100, stop_event=None, register_sys_signals=True, host=None,
access_log=True, **kwargs): port=None,
debug=False,
ssl=None,
sock=None,
workers=1,
protocol=None,
backlog=100,
stop_event=None,
register_sys_signals=True,
access_log=True,
**kwargs
):
"""Run the HTTP Server and listen until keyboard interrupt or term """Run the HTTP Server and listen until keyboard interrupt or term
signal. On termination, drain connections before closing. signal. On termination, drain connections before closing.
@ -692,7 +855,7 @@ class Sanic:
# Default auto_reload to false # Default auto_reload to false
auto_reload = False auto_reload = False
# If debug is set, default it to true (unless on windows) # If debug is set, default it to true (unless on windows)
if debug and os.name == 'posix': if debug and os.name == "posix":
auto_reload = True auto_reload = True
# Allow for overriding either of the defaults # Allow for overriding either of the defaults
auto_reload = kwargs.get("auto_reload", auto_reload) auto_reload = kwargs.get("auto_reload", auto_reload)
@ -701,30 +864,43 @@ class Sanic:
host, port = host or "127.0.0.1", port or 8000 host, port = host or "127.0.0.1", port or 8000
if protocol is None: if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled protocol = (
else HttpProtocol) WebSocketProtocol if self.websocket_enabled else HttpProtocol
)
if stop_event is not None: if stop_event is not None:
if debug: if debug:
warnings.simplefilter('default') warnings.simplefilter("default")
warnings.warn("stop_event will be removed from future versions.", warnings.warn(
DeprecationWarning) "stop_event will be removed from future versions.",
DeprecationWarning,
)
# compatibility old access_log params # compatibility old access_log params
self.config.ACCESS_LOG = access_log self.config.ACCESS_LOG = access_log
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, ssl=ssl, sock=sock, host=host,
workers=workers, protocol=protocol, backlog=backlog, port=port,
register_sys_signals=register_sys_signals, auto_reload=auto_reload) debug=debug,
ssl=ssl,
sock=sock,
workers=workers,
protocol=protocol,
backlog=backlog,
register_sys_signals=register_sys_signals,
auto_reload=auto_reload,
)
try: try:
self.is_running = True self.is_running = True
if workers == 1: if workers == 1:
if auto_reload and os.name != 'posix': if auto_reload and os.name != "posix":
# This condition must be removed after implementing # This condition must be removed after implementing
# auto reloader for other operating systems. # auto reloader for other operating systems.
raise NotImplementedError raise NotImplementedError
if auto_reload and \ if (
os.environ.get('SANIC_SERVER_RUNNING') != 'true': auto_reload
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
):
reloader_helpers.watchdog(2) reloader_helpers.watchdog(2)
else: else:
serve(**server_settings) serve(**server_settings)
@ -732,7 +908,8 @@ class Sanic:
serve_multiple(server_settings, workers) serve_multiple(server_settings, workers)
except BaseException: except BaseException:
error_logger.exception( error_logger.exception(
'Experienced exception while trying to serve') "Experienced exception while trying to serve"
)
raise raise
finally: finally:
self.is_running = False self.is_running = False
@ -746,10 +923,18 @@ class Sanic:
"""gunicorn compatibility""" """gunicorn compatibility"""
return self return self
async def create_server(self, host=None, port=None, debug=False, async def create_server(
ssl=None, sock=None, protocol=None, self,
backlog=100, stop_event=None, host=None,
access_log=True): port=None,
debug=False,
ssl=None,
sock=None,
protocol=None,
backlog=100,
stop_event=None,
access_log=True,
):
"""Asynchronous version of `run`. """Asynchronous version of `run`.
NOTE: This does not support multiprocessing and is not the preferred NOTE: This does not support multiprocessing and is not the preferred
@ -760,24 +945,34 @@ class Sanic:
host, port = host or "127.0.0.1", port or 8000 host, port = host or "127.0.0.1", port or 8000
if protocol is None: if protocol is None:
protocol = (WebSocketProtocol if self.websocket_enabled protocol = (
else HttpProtocol) WebSocketProtocol if self.websocket_enabled else HttpProtocol
)
if stop_event is not None: if stop_event is not None:
if debug: if debug:
warnings.simplefilter('default') warnings.simplefilter("default")
warnings.warn("stop_event will be removed from future versions.", warnings.warn(
DeprecationWarning) "stop_event will be removed from future versions.",
DeprecationWarning,
)
# compatibility old access_log params # compatibility old access_log params
self.config.ACCESS_LOG = access_log self.config.ACCESS_LOG = access_log
server_settings = self._helper( server_settings = self._helper(
host=host, port=port, debug=debug, ssl=ssl, sock=sock, host=host,
loop=get_event_loop(), protocol=protocol, port=port,
backlog=backlog, run_async=True) debug=debug,
ssl=ssl,
sock=sock,
loop=get_event_loop(),
protocol=protocol,
backlog=backlog,
run_async=True,
)
# Trigger before_start events # Trigger before_start events
await self.trigger_events( await self.trigger_events(
server_settings.get('before_start', []), server_settings.get("before_start", []),
server_settings.get('loop') server_settings.get("loop"),
) )
return await serve(**server_settings) return await serve(**server_settings)
@ -814,15 +1009,27 @@ class Sanic:
break break
return response return response
def _helper(self, host=None, port=None, debug=False, def _helper(
ssl=None, sock=None, workers=1, loop=None, self,
protocol=HttpProtocol, backlog=100, stop_event=None, host=None,
register_sys_signals=True, run_async=False, auto_reload=False): port=None,
debug=False,
ssl=None,
sock=None,
workers=1,
loop=None,
protocol=HttpProtocol,
backlog=100,
stop_event=None,
register_sys_signals=True,
run_async=False,
auto_reload=False,
):
"""Helper function used by `run` and `create_server`.""" """Helper function used by `run` and `create_server`."""
if isinstance(ssl, dict): if isinstance(ssl, dict):
# try common aliaseses # try common aliaseses
cert = ssl.get('cert') or ssl.get('certificate') cert = ssl.get("cert") or ssl.get("certificate")
key = ssl.get('key') or ssl.get('keyfile') key = ssl.get("key") or ssl.get("keyfile")
if cert is None or key is None: if cert is None or key is None:
raise ValueError("SSLContext or certificate and key required.") raise ValueError("SSLContext or certificate and key required.")
context = create_default_context(purpose=Purpose.CLIENT_AUTH) context = create_default_context(purpose=Purpose.CLIENT_AUTH)
@ -830,40 +1037,42 @@ class Sanic:
ssl = context ssl = context
if stop_event is not None: if stop_event is not None:
if debug: if debug:
warnings.simplefilter('default') warnings.simplefilter("default")
warnings.warn("stop_event will be removed from future versions.", warnings.warn(
DeprecationWarning) "stop_event will be removed from future versions.",
DeprecationWarning,
)
self.error_handler.debug = debug self.error_handler.debug = debug
self.debug = debug self.debug = debug
server_settings = { server_settings = {
'protocol': protocol, "protocol": protocol,
'request_class': self.request_class, "request_class": self.request_class,
'is_request_stream': self.is_request_stream, "is_request_stream": self.is_request_stream,
'router': self.router, "router": self.router,
'host': host, "host": host,
'port': port, "port": port,
'sock': sock, "sock": sock,
'ssl': ssl, "ssl": ssl,
'signal': Signal(), "signal": Signal(),
'debug': debug, "debug": debug,
'request_handler': self.handle_request, "request_handler": self.handle_request,
'error_handler': self.error_handler, "error_handler": self.error_handler,
'request_timeout': self.config.REQUEST_TIMEOUT, "request_timeout": self.config.REQUEST_TIMEOUT,
'response_timeout': self.config.RESPONSE_TIMEOUT, "response_timeout": self.config.RESPONSE_TIMEOUT,
'keep_alive_timeout': self.config.KEEP_ALIVE_TIMEOUT, "keep_alive_timeout": self.config.KEEP_ALIVE_TIMEOUT,
'request_max_size': self.config.REQUEST_MAX_SIZE, "request_max_size": self.config.REQUEST_MAX_SIZE,
'keep_alive': self.config.KEEP_ALIVE, "keep_alive": self.config.KEEP_ALIVE,
'loop': loop, "loop": loop,
'register_sys_signals': register_sys_signals, "register_sys_signals": register_sys_signals,
'backlog': backlog, "backlog": backlog,
'access_log': self.config.ACCESS_LOG, "access_log": self.config.ACCESS_LOG,
'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE, "websocket_max_size": self.config.WEBSOCKET_MAX_SIZE,
'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE, "websocket_max_queue": self.config.WEBSOCKET_MAX_QUEUE,
'websocket_read_limit': self.config.WEBSOCKET_READ_LIMIT, "websocket_read_limit": self.config.WEBSOCKET_READ_LIMIT,
'websocket_write_limit': self.config.WEBSOCKET_WRITE_LIMIT, "websocket_write_limit": self.config.WEBSOCKET_WRITE_LIMIT,
'graceful_shutdown_timeout': self.config.GRACEFUL_SHUTDOWN_TIMEOUT "graceful_shutdown_timeout": self.config.GRACEFUL_SHUTDOWN_TIMEOUT,
} }
# -------------------------------------------- # # -------------------------------------------- #
@ -871,10 +1080,10 @@ class Sanic:
# -------------------------------------------- # # -------------------------------------------- #
for event_name, settings_name, reverse in ( for event_name, settings_name, reverse in (
("before_server_start", "before_start", False), ("before_server_start", "before_start", False),
("after_server_start", "after_start", False), ("after_server_start", "after_start", False),
("before_server_stop", "before_stop", True), ("before_server_stop", "before_stop", True),
("after_server_stop", "after_stop", True), ("after_server_stop", "after_stop", True),
): ):
listeners = self.listeners[event_name].copy() listeners = self.listeners[event_name].copy()
if reverse: if reverse:
@ -886,18 +1095,20 @@ class Sanic:
if self.configure_logging and debug: if self.configure_logging and debug:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
if self.config.LOGO is not None and \ if (
os.environ.get('SANIC_SERVER_RUNNING') != 'true': self.config.LOGO is not None
and os.environ.get("SANIC_SERVER_RUNNING") != "true"
):
logger.debug(self.config.LOGO) logger.debug(self.config.LOGO)
if run_async: if run_async:
server_settings['run_async'] = True server_settings["run_async"] = True
# Serve # Serve
if host and port and os.environ.get('SANIC_SERVER_RUNNING') != 'true': if host and port and os.environ.get("SANIC_SERVER_RUNNING") != "true":
proto = "http" proto = "http"
if ssl is not None: if ssl is not None:
proto = "https" proto = "https"
logger.info('Goin\' Fast @ {}://{}:{}'.format(proto, host, port)) logger.info("Goin' Fast @ {}://{}:{}".format(proto, host, port))
return server_settings return server_settings

View File

@ -3,21 +3,36 @@ from collections import defaultdict, namedtuple
from sanic.constants import HTTP_METHODS from sanic.constants import HTTP_METHODS
from sanic.views import CompositionView from sanic.views import CompositionView
FutureRoute = namedtuple('Route', FutureRoute = namedtuple(
['handler', 'uri', 'methods', 'host', "Route",
'strict_slashes', 'stream', 'version', 'name']) [
FutureListener = namedtuple('Listener', ['handler', 'uri', 'methods', 'host']) "handler",
FutureMiddleware = namedtuple('Route', ['middleware', 'args', 'kwargs']) "uri",
FutureException = namedtuple('Route', ['handler', 'args', 'kwargs']) "methods",
FutureStatic = namedtuple('Route', "host",
['uri', 'file_or_directory', 'args', 'kwargs']) "strict_slashes",
"stream",
"version",
"name",
],
)
FutureListener = namedtuple("Listener", ["handler", "uri", "methods", "host"])
FutureMiddleware = namedtuple("Route", ["middleware", "args", "kwargs"])
FutureException = namedtuple("Route", ["handler", "args", "kwargs"])
FutureStatic = namedtuple(
"Route", ["uri", "file_or_directory", "args", "kwargs"]
)
class Blueprint: class Blueprint:
def __init__(self, name, def __init__(
url_prefix=None, self,
host=None, version=None, name,
strict_slashes=False): url_prefix=None,
host=None,
version=None,
strict_slashes=False,
):
"""Create a new blueprint """Create a new blueprint
:param name: unique name of the blueprint :param name: unique name of the blueprint
@ -38,13 +53,14 @@ class Blueprint:
self.strict_slashes = strict_slashes self.strict_slashes = strict_slashes
@staticmethod @staticmethod
def group(*blueprints, url_prefix=''): def group(*blueprints, url_prefix=""):
"""Create a list of blueprints, optionally """Create a list of blueprints, optionally
grouping them under a general URL prefix. grouping them under a general URL prefix.
:param blueprints: blueprints to be registered as a group :param blueprints: blueprints to be registered as a group
:param url_prefix: URL route to be prepended to all sub-prefixes :param url_prefix: URL route to be prepended to all sub-prefixes
""" """
def chain(nested): def chain(nested):
"""itertools.chain() but leaves strings untouched""" """itertools.chain() but leaves strings untouched"""
for i in nested: for i in nested:
@ -52,10 +68,11 @@ class Blueprint:
yield from chain(i) yield from chain(i)
else: else:
yield i yield i
bps = [] bps = []
for bp in chain(blueprints): for bp in chain(blueprints):
if bp.url_prefix is None: if bp.url_prefix is None:
bp.url_prefix = '' bp.url_prefix = ""
bp.url_prefix = url_prefix + bp.url_prefix bp.url_prefix = url_prefix + bp.url_prefix
bps.append(bp) bps.append(bp)
return bps return bps
@ -63,7 +80,7 @@ class Blueprint:
def register(self, app, options): def register(self, app, options):
"""Register the blueprint to the sanic app.""" """Register the blueprint to the sanic app."""
url_prefix = options.get('url_prefix', self.url_prefix) url_prefix = options.get("url_prefix", self.url_prefix)
# Routes # Routes
for future in self.routes: for future in self.routes:
@ -75,14 +92,15 @@ class Blueprint:
version = future.version or self.version version = future.version or self.version
app.route(uri=uri[1:] if uri.startswith('//') else uri, app.route(
methods=future.methods, uri=uri[1:] if uri.startswith("//") else uri,
host=future.host or self.host, methods=future.methods,
strict_slashes=future.strict_slashes, host=future.host or self.host,
stream=future.stream, strict_slashes=future.strict_slashes,
version=version, stream=future.stream,
name=future.name, version=version,
)(future.handler) name=future.name,
)(future.handler)
for future in self.websocket_routes: for future in self.websocket_routes:
# attach the blueprint name to the handler so that it can be # attach the blueprint name to the handler so that it can be
@ -90,18 +108,19 @@ class Blueprint:
future.handler.__blueprintname__ = self.name future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available # Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri uri = url_prefix + future.uri if url_prefix else future.uri
app.websocket(uri=uri, app.websocket(
host=future.host or self.host, uri=uri,
strict_slashes=future.strict_slashes, host=future.host or self.host,
name=future.name, strict_slashes=future.strict_slashes,
)(future.handler) name=future.name,
)(future.handler)
# Middleware # Middleware
for future in self.middlewares: for future in self.middlewares:
if future.args or future.kwargs: if future.args or future.kwargs:
app.register_middleware(future.middleware, app.register_middleware(
*future.args, future.middleware, *future.args, **future.kwargs
**future.kwargs) )
else: else:
app.register_middleware(future.middleware) app.register_middleware(future.middleware)
@ -113,16 +132,25 @@ class Blueprint:
for future in self.statics: for future in self.statics:
# Prepend the blueprint URI prefix if available # Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri uri = url_prefix + future.uri if url_prefix else future.uri
app.static(uri, future.file_or_directory, app.static(
*future.args, **future.kwargs) uri, future.file_or_directory, *future.args, **future.kwargs
)
# Event listeners # Event listeners
for event, listeners in self.listeners.items(): for event, listeners in self.listeners.items():
for listener in listeners: for listener in listeners:
app.listener(event)(listener) app.listener(event)(listener)
def route(self, uri, methods=frozenset({'GET'}), host=None, def route(
strict_slashes=None, stream=False, version=None, name=None): self,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
stream=False,
version=None,
name=None,
):
"""Create a blueprint route from a decorated function. """Create a blueprint route from a decorated function.
:param uri: endpoint at which the route will be accessible. :param uri: endpoint at which the route will be accessible.
@ -133,14 +161,30 @@ class Blueprint:
def decorator(handler): def decorator(handler):
route = FutureRoute( route = FutureRoute(
handler, uri, methods, host, strict_slashes, stream, version, handler,
name) uri,
methods,
host,
strict_slashes,
stream,
version,
name,
)
self.routes.append(route) self.routes.append(route)
return handler return handler
return decorator return decorator
def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None, def add_route(
strict_slashes=None, version=None, name=None): self,
handler,
uri,
methods=frozenset({"GET"}),
host=None,
strict_slashes=None,
version=None,
name=None,
):
"""Create a blueprint route from a function. """Create a blueprint route from a function.
:param handler: function for handling uri requests. Accepts function, :param handler: function for handling uri requests. Accepts function,
@ -154,7 +198,7 @@ class Blueprint:
:return: function or class instance :return: function or class instance
""" """
# Handle HTTPMethodView differently # Handle HTTPMethodView differently
if hasattr(handler, 'view_class'): if hasattr(handler, "view_class"):
methods = set() methods = set()
for method in HTTP_METHODS: for method in HTTP_METHODS:
@ -168,13 +212,19 @@ class Blueprint:
if isinstance(handler, CompositionView): if isinstance(handler, CompositionView):
methods = handler.handlers.keys() methods = handler.handlers.keys()
self.route(uri=uri, methods=methods, host=host, self.route(
strict_slashes=strict_slashes, version=version, uri=uri,
name=name)(handler) methods=methods,
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)(handler)
return handler return handler
def websocket(self, uri, host=None, strict_slashes=None, version=None, def websocket(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
):
"""Create a blueprint websocket route from a decorated function. """Create a blueprint websocket route from a decorated function.
:param uri: endpoint at which the route will be accessible. :param uri: endpoint at which the route will be accessible.
@ -183,14 +233,17 @@ class Blueprint:
strict_slashes = self.strict_slashes strict_slashes = self.strict_slashes
def decorator(handler): def decorator(handler):
route = FutureRoute(handler, uri, [], host, strict_slashes, route = FutureRoute(
False, version, name) handler, uri, [], host, strict_slashes, False, version, name
)
self.websocket_routes.append(route) self.websocket_routes.append(route)
return handler return handler
return decorator return decorator
def add_websocket_route(self, handler, uri, host=None, version=None, def add_websocket_route(
name=None): self, handler, uri, host=None, version=None, name=None
):
"""Create a blueprint websocket route from a function. """Create a blueprint websocket route from a function.
:param handler: function for handling uri requests. Accepts function, :param handler: function for handling uri requests. Accepts function,
@ -206,13 +259,16 @@ class Blueprint:
:param event: Event to listen to. :param event: Event to listen to.
""" """
def decorator(listener): def decorator(listener):
self.listeners[event].append(listener) self.listeners[event].append(listener)
return listener return listener
return decorator return decorator
def middleware(self, *args, **kwargs): def middleware(self, *args, **kwargs):
"""Create a blueprint middleware from a decorated function.""" """Create a blueprint middleware from a decorated function."""
def register_middleware(_middleware): def register_middleware(_middleware):
future_middleware = FutureMiddleware(_middleware, args, kwargs) future_middleware = FutureMiddleware(_middleware, args, kwargs)
self.middlewares.append(future_middleware) self.middlewares.append(future_middleware)
@ -228,10 +284,12 @@ class Blueprint:
def exception(self, *args, **kwargs): def exception(self, *args, **kwargs):
"""Create a blueprint exception from a decorated function.""" """Create a blueprint exception from a decorated function."""
def decorator(handler): def decorator(handler):
exception = FutureException(handler, args, kwargs) exception = FutureException(handler, args, kwargs)
self.exceptions.append(exception) self.exceptions.append(exception)
return handler return handler
return decorator return decorator
def static(self, uri, file_or_directory, *args, **kwargs): def static(self, uri, file_or_directory, *args, **kwargs):
@ -240,12 +298,12 @@ class Blueprint:
:param uri: endpoint at which the route will be accessible. :param uri: endpoint at which the route will be accessible.
:param file_or_directory: Static asset. :param file_or_directory: Static asset.
""" """
name = kwargs.pop('name', 'static') name = kwargs.pop("name", "static")
if not name.startswith(self.name + '.'): if not name.startswith(self.name + "."):
name = '{}.{}'.format(self.name, name) name = "{}.{}".format(self.name, name)
kwargs.update(name=name) kwargs.update(name=name)
strict_slashes = kwargs.get('strict_slashes') strict_slashes = kwargs.get("strict_slashes")
if strict_slashes is None and self.strict_slashes is not None: if strict_slashes is None and self.strict_slashes is not None:
kwargs.update(strict_slashes=self.strict_slashes) kwargs.update(strict_slashes=self.strict_slashes)
@ -253,44 +311,107 @@ class Blueprint:
self.statics.append(static) self.statics.append(static)
# Shorthand method decorators # Shorthand method decorators
def get(self, uri, host=None, strict_slashes=None, version=None, def get(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=["GET"], host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=["GET"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def post(self, uri, host=None, strict_slashes=None, stream=False, def post(
version=None, name=None): self,
return self.route(uri, methods=["POST"], host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["POST"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def put(self, uri, host=None, strict_slashes=None, stream=False, def put(
version=None, name=None): self,
return self.route(uri, methods=["PUT"], host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["PUT"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def head(self, uri, host=None, strict_slashes=None, version=None, def head(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=["HEAD"], host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=["HEAD"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def options(self, uri, host=None, strict_slashes=None, version=None, def options(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=["OPTIONS"], host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=["OPTIONS"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)
def patch(self, uri, host=None, strict_slashes=None, stream=False, def patch(
version=None, name=None): self,
return self.route(uri, methods=["PATCH"], host=host, uri,
strict_slashes=strict_slashes, stream=stream, host=None,
version=version, name=name) strict_slashes=None,
stream=False,
version=None,
name=None,
):
return self.route(
uri,
methods=["PATCH"],
host=host,
strict_slashes=strict_slashes,
stream=stream,
version=version,
name=name,
)
def delete(self, uri, host=None, strict_slashes=None, version=None, def delete(
name=None): self, uri, host=None, strict_slashes=None, version=None, name=None
return self.route(uri, methods=["DELETE"], host=host, ):
strict_slashes=strict_slashes, version=version, return self.route(
name=name) uri,
methods=["DELETE"],
host=host,
strict_slashes=strict_slashes,
version=version,
name=name,
)

View File

@ -4,7 +4,7 @@ import types
from sanic.exceptions import PyFileError from sanic.exceptions import PyFileError
SANIC_PREFIX = 'SANIC_' SANIC_PREFIX = "SANIC_"
class Config(dict): class Config(dict):
@ -65,9 +65,10 @@ class Config(dict):
""" """
config_file = os.environ.get(variable_name) config_file = os.environ.get(variable_name)
if not config_file: if not config_file:
raise RuntimeError('The environment variable %r is not set and ' raise RuntimeError(
'thus configuration could not be loaded.' % "The environment variable %r is not set and "
variable_name) "thus configuration could not be loaded." % variable_name
)
return self.from_pyfile(config_file) return self.from_pyfile(config_file)
def from_pyfile(self, filename): def from_pyfile(self, filename):
@ -76,14 +77,16 @@ class Config(dict):
:param filename: an absolute path to the config file :param filename: an absolute path to the config file
""" """
module = types.ModuleType('config') module = types.ModuleType("config")
module.__file__ = filename module.__file__ = filename
try: try:
with open(filename) as config_file: with open(filename) as config_file:
exec(compile(config_file.read(), filename, 'exec'), exec(
module.__dict__) compile(config_file.read(), filename, "exec"),
module.__dict__,
)
except IOError as e: except IOError as e:
e.strerror = 'Unable to load configuration file (%s)' % e.strerror e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise raise
except Exception as e: except Exception as e:
raise PyFileError(filename) from e raise PyFileError(filename) from e

View File

@ -1 +1 @@
HTTP_METHODS = ('GET', 'POST', 'PUT', 'HEAD', 'OPTIONS', 'PATCH', 'DELETE') HTTP_METHODS = ("GET", "POST", "PUT", "HEAD", "OPTIONS", "PATCH", "DELETE")

View File

@ -8,14 +8,12 @@ import string
# Straight up copied this section of dark magic from SimpleCookie # Straight up copied this section of dark magic from SimpleCookie
_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:" _LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:"
_UnescapedChars = _LegalChars + ' ()/<=>?@[]{}' _UnescapedChars = _LegalChars + " ()/<=>?@[]{}"
_Translator = {n: '\\%03o' % n _Translator = {
for n in set(range(256)) - set(map(ord, _UnescapedChars))} n: "\\%03o" % n for n in set(range(256)) - set(map(ord, _UnescapedChars))
_Translator.update({ }
ord('"'): '\\"', _Translator.update({ord('"'): '\\"', ord("\\"): "\\\\"})
ord('\\'): '\\\\',
})
def _quote(str): def _quote(str):
@ -30,7 +28,7 @@ def _quote(str):
return '"' + str.translate(_Translator) + '"' return '"' + str.translate(_Translator) + '"'
_is_legal_key = re.compile('[%s]+' % re.escape(_LegalChars)).fullmatch _is_legal_key = re.compile("[%s]+" % re.escape(_LegalChars)).fullmatch
# ------------------------------------------------------------ # # ------------------------------------------------------------ #
# Custom SimpleCookie # Custom SimpleCookie
@ -53,7 +51,7 @@ class CookieJar(dict):
# If this cookie doesn't exist, add it to the header keys # If this cookie doesn't exist, add it to the header keys
if not self.cookie_headers.get(key): if not self.cookie_headers.get(key):
cookie = Cookie(key, value) cookie = Cookie(key, value)
cookie['path'] = '/' cookie["path"] = "/"
self.cookie_headers[key] = self.header_key self.cookie_headers[key] = self.header_key
self.headers.add(self.header_key, cookie) self.headers.add(self.header_key, cookie)
return super().__setitem__(key, cookie) return super().__setitem__(key, cookie)
@ -62,8 +60,8 @@ class CookieJar(dict):
def __delitem__(self, key): def __delitem__(self, key):
if key not in self.cookie_headers: if key not in self.cookie_headers:
self[key] = '' self[key] = ""
self[key]['max-age'] = 0 self[key]["max-age"] = 0
else: else:
cookie_header = self.cookie_headers[key] cookie_header = self.cookie_headers[key]
# remove it from header # remove it from header
@ -77,6 +75,7 @@ class CookieJar(dict):
class Cookie(dict): class Cookie(dict):
"""A stripped down version of Morsel from SimpleCookie #gottagofast""" """A stripped down version of Morsel from SimpleCookie #gottagofast"""
_keys = { _keys = {
"expires": "expires", "expires": "expires",
"path": "Path", "path": "Path",
@ -88,7 +87,7 @@ class Cookie(dict):
"version": "Version", "version": "Version",
"samesite": "SameSite", "samesite": "SameSite",
} }
_flags = {'secure', 'httponly'} _flags = {"secure", "httponly"}
def __init__(self, key, value): def __init__(self, key, value):
if key in self._keys: if key in self._keys:
@ -106,24 +105,27 @@ class Cookie(dict):
return super().__setitem__(key, value) return super().__setitem__(key, value)
def encode(self, encoding): def encode(self, encoding):
output = ['%s=%s' % (self.key, _quote(self.value))] output = ["%s=%s" % (self.key, _quote(self.value))]
for key, value in self.items(): for key, value in self.items():
if key == 'max-age': if key == "max-age":
try: try:
output.append('%s=%d' % (self._keys[key], value)) output.append("%s=%d" % (self._keys[key], value))
except TypeError: except TypeError:
output.append('%s=%s' % (self._keys[key], value)) output.append("%s=%s" % (self._keys[key], value))
elif key == 'expires': elif key == "expires":
try: try:
output.append('%s=%s' % ( output.append(
self._keys[key], "%s=%s"
value.strftime("%a, %d-%b-%Y %T GMT") % (
)) self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT"),
)
)
except AttributeError: except AttributeError:
output.append('%s=%s' % (self._keys[key], value)) output.append("%s=%s" % (self._keys[key], value))
elif key in self._flags and self[key]: elif key in self._flags and self[key]:
output.append(self._keys[key]) output.append(self._keys[key])
else: else:
output.append('%s=%s' % (self._keys[key], value)) output.append("%s=%s" % (self._keys[key], value))
return "; ".join(output).encode(encoding) return "; ".join(output).encode(encoding)

View File

@ -1,6 +1,6 @@
from sanic.helpers import STATUS_CODES from sanic.helpers import STATUS_CODES
TRACEBACK_STYLE = ''' TRACEBACK_STYLE = """
<style> <style>
body { body {
padding: 20px; padding: 20px;
@ -61,9 +61,9 @@ TRACEBACK_STYLE = '''
font-size: 14px; font-size: 14px;
} }
</style> </style>
''' """
TRACEBACK_WRAPPER_HTML = ''' TRACEBACK_WRAPPER_HTML = """
<html> <html>
<head> <head>
{style} {style}
@ -78,27 +78,27 @@ TRACEBACK_WRAPPER_HTML = '''
</div> </div>
</body> </body>
</html> </html>
''' """
TRACEBACK_WRAPPER_INNER_HTML = ''' TRACEBACK_WRAPPER_INNER_HTML = """
<h1>{exc_name}</h1> <h1>{exc_name}</h1>
<h3><code>{exc_value}</code></h3> <h3><code>{exc_value}</code></h3>
<div class="tb-wrapper"> <div class="tb-wrapper">
<p class="tb-header">Traceback (most recent call last):</p> <p class="tb-header">Traceback (most recent call last):</p>
{frame_html} {frame_html}
</div> </div>
''' """
TRACEBACK_BORDER = ''' TRACEBACK_BORDER = """
<div class="tb-border"> <div class="tb-border">
<b><i> <b><i>
The above exception was the direct cause of the The above exception was the direct cause of the
following exception: following exception:
</i></b> </i></b>
</div> </div>
''' """
TRACEBACK_LINE_HTML = ''' TRACEBACK_LINE_HTML = """
<div class="frame-line"> <div class="frame-line">
<p class="frame-descriptor"> <p class="frame-descriptor">
File {0.filename}, line <i>{0.lineno}</i>, File {0.filename}, line <i>{0.lineno}</i>,
@ -106,15 +106,15 @@ TRACEBACK_LINE_HTML = '''
</p> </p>
<p class="frame-code"><code>{0.line}</code></p> <p class="frame-code"><code>{0.line}</code></p>
</div> </div>
''' """
INTERNAL_SERVER_ERROR_HTML = ''' INTERNAL_SERVER_ERROR_HTML = """
<h1>Internal Server Error</h1> <h1>Internal Server Error</h1>
<p> <p>
The server encountered an internal error and cannot complete The server encountered an internal error and cannot complete
your request. your request.
</p> </p>
''' """
_sanic_exceptions = {} _sanic_exceptions = {}
@ -124,15 +124,16 @@ def add_status_code(code):
""" """
Decorator used for adding exceptions to _sanic_exceptions. Decorator used for adding exceptions to _sanic_exceptions.
""" """
def class_decorator(cls): def class_decorator(cls):
cls.status_code = code cls.status_code = code
_sanic_exceptions[code] = cls _sanic_exceptions[code] = cls
return cls return cls
return class_decorator return class_decorator
class SanicException(Exception): class SanicException(Exception):
def __init__(self, message, status_code=None): def __init__(self, message, status_code=None):
super().__init__(message) super().__init__(message)
@ -156,8 +157,8 @@ class MethodNotSupported(SanicException):
super().__init__(message) super().__init__(message)
self.headers = dict() self.headers = dict()
self.headers["Allow"] = ", ".join(allowed_methods) self.headers["Allow"] = ", ".join(allowed_methods)
if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: if method in ["HEAD", "PATCH", "PUT", "DELETE"]:
self.headers['Content-Length'] = 0 self.headers["Content-Length"] = 0
@add_status_code(500) @add_status_code(500)
@ -169,6 +170,7 @@ class ServerError(SanicException):
class ServiceUnavailable(SanicException): class ServiceUnavailable(SanicException):
"""The server is currently unavailable (because it is overloaded or """The server is currently unavailable (because it is overloaded or
down for maintenance). Generally, this is a temporary state.""" down for maintenance). Generally, this is a temporary state."""
pass pass
@ -192,6 +194,7 @@ class RequestTimeout(SanicException):
the connection. The socket connection has actually been lost - the Web the connection. The socket connection has actually been lost - the Web
server has 'timed out' on that particular socket connection. server has 'timed out' on that particular socket connection.
""" """
pass pass
@ -209,8 +212,8 @@ class ContentRangeError(SanicException):
def __init__(self, message, content_range): def __init__(self, message, content_range):
super().__init__(message) super().__init__(message)
self.headers = { self.headers = {
'Content-Type': 'text/plain', "Content-Type": "text/plain",
"Content-Range": "bytes */%s" % (content_range.total,) "Content-Range": "bytes */%s" % (content_range.total,),
} }
@ -225,7 +228,7 @@ class InvalidRangeType(ContentRangeError):
class PyFileError(Exception): class PyFileError(Exception):
def __init__(self, file): def __init__(self, file):
super().__init__('could not execute config file %s', file) super().__init__("could not execute config file %s", file)
@add_status_code(401) @add_status_code(401)
@ -263,13 +266,14 @@ class Unauthorized(SanicException):
scheme="Bearer", scheme="Bearer",
realm="Restricted Area") realm="Restricted Area")
""" """
def __init__(self, message, status_code=None, scheme=None, **kwargs): def __init__(self, message, status_code=None, scheme=None, **kwargs):
super().__init__(message, status_code) super().__init__(message, status_code)
# if auth-scheme is specified, set "WWW-Authenticate" header # if auth-scheme is specified, set "WWW-Authenticate" header
if scheme is not None: if scheme is not None:
values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()] values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()]
challenge = ', '.join(values) challenge = ", ".join(values)
self.headers = { self.headers = {
"WWW-Authenticate": "{} {}".format(scheme, challenge).rstrip() "WWW-Authenticate": "{} {}".format(scheme, challenge).rstrip()
@ -288,6 +292,6 @@ def abort(status_code, message=None):
if message is None: if message is None:
message = STATUS_CODES.get(status_code) message = STATUS_CODES.get(status_code)
# These are stored as bytes in the STATUS_CODES dict # These are stored as bytes in the STATUS_CODES dict
message = message.decode('utf8') message = message.decode("utf8")
sanic_exception = _sanic_exceptions.get(status_code, SanicException) sanic_exception = _sanic_exceptions.get(status_code, SanicException)
raise sanic_exception(message=message, status_code=status_code) raise sanic_exception(message=message, status_code=status_code)

View File

@ -11,7 +11,8 @@ from sanic.exceptions import (
TRACEBACK_STYLE, TRACEBACK_STYLE,
TRACEBACK_WRAPPER_HTML, TRACEBACK_WRAPPER_HTML,
TRACEBACK_WRAPPER_INNER_HTML, TRACEBACK_WRAPPER_INNER_HTML,
TRACEBACK_BORDER) TRACEBACK_BORDER,
)
from sanic.log import logger from sanic.log import logger
from sanic.response import text, html from sanic.response import text, html
@ -36,7 +37,8 @@ class ErrorHandler:
return TRACEBACK_WRAPPER_INNER_HTML.format( return TRACEBACK_WRAPPER_INNER_HTML.format(
exc_name=exception.__class__.__name__, exc_name=exception.__class__.__name__,
exc_value=exception, exc_value=exception,
frame_html=''.join(frame_html)) frame_html="".join(frame_html),
)
def _render_traceback_html(self, exception, request): def _render_traceback_html(self, exception, request):
exc_type, exc_value, tb = sys.exc_info() exc_type, exc_value, tb = sys.exc_info()
@ -51,7 +53,8 @@ class ErrorHandler:
exc_name=exception.__class__.__name__, exc_name=exception.__class__.__name__,
exc_value=exception, exc_value=exception,
inner_html=TRACEBACK_BORDER.join(reversed(exceptions)), inner_html=TRACEBACK_BORDER.join(reversed(exceptions)),
path=request.path) path=request.path,
)
def add(self, exception, handler): def add(self, exception, handler):
self.handlers.append((exception, handler)) self.handlers.append((exception, handler))
@ -88,17 +91,18 @@ class ErrorHandler:
url = repr(request.url) url = repr(request.url)
except AttributeError: except AttributeError:
url = "unknown" url = "unknown"
response_message = ('Exception raised in exception handler ' response_message = (
'"%s" for uri: %s') "Exception raised in exception handler " '"%s" for uri: %s'
)
logger.exception(response_message, handler.__name__, url) logger.exception(response_message, handler.__name__, url)
if self.debug: if self.debug:
return text(response_message % (handler.__name__, url), 500) return text(response_message % (handler.__name__, url), 500)
else: else:
return text('An error occurred while handling an error', 500) return text("An error occurred while handling an error", 500)
return response return response
def log(self, message, level='error'): def log(self, message, level="error"):
""" """
Deprecated, do not use. Deprecated, do not use.
""" """
@ -110,14 +114,14 @@ class ErrorHandler:
except AttributeError: except AttributeError:
url = "unknown" url = "unknown"
response_message = ('Exception occurred while handling uri: %s') response_message = "Exception occurred while handling uri: %s"
logger.exception(response_message, url) logger.exception(response_message, url)
if issubclass(type(exception), SanicException): if issubclass(type(exception), SanicException):
return text( return text(
'Error: {}'.format(exception), "Error: {}".format(exception),
status=getattr(exception, 'status_code', 500), status=getattr(exception, "status_code", 500),
headers=getattr(exception, 'headers', dict()) headers=getattr(exception, "headers", dict()),
) )
elif self.debug: elif self.debug:
html_output = self._render_traceback_html(exception, request) html_output = self._render_traceback_html(exception, request)
@ -129,32 +133,37 @@ class ErrorHandler:
class ContentRangeHandler: class ContentRangeHandler:
"""Class responsible for parsing request header""" """Class responsible for parsing request header"""
__slots__ = ('start', 'end', 'size', 'total', 'headers')
__slots__ = ("start", "end", "size", "total", "headers")
def __init__(self, request, stats): def __init__(self, request, stats):
self.total = stats.st_size self.total = stats.st_size
_range = request.headers.get('Range') _range = request.headers.get("Range")
if _range is None: if _range is None:
raise HeaderNotFound('Range Header Not Found') raise HeaderNotFound("Range Header Not Found")
unit, _, value = tuple(map(str.strip, _range.partition('='))) unit, _, value = tuple(map(str.strip, _range.partition("=")))
if unit != 'bytes': if unit != "bytes":
raise InvalidRangeType( raise InvalidRangeType(
'%s is not a valid Range Type' % (unit,), self) "%s is not a valid Range Type" % (unit,), self
start_b, _, end_b = tuple(map(str.strip, value.partition('-'))) )
start_b, _, end_b = tuple(map(str.strip, value.partition("-")))
try: try:
self.start = int(start_b) if start_b else None self.start = int(start_b) if start_b else None
except ValueError: except ValueError:
raise ContentRangeError( raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (start_b,), self) "'%s' is invalid for Content Range" % (start_b,), self
)
try: try:
self.end = int(end_b) if end_b else None self.end = int(end_b) if end_b else None
except ValueError: except ValueError:
raise ContentRangeError( raise ContentRangeError(
'\'%s\' is invalid for Content Range' % (end_b,), self) "'%s' is invalid for Content Range" % (end_b,), self
)
if self.end is None: if self.end is None:
if self.start is None: if self.start is None:
raise ContentRangeError( raise ContentRangeError(
'Invalid for Content Range parameters', self) "Invalid for Content Range parameters", self
)
else: else:
# this case represents `Content-Range: bytes 5-` # this case represents `Content-Range: bytes 5-`
self.end = self.total self.end = self.total
@ -165,11 +174,13 @@ class ContentRangeHandler:
self.end = self.total self.end = self.total
if self.start >= self.end: if self.start >= self.end:
raise ContentRangeError( raise ContentRangeError(
'Invalid for Content Range parameters', self) "Invalid for Content Range parameters", self
)
self.size = self.end - self.start self.size = self.end - self.start
self.headers = { self.headers = {
'Content-Range': "bytes %s-%s/%s" % ( "Content-Range": "bytes %s-%s/%s"
self.start, self.end, self.total)} % (self.start, self.end, self.total)
}
def __bool__(self): def __bool__(self):
return self.size > 0 return self.size > 0

View File

@ -1,93 +1,97 @@
"""Defines basics of HTTP standard.""" """Defines basics of HTTP standard."""
STATUS_CODES = { STATUS_CODES = {
100: b'Continue', 100: b"Continue",
101: b'Switching Protocols', 101: b"Switching Protocols",
102: b'Processing', 102: b"Processing",
200: b'OK', 200: b"OK",
201: b'Created', 201: b"Created",
202: b'Accepted', 202: b"Accepted",
203: b'Non-Authoritative Information', 203: b"Non-Authoritative Information",
204: b'No Content', 204: b"No Content",
205: b'Reset Content', 205: b"Reset Content",
206: b'Partial Content', 206: b"Partial Content",
207: b'Multi-Status', 207: b"Multi-Status",
208: b'Already Reported', 208: b"Already Reported",
226: b'IM Used', 226: b"IM Used",
300: b'Multiple Choices', 300: b"Multiple Choices",
301: b'Moved Permanently', 301: b"Moved Permanently",
302: b'Found', 302: b"Found",
303: b'See Other', 303: b"See Other",
304: b'Not Modified', 304: b"Not Modified",
305: b'Use Proxy', 305: b"Use Proxy",
307: b'Temporary Redirect', 307: b"Temporary Redirect",
308: b'Permanent Redirect', 308: b"Permanent Redirect",
400: b'Bad Request', 400: b"Bad Request",
401: b'Unauthorized', 401: b"Unauthorized",
402: b'Payment Required', 402: b"Payment Required",
403: b'Forbidden', 403: b"Forbidden",
404: b'Not Found', 404: b"Not Found",
405: b'Method Not Allowed', 405: b"Method Not Allowed",
406: b'Not Acceptable', 406: b"Not Acceptable",
407: b'Proxy Authentication Required', 407: b"Proxy Authentication Required",
408: b'Request Timeout', 408: b"Request Timeout",
409: b'Conflict', 409: b"Conflict",
410: b'Gone', 410: b"Gone",
411: b'Length Required', 411: b"Length Required",
412: b'Precondition Failed', 412: b"Precondition Failed",
413: b'Request Entity Too Large', 413: b"Request Entity Too Large",
414: b'Request-URI Too Long', 414: b"Request-URI Too Long",
415: b'Unsupported Media Type', 415: b"Unsupported Media Type",
416: b'Requested Range Not Satisfiable', 416: b"Requested Range Not Satisfiable",
417: b'Expectation Failed', 417: b"Expectation Failed",
418: b'I\'m a teapot', 418: b"I'm a teapot",
422: b'Unprocessable Entity', 422: b"Unprocessable Entity",
423: b'Locked', 423: b"Locked",
424: b'Failed Dependency', 424: b"Failed Dependency",
426: b'Upgrade Required', 426: b"Upgrade Required",
428: b'Precondition Required', 428: b"Precondition Required",
429: b'Too Many Requests', 429: b"Too Many Requests",
431: b'Request Header Fields Too Large', 431: b"Request Header Fields Too Large",
451: b'Unavailable For Legal Reasons', 451: b"Unavailable For Legal Reasons",
500: b'Internal Server Error', 500: b"Internal Server Error",
501: b'Not Implemented', 501: b"Not Implemented",
502: b'Bad Gateway', 502: b"Bad Gateway",
503: b'Service Unavailable', 503: b"Service Unavailable",
504: b'Gateway Timeout', 504: b"Gateway Timeout",
505: b'HTTP Version Not Supported', 505: b"HTTP Version Not Supported",
506: b'Variant Also Negotiates', 506: b"Variant Also Negotiates",
507: b'Insufficient Storage', 507: b"Insufficient Storage",
508: b'Loop Detected', 508: b"Loop Detected",
510: b'Not Extended', 510: b"Not Extended",
511: b'Network Authentication Required' 511: b"Network Authentication Required",
} }
# According to https://tools.ietf.org/html/rfc2616#section-7.1 # According to https://tools.ietf.org/html/rfc2616#section-7.1
_ENTITY_HEADERS = frozenset([ _ENTITY_HEADERS = frozenset(
'allow', [
'content-encoding', "allow",
'content-language', "content-encoding",
'content-length', "content-language",
'content-location', "content-length",
'content-md5', "content-location",
'content-range', "content-md5",
'content-type', "content-range",
'expires', "content-type",
'last-modified', "expires",
'extension-header' "last-modified",
]) "extension-header",
]
)
# According to https://tools.ietf.org/html/rfc2616#section-13.5.1 # According to https://tools.ietf.org/html/rfc2616#section-13.5.1
_HOP_BY_HOP_HEADERS = frozenset([ _HOP_BY_HOP_HEADERS = frozenset(
'connection', [
'keep-alive', "connection",
'proxy-authenticate', "keep-alive",
'proxy-authorization', "proxy-authenticate",
'te', "proxy-authorization",
'trailers', "te",
'transfer-encoding', "trailers",
'upgrade' "transfer-encoding",
]) "upgrade",
]
)
def has_message_body(status): def has_message_body(status):
@ -110,8 +114,7 @@ def is_hop_by_hop_header(header):
return header.lower() in _HOP_BY_HOP_HEADERS return header.lower() in _HOP_BY_HOP_HEADERS
def remove_entity_headers(headers, def remove_entity_headers(headers, allowed=("content-location", "expires")):
allowed=('content-location', 'expires')):
""" """
Removes all the entity headers present in the headers given. Removes all the entity headers present in the headers given.
According to RFC 2616 Section 10.3.5, According to RFC 2616 Section 10.3.5,
@ -122,7 +125,9 @@ def remove_entity_headers(headers,
returns the headers without the entity headers returns the headers without the entity headers
""" """
allowed = set([h.lower() for h in allowed]) allowed = set([h.lower() for h in allowed])
headers = {header: value for header, value in headers.items() headers = {
if not is_entity_header(header) header: value
and header.lower() not in allowed} for header, value in headers.items()
if not is_entity_header(header) and header.lower() not in allowed
}
return headers return headers

View File

@ -5,59 +5,54 @@ import sys
LOGGING_CONFIG_DEFAULTS = dict( LOGGING_CONFIG_DEFAULTS = dict(
version=1, version=1,
disable_existing_loggers=False, disable_existing_loggers=False,
loggers={ loggers={
"root": { "root": {"level": "INFO", "handlers": ["console"]},
"level": "INFO",
"handlers": ["console"]
},
"sanic.error": { "sanic.error": {
"level": "INFO", "level": "INFO",
"handlers": ["error_console"], "handlers": ["error_console"],
"propagate": True, "propagate": True,
"qualname": "sanic.error" "qualname": "sanic.error",
}, },
"sanic.access": { "sanic.access": {
"level": "INFO", "level": "INFO",
"handlers": ["access_console"], "handlers": ["access_console"],
"propagate": True, "propagate": True,
"qualname": "sanic.access" "qualname": "sanic.access",
} },
}, },
handlers={ handlers={
"console": { "console": {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "generic", "formatter": "generic",
"stream": sys.stdout "stream": sys.stdout,
}, },
"error_console": { "error_console": {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "generic", "formatter": "generic",
"stream": sys.stderr "stream": sys.stderr,
}, },
"access_console": { "access_console": {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "access", "formatter": "access",
"stream": sys.stdout "stream": sys.stdout,
}, },
}, },
formatters={ formatters={
"generic": { "generic": {
"format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s",
"datefmt": "[%Y-%m-%d %H:%M:%S %z]", "datefmt": "[%Y-%m-%d %H:%M:%S %z]",
"class": "logging.Formatter" "class": "logging.Formatter",
}, },
"access": { "access": {
"format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: "
"%(request)s %(message)s %(status)d %(byte)d", + "%(request)s %(message)s %(status)d %(byte)d",
"datefmt": "[%Y-%m-%d %H:%M:%S %z]", "datefmt": "[%Y-%m-%d %H:%M:%S %z]",
"class": "logging.Formatter" "class": "logging.Formatter",
}, },
} },
) )
logger = logging.getLogger('sanic.root') logger = logging.getLogger("sanic.root")
error_logger = logging.getLogger('sanic.error') error_logger = logging.getLogger("sanic.error")
access_logger = logging.getLogger('sanic.access') access_logger = logging.getLogger("sanic.access")

View File

@ -18,7 +18,7 @@ def _iter_module_files():
for module in list(sys.modules.values()): for module in list(sys.modules.values()):
if module is None: if module is None:
continue continue
filename = getattr(module, '__file__', None) filename = getattr(module, "__file__", None)
if filename: if filename:
old = None old = None
while not os.path.isfile(filename): while not os.path.isfile(filename):
@ -27,7 +27,7 @@ def _iter_module_files():
if filename == old: if filename == old:
break break
else: else:
if filename[-4:] in ('.pyc', '.pyo'): if filename[-4:] in (".pyc", ".pyo"):
filename = filename[:-1] filename = filename[:-1]
yield filename yield filename
@ -45,11 +45,13 @@ def restart_with_reloader():
""" """
args = _get_args_for_reloading() args = _get_args_for_reloading()
new_environ = os.environ.copy() new_environ = os.environ.copy()
new_environ['SANIC_SERVER_RUNNING'] = 'true' new_environ["SANIC_SERVER_RUNNING"] = "true"
cmd = ' '.join(args) cmd = " ".join(args)
worker_process = Process( worker_process = Process(
target=subprocess.call, args=(cmd,), target=subprocess.call,
kwargs=dict(shell=True, env=new_environ)) args=(cmd,),
kwargs=dict(shell=True, env=new_environ),
)
worker_process.start() worker_process.start()
return worker_process return worker_process
@ -67,8 +69,10 @@ def kill_process_children_unix(pid):
children_list_pid = children_list_file.read().split() children_list_pid = children_list_file.read().split()
for child_pid in children_list_pid: for child_pid in children_list_pid:
children_proc_path = "/proc/%s/task/%s/children" % \ children_proc_path = "/proc/%s/task/%s/children" % (
(child_pid, child_pid) child_pid,
child_pid,
)
if not os.path.isfile(children_proc_path): if not os.path.isfile(children_proc_path):
continue continue
with open(children_proc_path) as children_list_file_2: with open(children_proc_path) as children_list_file_2:
@ -90,7 +94,7 @@ def kill_process_children_osx(pid):
:param pid: PID of parent process (process ID) :param pid: PID of parent process (process ID)
:return: Nothing :return: Nothing
""" """
subprocess.run(['pkill', '-P', str(pid)]) subprocess.run(["pkill", "-P", str(pid)])
def kill_process_children(pid): def kill_process_children(pid):
@ -99,12 +103,12 @@ def kill_process_children(pid):
:param pid: PID of parent process (process ID) :param pid: PID of parent process (process ID)
:return: Nothing :return: Nothing
""" """
if sys.platform == 'darwin': if sys.platform == "darwin":
kill_process_children_osx(pid) kill_process_children_osx(pid)
elif sys.platform == 'linux': elif sys.platform == "linux":
kill_process_children_unix(pid) kill_process_children_unix(pid)
else: else:
pass # should signal error here pass # should signal error here
def kill_program_completly(proc): def kill_program_completly(proc):
@ -127,9 +131,11 @@ def watchdog(sleep_interval):
mtimes = {} mtimes = {}
worker_process = restart_with_reloader() worker_process = restart_with_reloader()
signal.signal( signal.signal(
signal.SIGTERM, lambda *args: kill_program_completly(worker_process)) signal.SIGTERM, lambda *args: kill_program_completly(worker_process)
)
signal.signal( signal.signal(
signal.SIGINT, lambda *args: kill_program_completly(worker_process)) signal.SIGINT, lambda *args: kill_program_completly(worker_process)
)
while True: while True:
for filename in _iter_module_files(): for filename in _iter_module_files():
try: try:

View File

@ -10,9 +10,11 @@ try:
from ujson import loads as json_loads from ujson import loads as json_loads
except ImportError: except ImportError:
if sys.version_info[:2] == (3, 5): if sys.version_info[:2] == (3, 5):
def json_loads(data): def json_loads(data):
# on Python 3.5 json.loads only supports str not bytes # on Python 3.5 json.loads only supports str not bytes
return json.loads(data.decode()) return json.loads(data.decode())
else: else:
json_loads = json.loads json_loads = json.loads
@ -43,11 +45,28 @@ class RequestParameters(dict):
class Request(dict): class Request(dict):
"""Properties of an HTTP request such as URL, headers, etc.""" """Properties of an HTTP request such as URL, headers, etc."""
__slots__ = ( __slots__ = (
'app', 'headers', 'version', 'method', '_cookies', 'transport', "app",
'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', "headers",
'_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr', "version",
'_socket', '_port', '__weakref__', 'raw_url' "method",
"_cookies",
"transport",
"body",
"parsed_json",
"parsed_args",
"parsed_form",
"parsed_files",
"_ip",
"_parsed_url",
"uri_template",
"stream",
"_remote_addr",
"_socket",
"_port",
"__weakref__",
"raw_url",
) )
def __init__(self, url_bytes, headers, version, method, transport): def __init__(self, url_bytes, headers, version, method, transport):
@ -73,10 +92,10 @@ class Request(dict):
def __repr__(self): def __repr__(self):
if self.method is None or not self.path: if self.method is None or not self.path:
return '<{0}>'.format(self.__class__.__name__) return "<{0}>".format(self.__class__.__name__)
return '<{0}: {1} {2}>'.format(self.__class__.__name__, return "<{0}: {1} {2}>".format(
self.method, self.__class__.__name__, self.method, self.path
self.path) )
def __bool__(self): def __bool__(self):
if self.transport: if self.transport:
@ -106,8 +125,8 @@ class Request(dict):
:return: token related to request :return: token related to request
""" """
prefixes = ('Bearer', 'Token') prefixes = ("Bearer", "Token")
auth_header = self.headers.get('Authorization') auth_header = self.headers.get("Authorization")
if auth_header is not None: if auth_header is not None:
for prefix in prefixes: for prefix in prefixes:
@ -122,17 +141,20 @@ class Request(dict):
self.parsed_form = RequestParameters() self.parsed_form = RequestParameters()
self.parsed_files = RequestParameters() self.parsed_files = RequestParameters()
content_type = self.headers.get( content_type = self.headers.get(
'Content-Type', DEFAULT_HTTP_CONTENT_TYPE) "Content-Type", DEFAULT_HTTP_CONTENT_TYPE
)
content_type, parameters = parse_header(content_type) content_type, parameters = parse_header(content_type)
try: try:
if content_type == 'application/x-www-form-urlencoded': if content_type == "application/x-www-form-urlencoded":
self.parsed_form = RequestParameters( self.parsed_form = RequestParameters(
parse_qs(self.body.decode('utf-8'))) parse_qs(self.body.decode("utf-8"))
elif content_type == 'multipart/form-data': )
elif content_type == "multipart/form-data":
# TODO: Stream this instead of reading to/from memory # TODO: Stream this instead of reading to/from memory
boundary = parameters['boundary'].encode('utf-8') boundary = parameters["boundary"].encode("utf-8")
self.parsed_form, self.parsed_files = ( self.parsed_form, self.parsed_files = parse_multipart_form(
parse_multipart_form(self.body, boundary)) self.body, boundary
)
except Exception: except Exception:
error_logger.exception("Failed when parsing form") error_logger.exception("Failed when parsing form")
@ -150,7 +172,8 @@ class Request(dict):
if self.parsed_args is None: if self.parsed_args is None:
if self.query_string: if self.query_string:
self.parsed_args = RequestParameters( self.parsed_args = RequestParameters(
parse_qs(self.query_string)) parse_qs(self.query_string)
)
else: else:
self.parsed_args = RequestParameters() self.parsed_args = RequestParameters()
return self.parsed_args return self.parsed_args
@ -162,37 +185,40 @@ class Request(dict):
@property @property
def cookies(self): def cookies(self):
if self._cookies is None: if self._cookies is None:
cookie = self.headers.get('Cookie') cookie = self.headers.get("Cookie")
if cookie is not None: if cookie is not None:
cookies = SimpleCookie() cookies = SimpleCookie()
cookies.load(cookie) cookies.load(cookie)
self._cookies = {name: cookie.value self._cookies = {
for name, cookie in cookies.items()} name: cookie.value for name, cookie in cookies.items()
}
else: else:
self._cookies = {} self._cookies = {}
return self._cookies return self._cookies
@property @property
def ip(self): def ip(self):
if not hasattr(self, '_socket'): if not hasattr(self, "_socket"):
self._get_address() self._get_address()
return self._ip return self._ip
@property @property
def port(self): def port(self):
if not hasattr(self, '_socket'): if not hasattr(self, "_socket"):
self._get_address() self._get_address()
return self._port return self._port
@property @property
def socket(self): def socket(self):
if not hasattr(self, '_socket'): if not hasattr(self, "_socket"):
self._get_address() self._get_address()
return self._socket return self._socket
def _get_address(self): def _get_address(self):
self._socket = self.transport.get_extra_info('peername') or \ self._socket = self.transport.get_extra_info("peername") or (
(None, None) None,
None,
)
self._ip = self._socket[0] self._ip = self._socket[0]
self._port = self._socket[1] self._port = self._socket[1]
@ -202,29 +228,31 @@ class Request(dict):
:return: original client ip. :return: original client ip.
""" """
if not hasattr(self, '_remote_addr'): if not hasattr(self, "_remote_addr"):
forwarded_for = self.headers.get('X-Forwarded-For', '').split(',') forwarded_for = self.headers.get("X-Forwarded-For", "").split(",")
remote_addrs = [ remote_addrs = [
addr for addr in [ addr
addr.strip() for addr in forwarded_for for addr in [addr.strip() for addr in forwarded_for]
] if addr if addr
] ]
if len(remote_addrs) > 0: if len(remote_addrs) > 0:
self._remote_addr = remote_addrs[0] self._remote_addr = remote_addrs[0]
else: else:
self._remote_addr = '' self._remote_addr = ""
return self._remote_addr return self._remote_addr
@property @property
def scheme(self): def scheme(self):
if self.app.websocket_enabled \ if (
and self.headers.get('upgrade') == 'websocket': self.app.websocket_enabled
scheme = 'ws' and self.headers.get("upgrade") == "websocket"
):
scheme = "ws"
else: else:
scheme = 'http' scheme = "http"
if self.transport.get_extra_info('sslcontext'): if self.transport.get_extra_info("sslcontext"):
scheme += 's' scheme += "s"
return scheme return scheme
@ -232,11 +260,11 @@ class Request(dict):
def host(self): def host(self):
# it appears that httptools doesn't return the host # it appears that httptools doesn't return the host
# so pull it from the headers # so pull it from the headers
return self.headers.get('Host', '') return self.headers.get("Host", "")
@property @property
def content_type(self): def content_type(self):
return self.headers.get('Content-Type', DEFAULT_HTTP_CONTENT_TYPE) return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE)
@property @property
def match_info(self): def match_info(self):
@ -245,27 +273,23 @@ class Request(dict):
@property @property
def path(self): def path(self):
return self._parsed_url.path.decode('utf-8') return self._parsed_url.path.decode("utf-8")
@property @property
def query_string(self): def query_string(self):
if self._parsed_url.query: if self._parsed_url.query:
return self._parsed_url.query.decode('utf-8') return self._parsed_url.query.decode("utf-8")
else: else:
return '' return ""
@property @property
def url(self): def url(self):
return urlunparse(( return urlunparse(
self.scheme, (self.scheme, self.host, self.path, None, self.query_string, None)
self.host, )
self.path,
None,
self.query_string,
None))
File = namedtuple('File', ['type', 'body', 'name']) File = namedtuple("File", ["type", "body", "name"])
def parse_multipart_form(body, boundary): def parse_multipart_form(body, boundary):
@ -281,37 +305,38 @@ def parse_multipart_form(body, boundary):
form_parts = body.split(boundary) form_parts = body.split(boundary)
for form_part in form_parts[1:-1]: for form_part in form_parts[1:-1]:
file_name = None file_name = None
content_type = 'text/plain' content_type = "text/plain"
content_charset = 'utf-8' content_charset = "utf-8"
field_name = None field_name = None
line_index = 2 line_index = 2
line_end_index = 0 line_end_index = 0
while not line_end_index == -1: while not line_end_index == -1:
line_end_index = form_part.find(b'\r\n', line_index) line_end_index = form_part.find(b"\r\n", line_index)
form_line = form_part[line_index:line_end_index].decode('utf-8') form_line = form_part[line_index:line_end_index].decode("utf-8")
line_index = line_end_index + 2 line_index = line_end_index + 2
if not form_line: if not form_line:
break break
colon_index = form_line.index(':') colon_index = form_line.index(":")
form_header_field = form_line[0:colon_index].lower() form_header_field = form_line[0:colon_index].lower()
form_header_value, form_parameters = parse_header( form_header_value, form_parameters = parse_header(
form_line[colon_index + 2:]) form_line[colon_index + 2 :]
)
if form_header_field == 'content-disposition': if form_header_field == "content-disposition":
file_name = form_parameters.get('filename') file_name = form_parameters.get("filename")
field_name = form_parameters.get('name') field_name = form_parameters.get("name")
elif form_header_field == 'content-type': elif form_header_field == "content-type":
content_type = form_header_value content_type = form_header_value
content_charset = form_parameters.get('charset', 'utf-8') content_charset = form_parameters.get("charset", "utf-8")
if field_name: if field_name:
post_data = form_part[line_index:-4] post_data = form_part[line_index:-4]
if file_name: if file_name:
form_file = File(type=content_type, form_file = File(
name=file_name, type=content_type, name=file_name, body=post_data
body=post_data) )
if field_name in files: if field_name in files:
files[field_name].append(form_file) files[field_name].append(form_file)
else: else:
@ -323,7 +348,9 @@ def parse_multipart_form(body, boundary):
else: else:
fields[field_name] = [value] fields[field_name] = [value]
else: else:
logger.debug('Form-data field does not have a \'name\' parameter \ logger.debug(
in the Content-Disposition header') "Form-data field does not have a 'name' parameter \
in the Content-Disposition header"
)
return fields, files return fields, files

View File

@ -24,16 +24,18 @@ class BaseHTTPResponse:
return str(data).encode() return str(data).encode()
def _parse_headers(self): def _parse_headers(self):
headers = b'' headers = b""
for name, value in self.headers.items(): for name, value in self.headers.items():
try: try:
headers += ( headers += b"%b: %b\r\n" % (
b'%b: %b\r\n' % ( name.encode(),
name.encode(), value.encode('utf-8'))) value.encode("utf-8"),
)
except AttributeError: except AttributeError:
headers += ( headers += b"%b: %b\r\n" % (
b'%b: %b\r\n' % ( str(name).encode(),
str(name).encode(), str(value).encode('utf-8'))) str(value).encode("utf-8"),
)
return headers return headers
@ -46,12 +48,17 @@ class BaseHTTPResponse:
class StreamingHTTPResponse(BaseHTTPResponse): class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = ( __slots__ = (
'protocol', 'streaming_fn', 'status', "protocol",
'content_type', 'headers', '_cookies' "streaming_fn",
"status",
"content_type",
"headers",
"_cookies",
) )
def __init__(self, streaming_fn, status=200, headers=None, def __init__(
content_type='text/plain'): self, streaming_fn, status=200, headers=None, content_type="text/plain"
):
self.content_type = content_type self.content_type = content_type
self.streaming_fn = streaming_fn self.streaming_fn = streaming_fn
self.status = status self.status = status
@ -66,61 +73,69 @@ class StreamingHTTPResponse(BaseHTTPResponse):
if type(data) != bytes: if type(data) != bytes:
data = self._encode_body(data) data = self._encode_body(data)
self.protocol.push_data( self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data))
b"%x\r\n%b\r\n" % (len(data), data))
await self.protocol.drain() await self.protocol.drain()
async def stream( async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None): self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
"""Streams headers, runs the `streaming_fn` callback that writes """Streams headers, runs the `streaming_fn` callback that writes
content to the response body, then finalizes the response body. content to the response body, then finalizes the response body.
""" """
headers = self.get_headers( headers = self.get_headers(
version, keep_alive=keep_alive, version,
keep_alive_timeout=keep_alive_timeout) keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout,
)
self.protocol.push_data(headers) self.protocol.push_data(headers)
await self.protocol.drain() await self.protocol.drain()
await self.streaming_fn(self) await self.streaming_fn(self)
self.protocol.push_data(b'0\r\n\r\n') self.protocol.push_data(b"0\r\n\r\n")
# no need to await drain here after this write, because it is the # no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it. # very last thing we write and nothing needs to wait for it.
def get_headers( def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None): self, version="1.1", keep_alive=False, keep_alive_timeout=None
):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
timeout_header = b'' timeout_header = b""
if keep_alive and keep_alive_timeout is not None: if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
self.headers['Transfer-Encoding'] = 'chunked' self.headers["Transfer-Encoding"] = "chunked"
self.headers.pop('Content-Length', None) self.headers.pop("Content-Length", None)
self.headers['Content-Type'] = self.headers.get( self.headers["Content-Type"] = self.headers.get(
'Content-Type', self.content_type) "Content-Type", self.content_type
)
headers = self._parse_headers() headers = self._parse_headers()
if self.status is 200: if self.status is 200:
status = b'OK' status = b"OK"
else: else:
status = STATUS_CODES.get(self.status) status = STATUS_CODES.get(self.status)
return (b'HTTP/%b %d %b\r\n' return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % (
b'%b' version.encode(),
b'%b\r\n') % ( self.status,
version.encode(), status,
self.status, timeout_header,
status, headers,
timeout_header, )
headers
)
class HTTPResponse(BaseHTTPResponse): class HTTPResponse(BaseHTTPResponse):
__slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') __slots__ = ("body", "status", "content_type", "headers", "_cookies")
def __init__(self, body=None, status=200, headers=None, def __init__(
content_type='text/plain', body_bytes=b''): self,
body=None,
status=200,
headers=None,
content_type="text/plain",
body_bytes=b"",
):
self.content_type = content_type self.content_type = content_type
if body is not None: if body is not None:
@ -132,22 +147,23 @@ class HTTPResponse(BaseHTTPResponse):
self.headers = CIMultiDict(headers or {}) self.headers = CIMultiDict(headers or {})
self._cookies = None self._cookies = None
def output( def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None):
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
# This is all returned in a kind-of funky way # This is all returned in a kind-of funky way
# We tried to make this as fast as possible in pure python # We tried to make this as fast as possible in pure python
timeout_header = b'' timeout_header = b""
if keep_alive and keep_alive_timeout is not None: if keep_alive and keep_alive_timeout is not None:
timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout
body = b'' body = b""
if has_message_body(self.status): if has_message_body(self.status):
body = self.body body = self.body
self.headers['Content-Length'] = self.headers.get( self.headers["Content-Length"] = self.headers.get(
'Content-Length', len(self.body)) "Content-Length", len(self.body)
)
self.headers['Content-Type'] = self.headers.get( self.headers["Content-Type"] = self.headers.get(
'Content-Type', self.content_type) "Content-Type", self.content_type
)
if self.status in (304, 412): if self.status in (304, 412):
self.headers = remove_entity_headers(self.headers) self.headers = remove_entity_headers(self.headers)
@ -155,23 +171,21 @@ class HTTPResponse(BaseHTTPResponse):
headers = self._parse_headers() headers = self._parse_headers()
if self.status is 200: if self.status is 200:
status = b'OK' status = b"OK"
else: else:
status = STATUS_CODES.get(self.status, b'UNKNOWN RESPONSE') status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE")
return (b'HTTP/%b %d %b\r\n' return (
b'Connection: %b\r\n' b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b"
b'%b' ) % (
b'%b\r\n' version.encode(),
b'%b') % ( self.status,
version.encode(), status,
self.status, b"keep-alive" if keep_alive else b"close",
status, timeout_header,
b'keep-alive' if keep_alive else b'close', headers,
timeout_header, body,
headers, )
body
)
@property @property
def cookies(self): def cookies(self):
@ -180,9 +194,14 @@ class HTTPResponse(BaseHTTPResponse):
return self._cookies return self._cookies
def json(body, status=200, headers=None, def json(
content_type="application/json", dumps=json_dumps, body,
**kwargs): status=200,
headers=None,
content_type="application/json",
dumps=json_dumps,
**kwargs
):
""" """
Returns response object with body in json format. Returns response object with body in json format.
@ -191,12 +210,17 @@ def json(body, status=200, headers=None,
:param headers: Custom Headers. :param headers: Custom Headers.
:param kwargs: Remaining arguments that are passed to the json encoder. :param kwargs: Remaining arguments that are passed to the json encoder.
""" """
return HTTPResponse(dumps(body, **kwargs), headers=headers, return HTTPResponse(
status=status, content_type=content_type) dumps(body, **kwargs),
headers=headers,
status=status,
content_type=content_type,
)
def text(body, status=200, headers=None, def text(
content_type="text/plain; charset=utf-8"): body, status=200, headers=None, content_type="text/plain; charset=utf-8"
):
""" """
Returns response object with body in text format. Returns response object with body in text format.
@ -206,12 +230,13 @@ def text(body, status=200, headers=None,
:param content_type: the content type (string) of the response :param content_type: the content type (string) of the response
""" """
return HTTPResponse( return HTTPResponse(
body, status=status, headers=headers, body, status=status, headers=headers, content_type=content_type
content_type=content_type) )
def raw(body, status=200, headers=None, def raw(
content_type="application/octet-stream"): body, status=200, headers=None, content_type="application/octet-stream"
):
""" """
Returns response object without encoding the body. Returns response object without encoding the body.
@ -220,8 +245,12 @@ def raw(body, status=200, headers=None,
:param headers: Custom Headers. :param headers: Custom Headers.
:param content_type: the content type (string) of the response. :param content_type: the content type (string) of the response.
""" """
return HTTPResponse(body_bytes=body, status=status, headers=headers, return HTTPResponse(
content_type=content_type) body_bytes=body,
status=status,
headers=headers,
content_type=content_type,
)
def html(body, status=200, headers=None): def html(body, status=200, headers=None):
@ -232,12 +261,22 @@ def html(body, status=200, headers=None):
:param status: Response code. :param status: Response code.
:param headers: Custom Headers. :param headers: Custom Headers.
""" """
return HTTPResponse(body, status=status, headers=headers, return HTTPResponse(
content_type="text/html; charset=utf-8") body,
status=status,
headers=headers,
content_type="text/html; charset=utf-8",
)
async def file(location, status=200, mime_type=None, headers=None, async def file(
filename=None, _range=None): location,
status=200,
mime_type=None,
headers=None,
filename=None,
_range=None,
):
"""Return a response object with file data. """Return a response object with file data.
:param location: Location of file on system. :param location: Location of file on system.
@ -249,28 +288,40 @@ async def file(location, status=200, mime_type=None, headers=None,
headers = headers or {} headers = headers or {}
if filename: if filename:
headers.setdefault( headers.setdefault(
'Content-Disposition', "Content-Disposition", 'attachment; filename="{}"'.format(filename)
'attachment; filename="{}"'.format(filename)) )
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
async with open_async(location, mode='rb') as _file: async with open_async(location, mode="rb") as _file:
if _range: if _range:
await _file.seek(_range.start) await _file.seek(_range.start)
out_stream = await _file.read(_range.size) out_stream = await _file.read(_range.size)
headers['Content-Range'] = 'bytes %s-%s/%s' % ( headers["Content-Range"] = "bytes %s-%s/%s" % (
_range.start, _range.end, _range.total) _range.start,
_range.end,
_range.total,
)
else: else:
out_stream = await _file.read() out_stream = await _file.read()
mime_type = mime_type or guess_type(filename)[0] or 'text/plain' mime_type = mime_type or guess_type(filename)[0] or "text/plain"
return HTTPResponse(status=status, return HTTPResponse(
headers=headers, status=status,
content_type=mime_type, headers=headers,
body_bytes=out_stream) content_type=mime_type,
body_bytes=out_stream,
)
async def file_stream(location, status=200, chunk_size=4096, mime_type=None, async def file_stream(
headers=None, filename=None, _range=None): location,
status=200,
chunk_size=4096,
mime_type=None,
headers=None,
filename=None,
_range=None,
):
"""Return a streaming response object with file data. """Return a streaming response object with file data.
:param location: Location of file on system. :param location: Location of file on system.
@ -283,11 +334,11 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
headers = headers or {} headers = headers or {}
if filename: if filename:
headers.setdefault( headers.setdefault(
'Content-Disposition', "Content-Disposition", 'attachment; filename="{}"'.format(filename)
'attachment; filename="{}"'.format(filename)) )
filename = filename or path.split(location)[-1] filename = filename or path.split(location)[-1]
_file = await open_async(location, mode='rb') _file = await open_async(location, mode="rb")
async def _streaming_fn(response): async def _streaming_fn(response):
nonlocal _file, chunk_size nonlocal _file, chunk_size
@ -312,19 +363,27 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None,
await _file.close() await _file.close()
return # Returning from this fn closes the stream return # Returning from this fn closes the stream
mime_type = mime_type or guess_type(filename)[0] or 'text/plain' mime_type = mime_type or guess_type(filename)[0] or "text/plain"
if _range: if _range:
headers['Content-Range'] = 'bytes %s-%s/%s' % ( headers["Content-Range"] = "bytes %s-%s/%s" % (
_range.start, _range.end, _range.total) _range.start,
return StreamingHTTPResponse(streaming_fn=_streaming_fn, _range.end,
status=status, _range.total,
headers=headers, )
content_type=mime_type) return StreamingHTTPResponse(
streaming_fn=_streaming_fn,
status=status,
headers=headers,
content_type=mime_type,
)
def stream( def stream(
streaming_fn, status=200, headers=None, streaming_fn,
content_type="text/plain; charset=utf-8"): status=200,
headers=None,
content_type="text/plain; charset=utf-8",
):
"""Accepts an coroutine `streaming_fn` which can be used to """Accepts an coroutine `streaming_fn` which can be used to
write chunks to a streaming response. Returns a `StreamingHTTPResponse`. write chunks to a streaming response. Returns a `StreamingHTTPResponse`.
@ -344,15 +403,13 @@ def stream(
:param headers: Custom Headers. :param headers: Custom Headers.
""" """
return StreamingHTTPResponse( return StreamingHTTPResponse(
streaming_fn, streaming_fn, headers=headers, content_type=content_type, status=status
headers=headers,
content_type=content_type,
status=status
) )
def redirect(to, headers=None, status=302, def redirect(
content_type="text/html; charset=utf-8"): to, headers=None, status=302, content_type="text/html; charset=utf-8"
):
"""Abort execution and cause a 302 redirect (by default). """Abort execution and cause a 302 redirect (by default).
:param to: path or fully qualified URL to redirect to :param to: path or fully qualified URL to redirect to
@ -367,9 +424,8 @@ def redirect(to, headers=None, status=302,
safe_to = quote_plus(to, safe=":/#?&=@[]!$&'()*+,;") safe_to = quote_plus(to, safe=":/#?&=@[]!$&'()*+,;")
# According to RFC 7231, a relative URI is now permitted. # According to RFC 7231, a relative URI is now permitted.
headers['Location'] = safe_to headers["Location"] = safe_to
return HTTPResponse( return HTTPResponse(
status=status, status=status, headers=headers, content_type=content_type
headers=headers, )
content_type=content_type)

View File

@ -9,25 +9,28 @@ from sanic.exceptions import NotFound, MethodNotSupported
from sanic.views import CompositionView from sanic.views import CompositionView
Route = namedtuple( Route = namedtuple(
'Route', "Route", ["handler", "methods", "pattern", "parameters", "name", "uri"]
['handler', 'methods', 'pattern', 'parameters', 'name', 'uri']) )
Parameter = namedtuple('Parameter', ['name', 'cast']) Parameter = namedtuple("Parameter", ["name", "cast"])
REGEX_TYPES = { REGEX_TYPES = {
'string': (str, r'[^/]+'), "string": (str, r"[^/]+"),
'int': (int, r'\d+'), "int": (int, r"\d+"),
'number': (float, r'[0-9\\.]+'), "number": (float, r"[0-9\\.]+"),
'alpha': (str, r'[A-Za-z]+'), "alpha": (str, r"[A-Za-z]+"),
'path': (str, r'[^/].*?'), "path": (str, r"[^/].*?"),
'uuid': (uuid.UUID, r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' "uuid": (
r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}') uuid.UUID,
r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-"
r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}",
),
} }
ROUTER_CACHE_SIZE = 1024 ROUTER_CACHE_SIZE = 1024
def url_hash(url): def url_hash(url):
return url.count('/') return url.count("/")
class RouteExists(Exception): class RouteExists(Exception):
@ -68,10 +71,11 @@ class Router:
also be passed in as the type. The argument given to the function will also be passed in as the type. The argument given to the function will
always be a string, independent of the type. always be a string, independent of the type.
""" """
routes_static = None routes_static = None
routes_dynamic = None routes_dynamic = None
routes_always_check = None routes_always_check = None
parameter_pattern = re.compile(r'<(.+?)>') parameter_pattern = re.compile(r"<(.+?)>")
def __init__(self): def __init__(self):
self.routes_all = {} self.routes_all = {}
@ -98,9 +102,9 @@ class Router:
""" """
# We could receive NAME or NAME:PATTERN # We could receive NAME or NAME:PATTERN
name = parameter_string name = parameter_string
pattern = 'string' pattern = "string"
if ':' in parameter_string: if ":" in parameter_string:
name, pattern = parameter_string.split(':', 1) name, pattern = parameter_string.split(":", 1)
if not name: if not name:
raise ValueError( raise ValueError(
"Invalid parameter syntax: {}".format(parameter_string) "Invalid parameter syntax: {}".format(parameter_string)
@ -112,8 +116,16 @@ class Router:
return name, _type, pattern return name, _type, pattern
def add(self, uri, methods, handler, host=None, strict_slashes=False, def add(
version=None, name=None): self,
uri,
methods,
handler,
host=None,
strict_slashes=False,
version=None,
name=None,
):
"""Add a handler to the route list """Add a handler to the route list
:param uri: path to match :param uri: path to match
@ -127,8 +139,8 @@ class Router:
:return: Nothing :return: Nothing
""" """
if version is not None: if version is not None:
version = re.escape(str(version).strip('/').lstrip('v')) version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join(["/v{}".format(version), uri.lstrip('/')]) uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
# add regular version # add regular version
self._add(uri, methods, handler, host, name) self._add(uri, methods, handler, host, name)
@ -143,28 +155,26 @@ class Router:
return return
# Add versions with and without trailing / # Add versions with and without trailing /
slashed_methods = self.routes_all.get(uri + '/', frozenset({})) slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
unslashed_methods = self.routes_all.get(uri[:-1], frozenset({})) unslashed_methods = self.routes_all.get(uri[:-1], frozenset({}))
if isinstance(methods, Iterable): if isinstance(methods, Iterable):
_slash_is_missing = all(method in slashed_methods for _slash_is_missing = all(
method in methods) method in slashed_methods for method in methods
_without_slash_is_missing = all(method in unslashed_methods for )
method in methods) _without_slash_is_missing = all(
method in unslashed_methods for method in methods
)
else: else:
_slash_is_missing = methods in slashed_methods _slash_is_missing = methods in slashed_methods
_without_slash_is_missing = methods in unslashed_methods _without_slash_is_missing = methods in unslashed_methods
slash_is_missing = ( slash_is_missing = not uri[-1] == "/" and not _slash_is_missing
not uri[-1] == '/' and not _slash_is_missing
)
without_slash_is_missing = ( without_slash_is_missing = (
uri[-1] == '/' and not uri[-1] == "/" and not _without_slash_is_missing and not uri == "/"
_without_slash_is_missing and not
uri == '/'
) )
# add version with trailing slash # add version with trailing slash
if slash_is_missing: if slash_is_missing:
self._add(uri + '/', methods, handler, host, name) self._add(uri + "/", methods, handler, host, name)
# add version without trailing slash # add version without trailing slash
elif without_slash_is_missing: elif without_slash_is_missing:
self._add(uri[:-1], methods, handler, host, name) self._add(uri[:-1], methods, handler, host, name)
@ -187,8 +197,10 @@ class Router:
else: else:
if not isinstance(host, Iterable): if not isinstance(host, Iterable):
raise ValueError("Expected either string or Iterable of " raise ValueError(
"host strings, not {!r}".format(host)) "Expected either string or Iterable of "
"host strings, not {!r}".format(host)
)
for host_ in host: for host_ in host:
self.add(uri, methods, handler, host_, name) self.add(uri, methods, handler, host_, name)
@ -208,38 +220,39 @@ class Router:
if name in parameter_names: if name in parameter_names:
raise ParameterNameConflicts( raise ParameterNameConflicts(
"Multiple parameter named <{name}> " "Multiple parameter named <{name}> "
"in route uri {uri}".format(name=name, uri=uri)) "in route uri {uri}".format(name=name, uri=uri)
)
parameter_names.add(name) parameter_names.add(name)
parameter = Parameter( parameter = Parameter(name=name, cast=_type)
name=name, cast=_type)
parameters.append(parameter) parameters.append(parameter)
# Mark the whole route as unhashable if it has the hash key in it # Mark the whole route as unhashable if it has the hash key in it
if re.search(r'(^|[^^]){1}/', pattern): if re.search(r"(^|[^^]){1}/", pattern):
properties['unhashable'] = True properties["unhashable"] = True
# Mark the route as unhashable if it matches the hash key # Mark the route as unhashable if it matches the hash key
elif re.search(r'/', pattern): elif re.search(r"/", pattern):
properties['unhashable'] = True properties["unhashable"] = True
return '({})'.format(pattern) return "({})".format(pattern)
pattern_string = re.sub(self.parameter_pattern, add_parameter, uri) pattern_string = re.sub(self.parameter_pattern, add_parameter, uri)
pattern = re.compile(r'^{}$'.format(pattern_string)) pattern = re.compile(r"^{}$".format(pattern_string))
def merge_route(route, methods, handler): def merge_route(route, methods, handler):
# merge to the existing route when possible. # merge to the existing route when possible.
if not route.methods or not methods: if not route.methods or not methods:
# method-unspecified routes are not mergeable. # method-unspecified routes are not mergeable.
raise RouteExists( raise RouteExists("Route already registered: {}".format(uri))
"Route already registered: {}".format(uri))
elif route.methods.intersection(methods): elif route.methods.intersection(methods):
# already existing method is not overloadable. # already existing method is not overloadable.
duplicated = methods.intersection(route.methods) duplicated = methods.intersection(route.methods)
raise RouteExists( raise RouteExists(
"Route already registered: {} [{}]".format( "Route already registered: {} [{}]".format(
uri, ','.join(list(duplicated)))) uri, ",".join(list(duplicated))
)
)
if isinstance(route.handler, CompositionView): if isinstance(route.handler, CompositionView):
view = route.handler view = route.handler
else: else:
@ -247,19 +260,22 @@ class Router:
view.add(route.methods, route.handler) view.add(route.methods, route.handler)
view.add(methods, handler) view.add(methods, handler)
route = route._replace( route = route._replace(
handler=view, methods=methods.union(route.methods)) handler=view, methods=methods.union(route.methods)
)
return route return route
if parameters: if parameters:
# TODO: This is too complex, we need to reduce the complexity # TODO: This is too complex, we need to reduce the complexity
if properties['unhashable']: if properties["unhashable"]:
routes_to_check = self.routes_always_check routes_to_check = self.routes_always_check
ndx, route = self.check_dynamic_route_exists( ndx, route = self.check_dynamic_route_exists(
pattern, routes_to_check, parameters) pattern, routes_to_check, parameters
)
else: else:
routes_to_check = self.routes_dynamic[url_hash(uri)] routes_to_check = self.routes_dynamic[url_hash(uri)]
ndx, route = self.check_dynamic_route_exists( ndx, route = self.check_dynamic_route_exists(
pattern, routes_to_check, parameters) pattern, routes_to_check, parameters
)
if ndx != -1: if ndx != -1:
# Pop the ndx of the route, no dups of the same route # Pop the ndx of the route, no dups of the same route
routes_to_check.pop(ndx) routes_to_check.pop(ndx)
@ -270,35 +286,41 @@ class Router:
# if available # if available
# special prefix for static files # special prefix for static files
is_static = False is_static = False
if name and name.startswith('_static_'): if name and name.startswith("_static_"):
is_static = True is_static = True
name = name.split('_static_', 1)[-1] name = name.split("_static_", 1)[-1]
if hasattr(handler, '__blueprintname__'): if hasattr(handler, "__blueprintname__"):
handler_name = '{}.{}'.format( handler_name = "{}.{}".format(
handler.__blueprintname__, name or handler.__name__) handler.__blueprintname__, name or handler.__name__
)
else: else:
handler_name = name or getattr(handler, '__name__', None) handler_name = name or getattr(handler, "__name__", None)
if route: if route:
route = merge_route(route, methods, handler) route = merge_route(route, methods, handler)
else: else:
route = Route( route = Route(
handler=handler, methods=methods, pattern=pattern, handler=handler,
parameters=parameters, name=handler_name, uri=uri) methods=methods,
pattern=pattern,
parameters=parameters,
name=handler_name,
uri=uri,
)
self.routes_all[uri] = route self.routes_all[uri] = route
if is_static: if is_static:
pair = self.routes_static_files.get(handler_name) pair = self.routes_static_files.get(handler_name)
if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])):
self.routes_static_files[handler_name] = (uri, route) self.routes_static_files[handler_name] = (uri, route)
else: else:
pair = self.routes_names.get(handler_name) pair = self.routes_names.get(handler_name)
if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])):
self.routes_names[handler_name] = (uri, route) self.routes_names[handler_name] = (uri, route)
if properties['unhashable']: if properties["unhashable"]:
self.routes_always_check.append(route) self.routes_always_check.append(route)
elif parameters: elif parameters:
self.routes_dynamic[url_hash(uri)].append(route) self.routes_dynamic[url_hash(uri)].append(route)
@ -333,8 +355,10 @@ class Router:
if route in self.routes_always_check: if route in self.routes_always_check:
self.routes_always_check.remove(route) self.routes_always_check.remove(route)
elif url_hash(uri) in self.routes_dynamic \ elif (
and route in self.routes_dynamic[url_hash(uri)]: url_hash(uri) in self.routes_dynamic
and route in self.routes_dynamic[url_hash(uri)]
):
self.routes_dynamic[url_hash(uri)].remove(route) self.routes_dynamic[url_hash(uri)].remove(route)
else: else:
self.routes_static.pop(uri) self.routes_static.pop(uri)
@ -353,7 +377,7 @@ class Router:
if not view_name: if not view_name:
return (None, None) return (None, None)
if view_name == 'static' or view_name.endswith('.static'): if view_name == "static" or view_name.endswith(".static"):
return self.routes_static_files.get(name, (None, None)) return self.routes_static_files.get(name, (None, None))
return self.routes_names.get(view_name, (None, None)) return self.routes_names.get(view_name, (None, None))
@ -367,14 +391,15 @@ class Router:
""" """
# No virtual hosts specified; default behavior # No virtual hosts specified; default behavior
if not self.hosts: if not self.hosts:
return self._get(request.path, request.method, '') return self._get(request.path, request.method, "")
# virtual hosts specified; try to match route to the host header # virtual hosts specified; try to match route to the host header
try: try:
return self._get(request.path, request.method, return self._get(
request.headers.get("Host", '')) request.path, request.method, request.headers.get("Host", "")
)
# try default hosts # try default hosts
except NotFound: except NotFound:
return self._get(request.path, request.method, '') return self._get(request.path, request.method, "")
def get_supported_methods(self, url): def get_supported_methods(self, url):
"""Get a list of supported methods for a url and optional host. """Get a list of supported methods for a url and optional host.
@ -384,7 +409,7 @@ class Router:
""" """
route = self.routes_all.get(url) route = self.routes_all.get(url)
# if methods are None then this logic will prevent an error # if methods are None then this logic will prevent an error
return getattr(route, 'methods', None) or frozenset() return getattr(route, "methods", None) or frozenset()
@lru_cache(maxsize=ROUTER_CACHE_SIZE) @lru_cache(maxsize=ROUTER_CACHE_SIZE)
def _get(self, url, method, host): def _get(self, url, method, host):
@ -399,9 +424,10 @@ class Router:
# Check against known static routes # Check against known static routes
route = self.routes_static.get(url) route = self.routes_static.get(url)
method_not_supported = MethodNotSupported( method_not_supported = MethodNotSupported(
'Method {} not allowed for URL {}'.format(method, url), "Method {} not allowed for URL {}".format(method, url),
method=method, method=method,
allowed_methods=self.get_supported_methods(url)) allowed_methods=self.get_supported_methods(url),
)
if route: if route:
if route.methods and method not in route.methods: if route.methods and method not in route.methods:
raise method_not_supported raise method_not_supported
@ -427,13 +453,14 @@ class Router:
# Route was found but the methods didn't match # Route was found but the methods didn't match
if route_found: if route_found:
raise method_not_supported raise method_not_supported
raise NotFound('Requested URL {} not found'.format(url)) raise NotFound("Requested URL {} not found".format(url))
kwargs = {p.name: p.cast(value) kwargs = {
for value, p p.name: p.cast(value)
in zip(match.groups(1), route.parameters)} for value, p in zip(match.groups(1), route.parameters)
}
route_handler = route.handler route_handler = route.handler
if hasattr(route_handler, 'handlers'): if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method] route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri return route_handler, [], kwargs, route.uri
@ -446,7 +473,8 @@ class Router:
handler = self.get(request)[0] handler = self.get(request)[0]
except (NotFound, MethodNotSupported): except (NotFound, MethodNotSupported):
return False return False
if (hasattr(handler, 'view_class') and if hasattr(handler, "view_class") and hasattr(
hasattr(handler.view_class, request.method.lower())): handler.view_class, request.method.lower()
):
handler = getattr(handler.view_class, request.method.lower()) handler = getattr(handler.view_class, request.method.lower())
return hasattr(handler, 'is_stream') return hasattr(handler, "is_stream")

View File

@ -4,16 +4,8 @@ import traceback
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from multiprocessing import Process from multiprocessing import Process
from signal import ( from signal import SIGTERM, SIGINT, SIG_IGN, signal as signal_func, Signals
SIGTERM, SIGINT, SIG_IGN, from socket import socket, SOL_SOCKET, SO_REUSEADDR
signal as signal_func,
Signals
)
from socket import (
socket,
SOL_SOCKET,
SO_REUSEADDR,
)
from time import time from time import time
from httptools import HttpRequestParser from httptools import HttpRequestParser
@ -22,6 +14,7 @@ from multidict import CIMultiDict
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
pass pass
@ -30,8 +23,12 @@ from sanic.log import logger, access_logger
from sanic.response import HTTPResponse from sanic.response import HTTPResponse
from sanic.request import Request from sanic.request import Request
from sanic.exceptions import ( from sanic.exceptions import (
RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError, RequestTimeout,
ServiceUnavailable) PayloadTooLarge,
InvalidUsage,
ServerError,
ServiceUnavailable,
)
current_time = None current_time = None
@ -43,27 +40,58 @@ class Signal:
class HttpProtocol(asyncio.Protocol): class HttpProtocol(asyncio.Protocol):
__slots__ = ( __slots__ = (
# event loop, connection # event loop, connection
'loop', 'transport', 'connections', 'signal', "loop",
"transport",
"connections",
"signal",
# request params # request params
'parser', 'request', 'url', 'headers', "parser",
"request",
"url",
"headers",
# request config # request config
'request_handler', 'request_timeout', 'response_timeout', "request_handler",
'keep_alive_timeout', 'request_max_size', 'request_class', "request_timeout",
'is_request_stream', 'router', "response_timeout",
"keep_alive_timeout",
"request_max_size",
"request_class",
"is_request_stream",
"router",
# enable or disable access log purpose # enable or disable access log purpose
'access_log', "access_log",
# connection management # connection management
'_total_request_size', '_request_timeout_handler', "_total_request_size",
'_response_timeout_handler', '_keep_alive_timeout_handler', "_request_timeout_handler",
'_last_request_time', '_last_response_time', '_is_stream_handler', "_response_timeout_handler",
'_not_paused') "_keep_alive_timeout_handler",
"_last_request_time",
"_last_response_time",
"_is_stream_handler",
"_not_paused",
)
def __init__(self, *, loop, request_handler, error_handler, def __init__(
signal=Signal(), connections=set(), request_timeout=60, self,
response_timeout=60, keep_alive_timeout=5, *,
request_max_size=None, request_class=None, access_log=True, loop,
keep_alive=True, is_request_stream=False, router=None, request_handler,
state=None, debug=False, **kwargs): error_handler,
signal=Signal(),
connections=set(),
request_timeout=60,
response_timeout=60,
keep_alive_timeout=5,
request_max_size=None,
request_class=None,
access_log=True,
keep_alive=True,
is_request_stream=False,
router=None,
state=None,
debug=False,
**kwargs
):
self.loop = loop self.loop = loop
self.transport = None self.transport = None
self.request = None self.request = None
@ -93,19 +121,20 @@ class HttpProtocol(asyncio.Protocol):
self._request_handler_task = None self._request_handler_task = None
self._request_stream_task = None self._request_stream_task = None
self._keep_alive = keep_alive self._keep_alive = keep_alive
self._header_fragment = b'' self._header_fragment = b""
self.state = state if state else {} self.state = state if state else {}
if 'requests_count' not in self.state: if "requests_count" not in self.state:
self.state['requests_count'] = 0 self.state["requests_count"] = 0
self._debug = debug self._debug = debug
self._not_paused.set() self._not_paused.set()
@property @property
def keep_alive(self): def keep_alive(self):
return ( return (
self._keep_alive and self._keep_alive
not self.signal.stopped and and not self.signal.stopped
self.parser.should_keep_alive()) and self.parser.should_keep_alive()
)
# -------------------------------------------- # # -------------------------------------------- #
# Connection # Connection
@ -114,7 +143,8 @@ class HttpProtocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
self.connections.add(self) self.connections.add(self)
self._request_timeout_handler = self.loop.call_later( self._request_timeout_handler = self.loop.call_later(
self.request_timeout, self.request_timeout_callback) self.request_timeout, self.request_timeout_callback
)
self.transport = transport self.transport = transport
self._last_request_time = current_time self._last_request_time = current_time
@ -145,16 +175,15 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_request_time time_elapsed = current_time - self._last_request_time
if time_elapsed < self.request_timeout: if time_elapsed < self.request_timeout:
time_left = self.request_timeout - time_elapsed time_left = self.request_timeout - time_elapsed
self._request_timeout_handler = ( self._request_timeout_handler = self.loop.call_later(
self.loop.call_later(time_left, time_left, self.request_timeout_callback
self.request_timeout_callback)
) )
else: else:
if self._request_stream_task: if self._request_stream_task:
self._request_stream_task.cancel() self._request_stream_task.cancel()
if self._request_handler_task: if self._request_handler_task:
self._request_handler_task.cancel() self._request_handler_task.cancel()
self.write_error(RequestTimeout('Request Timeout')) self.write_error(RequestTimeout("Request Timeout"))
def response_timeout_callback(self): def response_timeout_callback(self):
# Check if elapsed time since response was initiated exceeds our # Check if elapsed time since response was initiated exceeds our
@ -162,16 +191,15 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_request_time time_elapsed = current_time - self._last_request_time
if time_elapsed < self.response_timeout: if time_elapsed < self.response_timeout:
time_left = self.response_timeout - time_elapsed time_left = self.response_timeout - time_elapsed
self._response_timeout_handler = ( self._response_timeout_handler = self.loop.call_later(
self.loop.call_later(time_left, time_left, self.response_timeout_callback
self.response_timeout_callback)
) )
else: else:
if self._request_stream_task: if self._request_stream_task:
self._request_stream_task.cancel() self._request_stream_task.cancel()
if self._request_handler_task: if self._request_handler_task:
self._request_handler_task.cancel() self._request_handler_task.cancel()
self.write_error(ServiceUnavailable('Response Timeout')) self.write_error(ServiceUnavailable("Response Timeout"))
def keep_alive_timeout_callback(self): def keep_alive_timeout_callback(self):
# Check if elapsed time since last response exceeds our configured # Check if elapsed time since last response exceeds our configured
@ -179,12 +207,11 @@ class HttpProtocol(asyncio.Protocol):
time_elapsed = current_time - self._last_response_time time_elapsed = current_time - self._last_response_time
if time_elapsed < self.keep_alive_timeout: if time_elapsed < self.keep_alive_timeout:
time_left = self.keep_alive_timeout - time_elapsed time_left = self.keep_alive_timeout - time_elapsed
self._keep_alive_timeout_handler = ( self._keep_alive_timeout_handler = self.loop.call_later(
self.loop.call_later(time_left, time_left, self.keep_alive_timeout_callback
self.keep_alive_timeout_callback)
) )
else: else:
logger.debug('KeepAlive Timeout. Closing connection.') logger.debug("KeepAlive Timeout. Closing connection.")
self.transport.close() self.transport.close()
self.transport = None self.transport = None
@ -197,7 +224,7 @@ class HttpProtocol(asyncio.Protocol):
# memory limits # memory limits
self._total_request_size += len(data) self._total_request_size += len(data)
if self._total_request_size > self.request_max_size: if self._total_request_size > self.request_max_size:
self.write_error(PayloadTooLarge('Payload Too Large')) self.write_error(PayloadTooLarge("Payload Too Large"))
# Create parser if this is the first time we're receiving data # Create parser if this is the first time we're receiving data
if self.parser is None: if self.parser is None:
@ -206,15 +233,15 @@ class HttpProtocol(asyncio.Protocol):
self.parser = HttpRequestParser(self) self.parser = HttpRequestParser(self)
# requests count # requests count
self.state['requests_count'] = self.state['requests_count'] + 1 self.state["requests_count"] = self.state["requests_count"] + 1
# Parse request chunk or close connection # Parse request chunk or close connection
try: try:
self.parser.feed_data(data) self.parser.feed_data(data)
except HttpParserError: except HttpParserError:
message = 'Bad Request' message = "Bad Request"
if self._debug: if self._debug:
message += '\n' + traceback.format_exc() message += "\n" + traceback.format_exc()
self.write_error(InvalidUsage(message)) self.write_error(InvalidUsage(message))
def on_url(self, url): def on_url(self, url):
@ -227,17 +254,20 @@ class HttpProtocol(asyncio.Protocol):
self._header_fragment += name self._header_fragment += name
if value is not None: if value is not None:
if self._header_fragment == b'Content-Length' \ if (
and int(value) > self.request_max_size: self._header_fragment == b"Content-Length"
self.write_error(PayloadTooLarge('Payload Too Large')) and int(value) > self.request_max_size
):
self.write_error(PayloadTooLarge("Payload Too Large"))
try: try:
value = value.decode() value = value.decode()
except UnicodeDecodeError: except UnicodeDecodeError:
value = value.decode('latin_1') value = value.decode("latin_1")
self.headers.append( self.headers.append(
(self._header_fragment.decode().casefold(), value)) (self._header_fragment.decode().casefold(), value)
)
self._header_fragment = b'' self._header_fragment = b""
def on_headers_complete(self): def on_headers_complete(self):
self.request = self.request_class( self.request = self.request_class(
@ -245,7 +275,7 @@ class HttpProtocol(asyncio.Protocol):
headers=CIMultiDict(self.headers), headers=CIMultiDict(self.headers),
version=self.parser.get_http_version(), version=self.parser.get_http_version(),
method=self.parser.get_method().decode(), method=self.parser.get_method().decode(),
transport=self.transport transport=self.transport,
) )
# Remove any existing KeepAlive handler here, # Remove any existing KeepAlive handler here,
# It will be recreated if required on the new request. # It will be recreated if required on the new request.
@ -254,7 +284,8 @@ class HttpProtocol(asyncio.Protocol):
self._keep_alive_timeout_handler = None self._keep_alive_timeout_handler = None
if self.is_request_stream: if self.is_request_stream:
self._is_stream_handler = self.router.is_stream_handler( self._is_stream_handler = self.router.is_stream_handler(
self.request) self.request
)
if self._is_stream_handler: if self._is_stream_handler:
self.request.stream = asyncio.Queue() self.request.stream = asyncio.Queue()
self.execute_request_handler() self.execute_request_handler()
@ -262,7 +293,8 @@ class HttpProtocol(asyncio.Protocol):
def on_body(self, body): def on_body(self, body):
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task( self._request_stream_task = self.loop.create_task(
self.request.stream.put(body)) self.request.stream.put(body)
)
return return
self.request.body.append(body) self.request.body.append(body)
@ -274,47 +306,49 @@ class HttpProtocol(asyncio.Protocol):
self._request_timeout_handler = None self._request_timeout_handler = None
if self.is_request_stream and self._is_stream_handler: if self.is_request_stream and self._is_stream_handler:
self._request_stream_task = self.loop.create_task( self._request_stream_task = self.loop.create_task(
self.request.stream.put(None)) self.request.stream.put(None)
)
return return
self.request.body = b''.join(self.request.body) self.request.body = b"".join(self.request.body)
self.execute_request_handler() self.execute_request_handler()
def execute_request_handler(self): def execute_request_handler(self):
self._response_timeout_handler = self.loop.call_later( self._response_timeout_handler = self.loop.call_later(
self.response_timeout, self.response_timeout_callback) self.response_timeout, self.response_timeout_callback
)
self._last_request_time = current_time self._last_request_time = current_time
self._request_handler_task = self.loop.create_task( self._request_handler_task = self.loop.create_task(
self.request_handler( self.request_handler(
self.request, self.request, self.write_response, self.stream_response
self.write_response, )
self.stream_response)) )
# -------------------------------------------- # # -------------------------------------------- #
# Responding # Responding
# -------------------------------------------- # # -------------------------------------------- #
def log_response(self, response): def log_response(self, response):
if self.access_log: if self.access_log:
extra = { extra = {"status": getattr(response, "status", 0)}
'status': getattr(response, 'status', 0),
}
if isinstance(response, HTTPResponse): if isinstance(response, HTTPResponse):
extra['byte'] = len(response.body) extra["byte"] = len(response.body)
else: else:
extra['byte'] = -1 extra["byte"] = -1
extra['host'] = 'UNKNOWN' extra["host"] = "UNKNOWN"
if self.request is not None: if self.request is not None:
if self.request.ip: if self.request.ip:
extra['host'] = '{0}:{1}'.format(self.request.ip, extra["host"] = "{0}:{1}".format(
self.request.port) self.request.ip, self.request.port
)
extra['request'] = '{0} {1}'.format(self.request.method, extra["request"] = "{0} {1}".format(
self.request.url) self.request.method, self.request.url
)
else: else:
extra['request'] = 'nil' extra["request"] = "nil"
access_logger.info('', extra=extra) access_logger.info("", extra=extra)
def write_response(self, response): def write_response(self, response):
""" """
@ -327,31 +361,37 @@ class HttpProtocol(asyncio.Protocol):
keep_alive = self.keep_alive keep_alive = self.keep_alive
self.transport.write( self.transport.write(
response.output( response.output(
self.request.version, keep_alive, self.request.version, keep_alive, self.keep_alive_timeout
self.keep_alive_timeout)) )
)
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
logger.error('Invalid response object for url %s, ' logger.error(
'Expected Type: HTTPResponse, Actual Type: %s', "Invalid response object for url %s, "
self.url, type(response)) "Expected Type: HTTPResponse, Actual Type: %s",
self.write_error(ServerError('Invalid response type')) self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError: except RuntimeError:
if self._debug: if self._debug:
logger.error('Connection lost before response written @ %s', logger.error(
self.request.ip) "Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False keep_alive = False
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing response failed, connection closed {}".format( "Writing response failed, connection closed {}".format(repr(e))
repr(e))) )
finally: finally:
if not keep_alive: if not keep_alive:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
else: else:
self._keep_alive_timeout_handler = self.loop.call_later( self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout, self.keep_alive_timeout_callback
self.keep_alive_timeout_callback) )
self._last_response_time = current_time self._last_response_time = current_time
self.cleanup() self.cleanup()
@ -375,30 +415,36 @@ class HttpProtocol(asyncio.Protocol):
keep_alive = self.keep_alive keep_alive = self.keep_alive
response.protocol = self response.protocol = self
await response.stream( await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout) self.request.version, keep_alive, self.keep_alive_timeout
)
self.log_response(response) self.log_response(response)
except AttributeError: except AttributeError:
logger.error('Invalid response object for url %s, ' logger.error(
'Expected Type: HTTPResponse, Actual Type: %s', "Invalid response object for url %s, "
self.url, type(response)) "Expected Type: HTTPResponse, Actual Type: %s",
self.write_error(ServerError('Invalid response type')) self.url,
type(response),
)
self.write_error(ServerError("Invalid response type"))
except RuntimeError: except RuntimeError:
if self._debug: if self._debug:
logger.error('Connection lost before response written @ %s', logger.error(
self.request.ip) "Connection lost before response written @ %s",
self.request.ip,
)
keep_alive = False keep_alive = False
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing response failed, connection closed {}".format( "Writing response failed, connection closed {}".format(repr(e))
repr(e))) )
finally: finally:
if not keep_alive: if not keep_alive:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
else: else:
self._keep_alive_timeout_handler = self.loop.call_later( self._keep_alive_timeout_handler = self.loop.call_later(
self.keep_alive_timeout, self.keep_alive_timeout, self.keep_alive_timeout_callback
self.keep_alive_timeout_callback) )
self._last_response_time = current_time self._last_response_time = current_time
self.cleanup() self.cleanup()
@ -411,32 +457,37 @@ class HttpProtocol(asyncio.Protocol):
response = None response = None
try: try:
response = self.error_handler.response(self.request, exception) response = self.error_handler.response(self.request, exception)
version = self.request.version if self.request else '1.1' version = self.request.version if self.request else "1.1"
self.transport.write(response.output(version)) self.transport.write(response.output(version))
except RuntimeError: except RuntimeError:
if self._debug: if self._debug:
logger.error('Connection lost before error written @ %s', logger.error(
self.request.ip if self.request else 'Unknown') "Connection lost before error written @ %s",
self.request.ip if self.request else "Unknown",
)
except Exception as e: except Exception as e:
self.bail_out( self.bail_out(
"Writing error failed, connection closed {}".format( "Writing error failed, connection closed {}".format(repr(e)),
repr(e)), from_error=True from_error=True,
) )
finally: finally:
if self.parser and (self.keep_alive if self.parser and (
or getattr(response, 'status', 0) == 408): self.keep_alive or getattr(response, "status", 0) == 408
):
self.log_response(response) self.log_response(response)
try: try:
self.transport.close() self.transport.close()
except AttributeError: except AttributeError:
logger.debug('Connection lost before server could close it.') logger.debug("Connection lost before server could close it.")
def bail_out(self, message, from_error=False): def bail_out(self, message, from_error=False):
if from_error or self.transport.is_closing(): if from_error or self.transport.is_closing():
logger.error("Transport closed @ %s and exception " logger.error(
"experienced during error handling", "Transport closed @ %s and exception "
self.transport.get_extra_info('peername')) "experienced during error handling",
logger.debug('Exception:', exc_info=True) self.transport.get_extra_info("peername"),
)
logger.debug("Exception:", exc_info=True)
else: else:
self.write_error(ServerError(message)) self.write_error(ServerError(message))
logger.error(message) logger.error(message)
@ -497,17 +548,43 @@ def trigger_events(events, loop):
loop.run_until_complete(result) loop.run_until_complete(result)
def serve(host, port, request_handler, error_handler, before_start=None, def serve(
after_start=None, before_stop=None, after_stop=None, debug=False, host,
request_timeout=60, response_timeout=60, keep_alive_timeout=5, port,
ssl=None, sock=None, request_max_size=None, reuse_port=False, request_handler,
loop=None, protocol=HttpProtocol, backlog=100, error_handler,
register_sys_signals=True, run_multiple=False, run_async=False, before_start=None,
connections=None, signal=Signal(), request_class=None, after_start=None,
access_log=True, keep_alive=True, is_request_stream=False, before_stop=None,
router=None, websocket_max_size=None, websocket_max_queue=None, after_stop=None,
websocket_read_limit=2 ** 16, websocket_write_limit=2 ** 16, debug=False,
state=None, graceful_shutdown_timeout=15.0): request_timeout=60,
response_timeout=60,
keep_alive_timeout=5,
ssl=None,
sock=None,
request_max_size=None,
reuse_port=False,
loop=None,
protocol=HttpProtocol,
backlog=100,
register_sys_signals=True,
run_multiple=False,
run_async=False,
connections=None,
signal=Signal(),
request_class=None,
access_log=True,
keep_alive=True,
is_request_stream=False,
router=None,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
state=None,
graceful_shutdown_timeout=15.0,
):
"""Start asynchronous HTTP Server on an individual process. """Start asynchronous HTTP Server on an individual process.
:param host: Address to host on :param host: Address to host on
@ -592,7 +669,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
ssl=ssl, ssl=ssl,
reuse_port=reuse_port, reuse_port=reuse_port,
sock=sock, sock=sock,
backlog=backlog backlog=backlog,
) )
# Instead of pulling time at the end of every request, # Instead of pulling time at the end of every request,
@ -623,11 +700,13 @@ def serve(host, port, request_handler, error_handler, before_start=None,
try: try:
loop.add_signal_handler(_signal, loop.stop) loop.add_signal_handler(_signal, loop.stop)
except NotImplementedError: except NotImplementedError:
logger.warning('Sanic tried to use loop.add_signal_handler ' logger.warning(
'but it is not implemented on this platform.') "Sanic tried to use loop.add_signal_handler "
"but it is not implemented on this platform."
)
pid = os.getpid() pid = os.getpid()
try: try:
logger.info('Starting worker [%s]', pid) logger.info("Starting worker [%s]", pid)
loop.run_forever() loop.run_forever()
finally: finally:
logger.info("Stopping worker [%s]", pid) logger.info("Stopping worker [%s]", pid)
@ -658,9 +737,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
coros = [] coros = []
for conn in connections: for conn in connections:
if hasattr(conn, "websocket") and conn.websocket: if hasattr(conn, "websocket") and conn.websocket:
coros.append( coros.append(conn.websocket.close_connection())
conn.websocket.close_connection()
)
else: else:
conn.close() conn.close()
@ -681,18 +758,18 @@ def serve_multiple(server_settings, workers):
:param stop_event: if provided, is used as a stop signal :param stop_event: if provided, is used as a stop signal
:return: :return:
""" """
server_settings['reuse_port'] = True server_settings["reuse_port"] = True
server_settings['run_multiple'] = True server_settings["run_multiple"] = True
# Handling when custom socket is not provided. # Handling when custom socket is not provided.
if server_settings.get('sock') is None: if server_settings.get("sock") is None:
sock = socket() sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
sock.bind((server_settings['host'], server_settings['port'])) sock.bind((server_settings["host"], server_settings["port"]))
sock.set_inheritable(True) sock.set_inheritable(True)
server_settings['sock'] = sock server_settings["sock"] = sock
server_settings['host'] = None server_settings["host"] = None
server_settings['port'] = None server_settings["port"] = None
def sig_handler(signal, frame): def sig_handler(signal, frame):
logger.info("Received signal %s. Shutting down.", Signals(signal).name) logger.info("Received signal %s. Shutting down.", Signals(signal).name)
@ -716,4 +793,4 @@ def serve_multiple(server_settings, workers):
# the above processes will block this until they're stopped # the above processes will block this until they're stopped
for process in processes: for process in processes:
process.terminate() process.terminate()
server_settings.get('sock').close() server_settings.get("sock").close()

View File

@ -16,10 +16,19 @@ from sanic.handlers import ContentRangeHandler
from sanic.response import file, file_stream, HTTPResponse from sanic.response import file, file_stream, HTTPResponse
def register(app, uri, file_or_directory, pattern, def register(
use_modified_since, use_content_range, app,
stream_large_files, name='static', host=None, uri,
strict_slashes=None, content_type=None): file_or_directory,
pattern,
use_modified_since,
use_content_range,
stream_large_files,
name="static",
host=None,
strict_slashes=None,
content_type=None,
):
# TODO: Though sanic is not a file server, I feel like we should at least # TODO: Though sanic is not a file server, I feel like we should at least
# make a good effort here. Modified-since is nice, but we could # make a good effort here. Modified-since is nice, but we could
# also look into etags, expires, and caching # also look into etags, expires, and caching
@ -46,12 +55,12 @@ def register(app, uri, file_or_directory, pattern,
# If we're not trying to match a file directly, # If we're not trying to match a file directly,
# serve from the folder # serve from the folder
if not path.isfile(file_or_directory): if not path.isfile(file_or_directory):
uri += '<file_uri:' + pattern + '>' uri += "<file_uri:" + pattern + ">"
async def _handler(request, file_uri=None): async def _handler(request, file_uri=None):
# Using this to determine if the URL is trying to break out of the path # Using this to determine if the URL is trying to break out of the path
# served. os.path.realpath seems to be very slow # served. os.path.realpath seems to be very slow
if file_uri and '../' in file_uri: if file_uri and "../" in file_uri:
raise InvalidUsage("Invalid URL") raise InvalidUsage("Invalid URL")
# Merge served directory and requested file if provided # Merge served directory and requested file if provided
# Strip all / that in the beginning of the URL to help prevent python # Strip all / that in the beginning of the URL to help prevent python
@ -59,15 +68,16 @@ def register(app, uri, file_or_directory, pattern,
root_path = file_path = file_or_directory root_path = file_path = file_or_directory
if file_uri: if file_uri:
file_path = path.join( file_path = path.join(
file_or_directory, sub('^[/]*', '', file_uri)) file_or_directory, sub("^[/]*", "", file_uri)
)
# URL decode the path sent by the browser otherwise we won't be able to # URL decode the path sent by the browser otherwise we won't be able to
# match filenames which got encoded (filenames with spaces etc) # match filenames which got encoded (filenames with spaces etc)
file_path = path.abspath(unquote(file_path)) file_path = path.abspath(unquote(file_path))
if not file_path.startswith(path.abspath(unquote(root_path))): if not file_path.startswith(path.abspath(unquote(root_path))):
raise FileNotFound('File not found', raise FileNotFound(
path=file_or_directory, "File not found", path=file_or_directory, relative_url=file_uri
relative_url=file_uri) )
try: try:
headers = {} headers = {}
# Check if the client has been sent this file before # Check if the client has been sent this file before
@ -76,29 +86,31 @@ def register(app, uri, file_or_directory, pattern,
if use_modified_since: if use_modified_since:
stats = await stat(file_path) stats = await stat(file_path)
modified_since = strftime( modified_since = strftime(
'%a, %d %b %Y %H:%M:%S GMT', gmtime(stats.st_mtime)) "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime)
if request.headers.get('If-Modified-Since') == modified_since: )
if request.headers.get("If-Modified-Since") == modified_since:
return HTTPResponse(status=304) return HTTPResponse(status=304)
headers['Last-Modified'] = modified_since headers["Last-Modified"] = modified_since
_range = None _range = None
if use_content_range: if use_content_range:
_range = None _range = None
if not stats: if not stats:
stats = await stat(file_path) stats = await stat(file_path)
headers['Accept-Ranges'] = 'bytes' headers["Accept-Ranges"] = "bytes"
headers['Content-Length'] = str(stats.st_size) headers["Content-Length"] = str(stats.st_size)
if request.method != 'HEAD': if request.method != "HEAD":
try: try:
_range = ContentRangeHandler(request, stats) _range = ContentRangeHandler(request, stats)
except HeaderNotFound: except HeaderNotFound:
pass pass
else: else:
del headers['Content-Length'] del headers["Content-Length"]
for key, value in _range.headers.items(): for key, value in _range.headers.items():
headers[key] = value headers[key] = value
headers['Content-Type'] = content_type \ headers["Content-Type"] = (
or guess_type(file_path)[0] or 'text/plain' content_type or guess_type(file_path)[0] or "text/plain"
if request.method == 'HEAD': )
if request.method == "HEAD":
return HTTPResponse(headers=headers) return HTTPResponse(headers=headers)
else: else:
if stream_large_files: if stream_large_files:
@ -110,19 +122,25 @@ def register(app, uri, file_or_directory, pattern,
if not stats: if not stats:
stats = await stat(file_path) stats = await stat(file_path)
if stats.st_size >= threshold: if stats.st_size >= threshold:
return await file_stream(file_path, headers=headers, return await file_stream(
_range=_range) file_path, headers=headers, _range=_range
)
return await file(file_path, headers=headers, _range=_range) return await file(file_path, headers=headers, _range=_range)
except ContentRangeError: except ContentRangeError:
raise raise
except Exception: except Exception:
raise FileNotFound('File not found', raise FileNotFound(
path=file_or_directory, "File not found", path=file_or_directory, relative_url=file_uri
relative_url=file_uri) )
# special prefix for static files # special prefix for static files
if not name.startswith('_static_'): if not name.startswith("_static_"):
name = '_static_{}'.format(name) name = "_static_{}".format(name)
app.route(uri, methods=['GET', 'HEAD'], name=name, host=host, app.route(
strict_slashes=strict_slashes)(_handler) uri,
methods=["GET", "HEAD"],
name=name,
host=host,
strict_slashes=strict_slashes,
)(_handler)

View File

@ -4,7 +4,7 @@ from sanic.exceptions import MethodNotSupported
from sanic.response import text from sanic.response import text
HOST = '127.0.0.1' HOST = "127.0.0.1"
PORT = 42101 PORT = 42101
@ -15,18 +15,22 @@ class SanicTestClient:
async def _local_request(self, method, uri, cookies=None, *args, **kwargs): async def _local_request(self, method, uri, cookies=None, *args, **kwargs):
import aiohttp import aiohttp
if uri.startswith(('http:', 'https:', 'ftp:', 'ftps://' '//')):
if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")):
url = uri url = uri
else: else:
url = 'http://{host}:{port}{uri}'.format( url = "http://{host}:{port}{uri}".format(
host=HOST, port=self.port, uri=uri) host=HOST, port=self.port, uri=uri
)
logger.info(url) logger.info(url)
conn = aiohttp.TCPConnector(verify_ssl=False) conn = aiohttp.TCPConnector(verify_ssl=False)
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
cookies=cookies, connector=conn) as session: cookies=cookies, connector=conn
async with getattr( ) as session:
session, method.lower())(url, *args, **kwargs) as response: async with getattr(session, method.lower())(
url, *args, **kwargs
) as response:
try: try:
response.text = await response.text() response.text = await response.text()
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
@ -34,50 +38,60 @@ class SanicTestClient:
try: try:
response.json = await response.json() response.json = await response.json()
except (JSONDecodeError, except (
UnicodeDecodeError, JSONDecodeError,
aiohttp.ClientResponseError): UnicodeDecodeError,
aiohttp.ClientResponseError,
):
response.json = None response.json = None
response.body = await response.read() response.body = await response.read()
return response return response
def _sanic_endpoint_test( def _sanic_endpoint_test(
self, method='get', uri='/', gather_request=True, self,
debug=False, server_kwargs={"auto_reload": False}, method="get",
*request_args, **request_kwargs): uri="/",
gather_request=True,
debug=False,
server_kwargs={"auto_reload": False},
*request_args,
**request_kwargs
):
results = [None, None] results = [None, None]
exceptions = [] exceptions = []
if gather_request: if gather_request:
def _collect_request(request): def _collect_request(request):
if results[0] is None: if results[0] is None:
results[0] = request results[0] = request
self.app.request_middleware.appendleft(_collect_request) self.app.request_middleware.appendleft(_collect_request)
@self.app.exception(MethodNotSupported) @self.app.exception(MethodNotSupported)
async def error_handler(request, exception): async def error_handler(request, exception):
if request.method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text( return text(
'', exception.status_code, headers=exception.headers "", exception.status_code, headers=exception.headers
) )
else: else:
return self.app.error_handler.default(request, exception) return self.app.error_handler.default(request, exception)
@self.app.listener('after_server_start') @self.app.listener("after_server_start")
async def _collect_response(sanic, loop): async def _collect_response(sanic, loop):
try: try:
response = await self._local_request( response = await self._local_request(
method, uri, *request_args, method, uri, *request_args, **request_kwargs
**request_kwargs) )
results[-1] = response results[-1] = response
except Exception as e: except Exception as e:
logger.exception('Exception') logger.exception("Exception")
exceptions.append(e) exceptions.append(e)
self.app.stop() self.app.stop()
self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs) self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs)
self.app.listeners['after_server_start'].pop() self.app.listeners["after_server_start"].pop()
if exceptions: if exceptions:
raise ValueError("Exception during request: {}".format(exceptions)) raise ValueError("Exception during request: {}".format(exceptions))
@ -89,31 +103,34 @@ class SanicTestClient:
except BaseException: except BaseException:
raise ValueError( raise ValueError(
"Request and response object expected, got ({})".format( "Request and response object expected, got ({})".format(
results)) results
)
)
else: else:
try: try:
return results[-1] return results[-1]
except BaseException: except BaseException:
raise ValueError( raise ValueError(
"Request object expected, got ({})".format(results)) "Request object expected, got ({})".format(results)
)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
return self._sanic_endpoint_test('get', *args, **kwargs) return self._sanic_endpoint_test("get", *args, **kwargs)
def post(self, *args, **kwargs): def post(self, *args, **kwargs):
return self._sanic_endpoint_test('post', *args, **kwargs) return self._sanic_endpoint_test("post", *args, **kwargs)
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
return self._sanic_endpoint_test('put', *args, **kwargs) return self._sanic_endpoint_test("put", *args, **kwargs)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
return self._sanic_endpoint_test('delete', *args, **kwargs) return self._sanic_endpoint_test("delete", *args, **kwargs)
def patch(self, *args, **kwargs): def patch(self, *args, **kwargs):
return self._sanic_endpoint_test('patch', *args, **kwargs) return self._sanic_endpoint_test("patch", *args, **kwargs)
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
return self._sanic_endpoint_test('options', *args, **kwargs) return self._sanic_endpoint_test("options", *args, **kwargs)
def head(self, *args, **kwargs): def head(self, *args, **kwargs):
return self._sanic_endpoint_test('head', *args, **kwargs) return self._sanic_endpoint_test("head", *args, **kwargs)

View File

@ -48,6 +48,7 @@ class HTTPMethodView:
"""Return view function for use with the routing system, that """Return view function for use with the routing system, that
dispatches request to appropriate handler method. dispatches request to appropriate handler method.
""" """
def view(*args, **kwargs): def view(*args, **kwargs):
self = view.view_class(*class_args, **class_kwargs) self = view.view_class(*class_args, **class_kwargs)
return self.dispatch_request(*args, **kwargs) return self.dispatch_request(*args, **kwargs)
@ -94,11 +95,13 @@ class CompositionView:
for method in methods: for method in methods:
if method not in HTTP_METHODS: if method not in HTTP_METHODS:
raise InvalidUsage( raise InvalidUsage(
'{} is not a valid HTTP method.'.format(method)) "{} is not a valid HTTP method.".format(method)
)
if method in self.handlers: if method in self.handlers:
raise InvalidUsage( raise InvalidUsage(
'Method {} is already registered.'.format(method)) "Method {} is already registered.".format(method)
)
self.handlers[method] = handler self.handlers[method] = handler
def __call__(self, request, *args, **kwargs): def __call__(self, request, *args, **kwargs):

View File

@ -6,11 +6,16 @@ from websockets import ConnectionClosed # noqa
class WebSocketProtocol(HttpProtocol): class WebSocketProtocol(HttpProtocol):
def __init__(self, *args, websocket_timeout=10, def __init__(
websocket_max_size=None, self,
websocket_max_queue=None, *args,
websocket_read_limit=2 ** 16, websocket_timeout=10,
websocket_write_limit=2 ** 16, **kwargs): websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
**kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.websocket = None self.websocket = None
self.websocket_timeout = websocket_timeout self.websocket_timeout = websocket_timeout
@ -63,24 +68,26 @@ class WebSocketProtocol(HttpProtocol):
key = handshake.check_request(request.headers) key = handshake.check_request(request.headers)
handshake.build_response(headers, key) handshake.build_response(headers, key)
except InvalidHandshake: except InvalidHandshake:
raise InvalidUsage('Invalid websocket request') raise InvalidUsage("Invalid websocket request")
subprotocol = None subprotocol = None
if subprotocols and 'Sec-Websocket-Protocol' in request.headers: if subprotocols and "Sec-Websocket-Protocol" in request.headers:
# select a subprotocol # select a subprotocol
client_subprotocols = [p.strip() for p in request.headers[ client_subprotocols = [
'Sec-Websocket-Protocol'].split(',')] p.strip()
for p in request.headers["Sec-Websocket-Protocol"].split(",")
]
for p in client_subprotocols: for p in client_subprotocols:
if p in subprotocols: if p in subprotocols:
subprotocol = p subprotocol = p
headers['Sec-Websocket-Protocol'] = subprotocol headers["Sec-Websocket-Protocol"] = subprotocol
break break
# write the 101 response back to the client # write the 101 response back to the client
rv = b'HTTP/1.1 101 Switching Protocols\r\n' rv = b"HTTP/1.1 101 Switching Protocols\r\n"
for k, v in headers.items(): for k, v in headers.items():
rv += k.encode('utf-8') + b': ' + v.encode('utf-8') + b'\r\n' rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
rv += b'\r\n' rv += b"\r\n"
request.transport.write(rv) request.transport.write(rv)
# hook up the websocket protocol # hook up the websocket protocol
@ -89,7 +96,7 @@ class WebSocketProtocol(HttpProtocol):
max_size=self.websocket_max_size, max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue, max_queue=self.websocket_max_queue,
read_limit=self.websocket_read_limit, read_limit=self.websocket_read_limit,
write_limit=self.websocket_write_limit write_limit=self.websocket_write_limit,
) )
self.websocket.subprotocol = subprotocol self.websocket.subprotocol = subprotocol
self.websocket.connection_made(request.transport) self.websocket.connection_made(request.transport)

View File

@ -12,6 +12,7 @@ except ImportError:
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
pass pass
@ -50,36 +51,43 @@ class GunicornWorker(base.Worker):
def run(self): def run(self):
is_debug = self.log.loglevel == logging.DEBUG is_debug = self.log.loglevel == logging.DEBUG
protocol = ( protocol = (
self.websocket_protocol if self.app.callable.websocket_enabled self.websocket_protocol
else self.http_protocol) if self.app.callable.websocket_enabled
else self.http_protocol
)
self._server_settings = self.app.callable._helper( self._server_settings = self.app.callable._helper(
loop=self.loop, loop=self.loop,
debug=is_debug, debug=is_debug,
protocol=protocol, protocol=protocol,
ssl=self.ssl_context, ssl=self.ssl_context,
run_async=True) run_async=True,
self._server_settings['signal'] = self.signal )
self._server_settings.pop('sock') self._server_settings["signal"] = self.signal
trigger_events(self._server_settings.get('before_start', []), self._server_settings.pop("sock")
self.loop) trigger_events(
self._server_settings['before_start'] = () self._server_settings.get("before_start", []), self.loop
)
self._server_settings["before_start"] = ()
self._runner = asyncio.ensure_future(self._run(), loop=self.loop) self._runner = asyncio.ensure_future(self._run(), loop=self.loop)
try: try:
self.loop.run_until_complete(self._runner) self.loop.run_until_complete(self._runner)
self.app.callable.is_running = True self.app.callable.is_running = True
trigger_events(self._server_settings.get('after_start', []), trigger_events(
self.loop) self._server_settings.get("after_start", []), self.loop
)
self.loop.run_until_complete(self._check_alive()) self.loop.run_until_complete(self._check_alive())
trigger_events(self._server_settings.get('before_stop', []), trigger_events(
self.loop) self._server_settings.get("before_stop", []), self.loop
)
self.loop.run_until_complete(self.close()) self.loop.run_until_complete(self.close())
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()
finally: finally:
try: try:
trigger_events(self._server_settings.get('after_stop', []), trigger_events(
self.loop) self._server_settings.get("after_stop", []), self.loop
)
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()
finally: finally:
@ -90,8 +98,11 @@ class GunicornWorker(base.Worker):
async def close(self): async def close(self):
if self.servers: if self.servers:
# stop accepting connections # stop accepting connections
self.log.info("Stopping server: %s, connections: %s", self.log.info(
self.pid, len(self.connections)) "Stopping server: %s, connections: %s",
self.pid,
len(self.connections),
)
for server in self.servers: for server in self.servers:
server.close() server.close()
await server.wait_closed() await server.wait_closed()
@ -105,8 +116,9 @@ class GunicornWorker(base.Worker):
# gracefully shutdown timeout # gracefully shutdown timeout
start_shutdown = 0 start_shutdown = 0
graceful_shutdown_timeout = self.cfg.graceful_timeout graceful_shutdown_timeout = self.cfg.graceful_timeout
while self.connections and \ while self.connections and (
(start_shutdown < graceful_shutdown_timeout): start_shutdown < graceful_shutdown_timeout
):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
start_shutdown = start_shutdown + 0.1 start_shutdown = start_shutdown + 0.1
@ -115,9 +127,7 @@ class GunicornWorker(base.Worker):
coros = [] coros = []
for conn in self.connections: for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket: if hasattr(conn, "websocket") and conn.websocket:
coros.append( coros.append(conn.websocket.close_connection())
conn.websocket.close_connection()
)
else: else:
conn.close() conn.close()
_shutdown = asyncio.gather(*coros, loop=self.loop) _shutdown = asyncio.gather(*coros, loop=self.loop)
@ -148,8 +158,9 @@ class GunicornWorker(base.Worker):
) )
if self.max_requests and req_count > self.max_requests: if self.max_requests and req_count > self.max_requests:
self.alive = False self.alive = False
self.log.info("Max requests exceeded, shutting down: %s", self.log.info(
self) "Max requests exceeded, shutting down: %s", self
)
elif pid == os.getpid() and self.ppid != os.getppid(): elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False self.alive = False
self.log.info("Parent changed, shutting down: %s", self) self.log.info("Parent changed, shutting down: %s", self)
@ -175,23 +186,29 @@ class GunicornWorker(base.Worker):
def init_signals(self): def init_signals(self):
# Set up signals through the event loop API. # Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, self.loop.add_signal_handler(
signal.SIGQUIT, None) signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, self.loop.add_signal_handler(
signal.SIGTERM, None) signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, self.loop.add_signal_handler(
signal.SIGINT, None) signal.SIGINT, self.handle_quit, signal.SIGINT, None
)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, self.loop.add_signal_handler(
signal.SIGWINCH, None) signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, self.loop.add_signal_handler(
signal.SIGUSR1, None) signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, self.loop.add_signal_handler(
signal.SIGABRT, None) signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
)
# Don't let SIGTERM and SIGUSR1 disturb active requests # Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls # by interrupting system calls