From aa9bf04dfe22c9cc1aae63e5832c93f7b7b0296c Mon Sep 17 00:00:00 2001 From: Yun Xu Date: Sat, 13 Oct 2018 17:55:33 -0700 Subject: [PATCH] run black against sanic module --- pyproject.toml | 2 + sanic/__init__.py | 4 +- sanic/__main__.py | 51 +-- sanic/app.py | 655 +++++++++++++++++++++++++------------- sanic/blueprints.py | 297 ++++++++++++----- sanic/config.py | 19 +- sanic/constants.py | 2 +- sanic/cookies.py | 48 +-- sanic/exceptions.py | 44 +-- sanic/handlers.py | 59 ++-- sanic/helpers.py | 179 ++++++----- sanic/log.py | 35 +- sanic/reloader_helpers.py | 34 +- sanic/request.py | 163 ++++++---- sanic/response.py | 280 +++++++++------- sanic/router.py | 186 ++++++----- sanic/server.py | 355 +++++++++++++-------- sanic/static.py | 76 +++-- sanic/testing.py | 75 +++-- sanic/views.py | 7 +- sanic/websocket.py | 35 +- sanic/worker.py | 87 +++-- 22 files changed, 1657 insertions(+), 1036 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..a8f43fef --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 79 diff --git a/sanic/__init__.py b/sanic/__init__.py index 51c8268e..786b6647 100644 --- a/sanic/__init__.py +++ b/sanic/__init__.py @@ -1,6 +1,6 @@ from sanic.app import Sanic from sanic.blueprints import Blueprint -__version__ = '0.8.3' +__version__ = "0.8.3" -__all__ = ['Sanic', 'Blueprint'] +__all__ = ["Sanic", "Blueprint"] diff --git a/sanic/__main__.py b/sanic/__main__.py index ebea4ce9..ede743fd 100644 --- a/sanic/__main__.py +++ b/sanic/__main__.py @@ -5,16 +5,18 @@ from sanic.log import logger from sanic.app import Sanic if __name__ == "__main__": - parser = ArgumentParser(prog='sanic') - parser.add_argument('--host', dest='host', type=str, default='127.0.0.1') - parser.add_argument('--port', dest='port', type=int, default=8000) - parser.add_argument('--cert', dest='cert', type=str, - help='location of certificate for SSL') - parser.add_argument('--key', dest='key', type=str, - help='location of keyfile for SSL.') - parser.add_argument('--workers', dest='workers', type=int, default=1, ) - parser.add_argument('--debug', dest='debug', action="store_true") - parser.add_argument('module') + parser = ArgumentParser(prog="sanic") + parser.add_argument("--host", dest="host", type=str, default="127.0.0.1") + parser.add_argument("--port", dest="port", type=int, default=8000) + parser.add_argument( + "--cert", dest="cert", type=str, help="location of certificate for SSL" + ) + parser.add_argument( + "--key", dest="key", type=str, help="location of keyfile for SSL." + ) + parser.add_argument("--workers", dest="workers", type=int, default=1) + parser.add_argument("--debug", dest="debug", action="store_true") + parser.add_argument("module") args = parser.parse_args() try: @@ -25,20 +27,29 @@ if __name__ == "__main__": module = import_module(module_name) app = getattr(module, app_name, None) if not isinstance(app, Sanic): - raise ValueError("Module is not a Sanic app, it is a {}. " - "Perhaps you meant {}.app?" - .format(type(app).__name__, args.module)) + raise ValueError( + "Module is not a Sanic app, it is a {}. " + "Perhaps you meant {}.app?".format( + type(app).__name__, args.module + ) + ) if args.cert is not None or args.key is not None: - ssl = {'cert': args.cert, 'key': args.key} + ssl = {"cert": args.cert, "key": args.key} else: ssl = None - app.run(host=args.host, port=args.port, - workers=args.workers, debug=args.debug, ssl=ssl) + app.run( + host=args.host, + port=args.port, + workers=args.workers, + debug=args.debug, + ssl=ssl, + ) except ImportError as e: - logger.error("No module named {} found.\n" - " Example File: project/sanic_server.py -> app\n" - " Example Module: project.sanic_server.app" - .format(e.name)) + logger.error( + "No module named {} found.\n" + " Example File: project/sanic_server.py -> app\n" + " Example Module: project.sanic_server.app".format(e.name) + ) except ValueError as e: logger.exception("Failed to run app") diff --git a/sanic/app.py b/sanic/app.py index 5af21751..9fa2e4ce 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -27,10 +27,17 @@ import sanic.reloader_helpers as reloader_helpers class Sanic: - def __init__(self, name=None, router=None, error_handler=None, - load_env=True, request_class=None, - strict_slashes=False, log_config=None, - configure_logging=True): + def __init__( + self, + name=None, + router=None, + error_handler=None, + load_env=True, + request_class=None, + strict_slashes=False, + log_config=None, + configure_logging=True, + ): # Get name from previous stack frame if name is None: @@ -71,8 +78,9 @@ class Sanic: """ if not self.is_running: raise SanicException( - 'Loop can only be retrieved after the app has started ' - 'running. Not supported with `create_server` function') + "Loop can only be retrieved after the app has started " + "running. Not supported with `create_server` function" + ) return get_event_loop() # -------------------------------------------------------------------- # @@ -96,7 +104,8 @@ class Sanic: else: self.loop.create_task(task) except SanicException: - @self.listener('before_server_start') + + @self.listener("before_server_start") def run(app, loop): if callable(task): try: @@ -133,8 +142,16 @@ class Sanic: return self.listener(event)(listener) # Decorator - def route(self, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=None, stream=False, version=None, name=None): + def route( + self, + uri, + methods=frozenset({"GET"}), + host=None, + strict_slashes=None, + stream=False, + version=None, + name=None, + ): """Decorate a function to be registered as a route :param uri: path of the URL @@ -149,8 +166,8 @@ class Sanic: # Fix case where the user did not prefix the URL with a / # and will probably get confused as to why it's not working - if not uri.startswith('/'): - uri = '/' + uri + if not uri.startswith("/"): + uri = "/" + uri if stream: self.is_request_stream = True @@ -164,63 +181,141 @@ class Sanic: if stream: handler.is_stream = stream - self.router.add(uri=uri, methods=methods, handler=handler, - host=host, strict_slashes=strict_slashes, - version=version, name=name) + self.router.add( + uri=uri, + methods=methods, + handler=handler, + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ) return handler else: raise ValueError( - 'Required parameter `request` missing ' - 'in the {0}() route?'.format( - handler.__name__)) + "Required parameter `request` missing " + "in the {0}() route?".format(handler.__name__) + ) return response # Shorthand method decorators - def get(self, uri, host=None, strict_slashes=None, version=None, - name=None): - return self.route(uri, methods=frozenset({"GET"}), host=host, - strict_slashes=strict_slashes, version=version, - name=name) + def get( + self, uri, host=None, strict_slashes=None, version=None, name=None + ): + return self.route( + uri, + methods=frozenset({"GET"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ) - def post(self, uri, host=None, strict_slashes=None, stream=False, - version=None, name=None): - return self.route(uri, methods=frozenset({"POST"}), host=host, - strict_slashes=strict_slashes, stream=stream, - version=version, name=name) + def post( + self, + uri, + host=None, + strict_slashes=None, + stream=False, + version=None, + name=None, + ): + return self.route( + uri, + methods=frozenset({"POST"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + ) - def put(self, uri, host=None, strict_slashes=None, stream=False, - version=None, name=None): - return self.route(uri, methods=frozenset({"PUT"}), host=host, - strict_slashes=strict_slashes, stream=stream, - version=version, name=name) + def put( + self, + uri, + host=None, + strict_slashes=None, + stream=False, + version=None, + name=None, + ): + return self.route( + uri, + methods=frozenset({"PUT"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + ) - def head(self, uri, host=None, strict_slashes=None, version=None, - name=None): - return self.route(uri, methods=frozenset({"HEAD"}), host=host, - strict_slashes=strict_slashes, version=version, - name=name) + def head( + self, uri, host=None, strict_slashes=None, version=None, name=None + ): + return self.route( + uri, + methods=frozenset({"HEAD"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ) - def options(self, uri, host=None, strict_slashes=None, version=None, - name=None): - return self.route(uri, methods=frozenset({"OPTIONS"}), host=host, - strict_slashes=strict_slashes, version=version, - name=name) + def options( + self, uri, host=None, strict_slashes=None, version=None, name=None + ): + return self.route( + uri, + methods=frozenset({"OPTIONS"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ) - def patch(self, uri, host=None, strict_slashes=None, stream=False, - version=None, name=None): - return self.route(uri, methods=frozenset({"PATCH"}), host=host, - strict_slashes=strict_slashes, stream=stream, - version=version, name=name) + def patch( + self, + uri, + host=None, + strict_slashes=None, + stream=False, + version=None, + name=None, + ): + return self.route( + uri, + methods=frozenset({"PATCH"}), + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + ) - def delete(self, uri, host=None, strict_slashes=None, version=None, - name=None): - return self.route(uri, methods=frozenset({"DELETE"}), host=host, - strict_slashes=strict_slashes, version=version, - name=name) + def delete( + self, uri, host=None, strict_slashes=None, version=None, name=None + ): + return self.route( + uri, + methods=frozenset({"DELETE"}), + host=host, + strict_slashes=strict_slashes, + version=version, + name=name, + ) - def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=None, version=None, name=None, stream=False): + def add_route( + self, + handler, + uri, + methods=frozenset({"GET"}), + host=None, + strict_slashes=None, + version=None, + name=None, + stream=False, + ): """A helper method to register class instance or functions as a handler to the application url routes. @@ -237,35 +332,42 @@ class Sanic: :return: function or class instance """ # Handle HTTPMethodView differently - if hasattr(handler, 'view_class'): + if hasattr(handler, "view_class"): methods = set() for method in HTTP_METHODS: _handler = getattr(handler.view_class, method.lower(), None) if _handler: methods.add(method) - if hasattr(_handler, 'is_stream'): + if hasattr(_handler, "is_stream"): stream = True # handle composition view differently if isinstance(handler, CompositionView): methods = handler.handlers.keys() for _handler in handler.handlers.values(): - if hasattr(_handler, 'is_stream'): + if hasattr(_handler, "is_stream"): stream = True break if strict_slashes is None: strict_slashes = self.strict_slashes - self.route(uri=uri, methods=methods, host=host, - strict_slashes=strict_slashes, stream=stream, - version=version, name=name)(handler) + self.route( + uri=uri, + methods=methods, + host=host, + strict_slashes=strict_slashes, + stream=stream, + version=version, + name=name, + )(handler) return handler # Decorator - def websocket(self, uri, host=None, strict_slashes=None, - subprotocols=None, name=None): + def websocket( + self, uri, host=None, strict_slashes=None, subprotocols=None, name=None + ): """Decorate a function to be registered as a websocket route :param uri: path of the URL :param subprotocols: optional list of strings with the supported @@ -277,8 +379,8 @@ class Sanic: # Fix case where the user did not prefix the URL with a / # and will probably get confused as to why it's not working - if not uri.startswith('/'): - uri = '/' + uri + if not uri.startswith("/"): + uri = "/" + uri if strict_slashes is None: strict_slashes = self.strict_slashes @@ -307,21 +409,38 @@ class Sanic: self.websocket_tasks.remove(fut) await ws.close() - self.router.add(uri=uri, handler=websocket_handler, - methods=frozenset({'GET'}), host=host, - strict_slashes=strict_slashes, name=name) + self.router.add( + uri=uri, + handler=websocket_handler, + methods=frozenset({"GET"}), + host=host, + strict_slashes=strict_slashes, + name=name, + ) return handler return response - def add_websocket_route(self, handler, uri, host=None, - strict_slashes=None, subprotocols=None, name=None): + def add_websocket_route( + self, + handler, + uri, + host=None, + strict_slashes=None, + subprotocols=None, + name=None, + ): """A helper method to register a function as a websocket route.""" if strict_slashes is None: strict_slashes = self.strict_slashes - return self.websocket(uri, host=host, strict_slashes=strict_slashes, - subprotocols=subprotocols, name=name)(handler) + return self.websocket( + uri, + host=host, + strict_slashes=strict_slashes, + subprotocols=subprotocols, + name=name, + )(handler) def enable_websocket(self, enable=True): """Enable or disable the support for websocket. @@ -332,7 +451,7 @@ class Sanic: if not self.websocket_enabled: # if the server is stopped, we want to cancel any ongoing # websocket tasks, to allow the server to exit promptly - @self.listener('before_server_stop') + @self.listener("before_server_stop") def cancel_websocket_tasks(app, loop): for task in self.websocket_tasks: task.cancel() @@ -361,10 +480,10 @@ class Sanic: return response - def register_middleware(self, middleware, attach_to='request'): - if attach_to == 'request': + def register_middleware(self, middleware, attach_to="request"): + if attach_to == "request": self.request_middleware.append(middleware) - if attach_to == 'response': + if attach_to == "response": self.response_middleware.appendleft(middleware) return middleware @@ -379,21 +498,40 @@ class Sanic: return self.register_middleware(middleware_or_request) else: - return partial(self.register_middleware, - attach_to=middleware_or_request) + return partial( + self.register_middleware, attach_to=middleware_or_request + ) # Static Files - def static(self, uri, file_or_directory, pattern=r'/?.+', - use_modified_since=True, use_content_range=False, - stream_large_files=False, name='static', host=None, - strict_slashes=None, content_type=None): + def static( + self, + uri, + file_or_directory, + pattern=r"/?.+", + use_modified_since=True, + use_content_range=False, + stream_large_files=False, + name="static", + host=None, + strict_slashes=None, + content_type=None, + ): """Register a root to serve files from. The input can either be a file or a directory. See """ - static_register(self, uri, file_or_directory, pattern, - use_modified_since, use_content_range, - stream_large_files, name, host, strict_slashes, - content_type) + static_register( + self, + uri, + file_or_directory, + pattern, + use_modified_since, + use_content_range, + stream_large_files, + name, + host, + strict_slashes, + content_type, + ) def blueprint(self, blueprint, **options): """Register a blueprint on the application. @@ -407,10 +545,10 @@ class Sanic: self.blueprint(item, **options) return if blueprint.name in self.blueprints: - assert self.blueprints[blueprint.name] is blueprint, \ - 'A blueprint with the name "%s" is already registered. ' \ - 'Blueprint names must be unique.' % \ - (blueprint.name,) + assert self.blueprints[blueprint.name] is blueprint, ( + 'A blueprint with the name "%s" is already registered. ' + "Blueprint names must be unique." % (blueprint.name,) + ) else: self.blueprints[blueprint.name] = blueprint self._blueprint_order.append(blueprint) @@ -419,11 +557,13 @@ class Sanic: def register_blueprint(self, *args, **kwargs): # TODO: deprecate 1.0 if self.debug: - warnings.simplefilter('default') - warnings.warn("Use of register_blueprint will be deprecated in " - "version 1.0. Please use the blueprint method" - " instead", - DeprecationWarning) + warnings.simplefilter("default") + warnings.warn( + "Use of register_blueprint will be deprecated in " + "version 1.0. Please use the blueprint method" + " instead", + DeprecationWarning, + ) return self.blueprint(*args, **kwargs) def url_for(self, view_name: str, **kwargs): @@ -449,67 +589,66 @@ class Sanic: # find the route by the supplied view name kw = {} # special static files url_for - if view_name == 'static': - kw.update(name=kwargs.pop('name', 'static')) - elif view_name.endswith('.static'): # blueprint.static - kwargs.pop('name', None) + if view_name == "static": + kw.update(name=kwargs.pop("name", "static")) + elif view_name.endswith(".static"): # blueprint.static + kwargs.pop("name", None) kw.update(name=view_name) uri, route = self.router.find_route_by_view_name(view_name, **kw) if not (uri and route): - raise URLBuildError('Endpoint with name `{}` was not found'.format( - view_name)) + raise URLBuildError( + "Endpoint with name `{}` was not found".format(view_name) + ) - if view_name == 'static' or view_name.endswith('.static'): - filename = kwargs.pop('filename', None) + if view_name == "static" or view_name.endswith(".static"): + filename = kwargs.pop("filename", None) # it's static folder - if '?@[]{}' +_UnescapedChars = _LegalChars + " ()/<=>?@[]{}" -_Translator = {n: '\\%03o' % n - for n in set(range(256)) - set(map(ord, _UnescapedChars))} -_Translator.update({ - ord('"'): '\\"', - ord('\\'): '\\\\', -}) +_Translator = { + n: "\\%03o" % n for n in set(range(256)) - set(map(ord, _UnescapedChars)) +} +_Translator.update({ord('"'): '\\"', ord("\\"): "\\\\"}) def _quote(str): @@ -30,7 +28,7 @@ def _quote(str): return '"' + str.translate(_Translator) + '"' -_is_legal_key = re.compile('[%s]+' % re.escape(_LegalChars)).fullmatch +_is_legal_key = re.compile("[%s]+" % re.escape(_LegalChars)).fullmatch # ------------------------------------------------------------ # # Custom SimpleCookie @@ -53,7 +51,7 @@ class CookieJar(dict): # If this cookie doesn't exist, add it to the header keys if not self.cookie_headers.get(key): cookie = Cookie(key, value) - cookie['path'] = '/' + cookie["path"] = "/" self.cookie_headers[key] = self.header_key self.headers.add(self.header_key, cookie) return super().__setitem__(key, cookie) @@ -62,8 +60,8 @@ class CookieJar(dict): def __delitem__(self, key): if key not in self.cookie_headers: - self[key] = '' - self[key]['max-age'] = 0 + self[key] = "" + self[key]["max-age"] = 0 else: cookie_header = self.cookie_headers[key] # remove it from header @@ -77,6 +75,7 @@ class CookieJar(dict): class Cookie(dict): """A stripped down version of Morsel from SimpleCookie #gottagofast""" + _keys = { "expires": "expires", "path": "Path", @@ -88,7 +87,7 @@ class Cookie(dict): "version": "Version", "samesite": "SameSite", } - _flags = {'secure', 'httponly'} + _flags = {"secure", "httponly"} def __init__(self, key, value): if key in self._keys: @@ -106,24 +105,27 @@ class Cookie(dict): return super().__setitem__(key, value) def encode(self, encoding): - output = ['%s=%s' % (self.key, _quote(self.value))] + output = ["%s=%s" % (self.key, _quote(self.value))] for key, value in self.items(): - if key == 'max-age': + if key == "max-age": try: - output.append('%s=%d' % (self._keys[key], value)) + output.append("%s=%d" % (self._keys[key], value)) except TypeError: - output.append('%s=%s' % (self._keys[key], value)) - elif key == 'expires': + output.append("%s=%s" % (self._keys[key], value)) + elif key == "expires": try: - output.append('%s=%s' % ( - self._keys[key], - value.strftime("%a, %d-%b-%Y %T GMT") - )) + output.append( + "%s=%s" + % ( + self._keys[key], + value.strftime("%a, %d-%b-%Y %T GMT"), + ) + ) except AttributeError: - output.append('%s=%s' % (self._keys[key], value)) + output.append("%s=%s" % (self._keys[key], value)) elif key in self._flags and self[key]: output.append(self._keys[key]) else: - output.append('%s=%s' % (self._keys[key], value)) + output.append("%s=%s" % (self._keys[key], value)) return "; ".join(output).encode(encoding) diff --git a/sanic/exceptions.py b/sanic/exceptions.py index b535b38f..6e9323a9 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -1,6 +1,6 @@ from sanic.helpers import STATUS_CODES -TRACEBACK_STYLE = ''' +TRACEBACK_STYLE = """ -''' +""" -TRACEBACK_WRAPPER_HTML = ''' +TRACEBACK_WRAPPER_HTML = """ {style} @@ -78,27 +78,27 @@ TRACEBACK_WRAPPER_HTML = ''' -''' +""" -TRACEBACK_WRAPPER_INNER_HTML = ''' +TRACEBACK_WRAPPER_INNER_HTML = """

{exc_name}

{exc_value}

Traceback (most recent call last):

{frame_html}
-''' +""" -TRACEBACK_BORDER = ''' +TRACEBACK_BORDER = """
The above exception was the direct cause of the following exception:
-''' +""" -TRACEBACK_LINE_HTML = ''' +TRACEBACK_LINE_HTML = """

File {0.filename}, line {0.lineno}, @@ -106,15 +106,15 @@ TRACEBACK_LINE_HTML = '''

{0.line}

-''' +""" -INTERNAL_SERVER_ERROR_HTML = ''' +INTERNAL_SERVER_ERROR_HTML = """

Internal Server Error

The server encountered an internal error and cannot complete your request.

-''' +""" _sanic_exceptions = {} @@ -124,15 +124,16 @@ def add_status_code(code): """ Decorator used for adding exceptions to _sanic_exceptions. """ + def class_decorator(cls): cls.status_code = code _sanic_exceptions[code] = cls return cls + return class_decorator class SanicException(Exception): - def __init__(self, message, status_code=None): super().__init__(message) @@ -156,8 +157,8 @@ class MethodNotSupported(SanicException): super().__init__(message) self.headers = dict() self.headers["Allow"] = ", ".join(allowed_methods) - if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: - self.headers['Content-Length'] = 0 + if method in ["HEAD", "PATCH", "PUT", "DELETE"]: + self.headers["Content-Length"] = 0 @add_status_code(500) @@ -169,6 +170,7 @@ class ServerError(SanicException): class ServiceUnavailable(SanicException): """The server is currently unavailable (because it is overloaded or down for maintenance). Generally, this is a temporary state.""" + pass @@ -192,6 +194,7 @@ class RequestTimeout(SanicException): the connection. The socket connection has actually been lost - the Web server has 'timed out' on that particular socket connection. """ + pass @@ -209,8 +212,8 @@ class ContentRangeError(SanicException): def __init__(self, message, content_range): super().__init__(message) self.headers = { - 'Content-Type': 'text/plain', - "Content-Range": "bytes */%s" % (content_range.total,) + "Content-Type": "text/plain", + "Content-Range": "bytes */%s" % (content_range.total,), } @@ -225,7 +228,7 @@ class InvalidRangeType(ContentRangeError): class PyFileError(Exception): def __init__(self, file): - super().__init__('could not execute config file %s', file) + super().__init__("could not execute config file %s", file) @add_status_code(401) @@ -263,13 +266,14 @@ class Unauthorized(SanicException): scheme="Bearer", realm="Restricted Area") """ + def __init__(self, message, status_code=None, scheme=None, **kwargs): super().__init__(message, status_code) # if auth-scheme is specified, set "WWW-Authenticate" header if scheme is not None: values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()] - challenge = ', '.join(values) + challenge = ", ".join(values) self.headers = { "WWW-Authenticate": "{} {}".format(scheme, challenge).rstrip() @@ -288,6 +292,6 @@ def abort(status_code, message=None): if message is None: message = STATUS_CODES.get(status_code) # These are stored as bytes in the STATUS_CODES dict - message = message.decode('utf8') + message = message.decode("utf8") sanic_exception = _sanic_exceptions.get(status_code, SanicException) raise sanic_exception(message=message, status_code=status_code) diff --git a/sanic/handlers.py b/sanic/handlers.py index 30923e27..ec6de82e 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -11,7 +11,8 @@ from sanic.exceptions import ( TRACEBACK_STYLE, TRACEBACK_WRAPPER_HTML, TRACEBACK_WRAPPER_INNER_HTML, - TRACEBACK_BORDER) + TRACEBACK_BORDER, +) from sanic.log import logger from sanic.response import text, html @@ -36,7 +37,8 @@ class ErrorHandler: return TRACEBACK_WRAPPER_INNER_HTML.format( exc_name=exception.__class__.__name__, exc_value=exception, - frame_html=''.join(frame_html)) + frame_html="".join(frame_html), + ) def _render_traceback_html(self, exception, request): exc_type, exc_value, tb = sys.exc_info() @@ -51,7 +53,8 @@ class ErrorHandler: exc_name=exception.__class__.__name__, exc_value=exception, inner_html=TRACEBACK_BORDER.join(reversed(exceptions)), - path=request.path) + path=request.path, + ) def add(self, exception, handler): self.handlers.append((exception, handler)) @@ -88,17 +91,18 @@ class ErrorHandler: url = repr(request.url) except AttributeError: url = "unknown" - response_message = ('Exception raised in exception handler ' - '"%s" for uri: %s') + response_message = ( + "Exception raised in exception handler " '"%s" for uri: %s' + ) logger.exception(response_message, handler.__name__, url) if self.debug: return text(response_message % (handler.__name__, url), 500) else: - return text('An error occurred while handling an error', 500) + return text("An error occurred while handling an error", 500) return response - def log(self, message, level='error'): + def log(self, message, level="error"): """ Deprecated, do not use. """ @@ -110,14 +114,14 @@ class ErrorHandler: except AttributeError: url = "unknown" - response_message = ('Exception occurred while handling uri: %s') + response_message = "Exception occurred while handling uri: %s" logger.exception(response_message, url) if issubclass(type(exception), SanicException): return text( - 'Error: {}'.format(exception), - status=getattr(exception, 'status_code', 500), - headers=getattr(exception, 'headers', dict()) + "Error: {}".format(exception), + status=getattr(exception, "status_code", 500), + headers=getattr(exception, "headers", dict()), ) elif self.debug: html_output = self._render_traceback_html(exception, request) @@ -129,32 +133,37 @@ class ErrorHandler: class ContentRangeHandler: """Class responsible for parsing request header""" - __slots__ = ('start', 'end', 'size', 'total', 'headers') + + __slots__ = ("start", "end", "size", "total", "headers") def __init__(self, request, stats): self.total = stats.st_size - _range = request.headers.get('Range') + _range = request.headers.get("Range") if _range is None: - raise HeaderNotFound('Range Header Not Found') - unit, _, value = tuple(map(str.strip, _range.partition('='))) - if unit != 'bytes': + raise HeaderNotFound("Range Header Not Found") + unit, _, value = tuple(map(str.strip, _range.partition("="))) + if unit != "bytes": raise InvalidRangeType( - '%s is not a valid Range Type' % (unit,), self) - start_b, _, end_b = tuple(map(str.strip, value.partition('-'))) + "%s is not a valid Range Type" % (unit,), self + ) + start_b, _, end_b = tuple(map(str.strip, value.partition("-"))) try: self.start = int(start_b) if start_b else None except ValueError: raise ContentRangeError( - '\'%s\' is invalid for Content Range' % (start_b,), self) + "'%s' is invalid for Content Range" % (start_b,), self + ) try: self.end = int(end_b) if end_b else None except ValueError: raise ContentRangeError( - '\'%s\' is invalid for Content Range' % (end_b,), self) + "'%s' is invalid for Content Range" % (end_b,), self + ) if self.end is None: if self.start is None: raise ContentRangeError( - 'Invalid for Content Range parameters', self) + "Invalid for Content Range parameters", self + ) else: # this case represents `Content-Range: bytes 5-` self.end = self.total @@ -165,11 +174,13 @@ class ContentRangeHandler: self.end = self.total if self.start >= self.end: raise ContentRangeError( - 'Invalid for Content Range parameters', self) + "Invalid for Content Range parameters", self + ) self.size = self.end - self.start self.headers = { - 'Content-Range': "bytes %s-%s/%s" % ( - self.start, self.end, self.total)} + "Content-Range": "bytes %s-%s/%s" + % (self.start, self.end, self.total) + } def __bool__(self): return self.size > 0 diff --git a/sanic/helpers.py b/sanic/helpers.py index 482253b3..d992f92e 100644 --- a/sanic/helpers.py +++ b/sanic/helpers.py @@ -1,93 +1,97 @@ """Defines basics of HTTP standard.""" STATUS_CODES = { - 100: b'Continue', - 101: b'Switching Protocols', - 102: b'Processing', - 200: b'OK', - 201: b'Created', - 202: b'Accepted', - 203: b'Non-Authoritative Information', - 204: b'No Content', - 205: b'Reset Content', - 206: b'Partial Content', - 207: b'Multi-Status', - 208: b'Already Reported', - 226: b'IM Used', - 300: b'Multiple Choices', - 301: b'Moved Permanently', - 302: b'Found', - 303: b'See Other', - 304: b'Not Modified', - 305: b'Use Proxy', - 307: b'Temporary Redirect', - 308: b'Permanent Redirect', - 400: b'Bad Request', - 401: b'Unauthorized', - 402: b'Payment Required', - 403: b'Forbidden', - 404: b'Not Found', - 405: b'Method Not Allowed', - 406: b'Not Acceptable', - 407: b'Proxy Authentication Required', - 408: b'Request Timeout', - 409: b'Conflict', - 410: b'Gone', - 411: b'Length Required', - 412: b'Precondition Failed', - 413: b'Request Entity Too Large', - 414: b'Request-URI Too Long', - 415: b'Unsupported Media Type', - 416: b'Requested Range Not Satisfiable', - 417: b'Expectation Failed', - 418: b'I\'m a teapot', - 422: b'Unprocessable Entity', - 423: b'Locked', - 424: b'Failed Dependency', - 426: b'Upgrade Required', - 428: b'Precondition Required', - 429: b'Too Many Requests', - 431: b'Request Header Fields Too Large', - 451: b'Unavailable For Legal Reasons', - 500: b'Internal Server Error', - 501: b'Not Implemented', - 502: b'Bad Gateway', - 503: b'Service Unavailable', - 504: b'Gateway Timeout', - 505: b'HTTP Version Not Supported', - 506: b'Variant Also Negotiates', - 507: b'Insufficient Storage', - 508: b'Loop Detected', - 510: b'Not Extended', - 511: b'Network Authentication Required' + 100: b"Continue", + 101: b"Switching Protocols", + 102: b"Processing", + 200: b"OK", + 201: b"Created", + 202: b"Accepted", + 203: b"Non-Authoritative Information", + 204: b"No Content", + 205: b"Reset Content", + 206: b"Partial Content", + 207: b"Multi-Status", + 208: b"Already Reported", + 226: b"IM Used", + 300: b"Multiple Choices", + 301: b"Moved Permanently", + 302: b"Found", + 303: b"See Other", + 304: b"Not Modified", + 305: b"Use Proxy", + 307: b"Temporary Redirect", + 308: b"Permanent Redirect", + 400: b"Bad Request", + 401: b"Unauthorized", + 402: b"Payment Required", + 403: b"Forbidden", + 404: b"Not Found", + 405: b"Method Not Allowed", + 406: b"Not Acceptable", + 407: b"Proxy Authentication Required", + 408: b"Request Timeout", + 409: b"Conflict", + 410: b"Gone", + 411: b"Length Required", + 412: b"Precondition Failed", + 413: b"Request Entity Too Large", + 414: b"Request-URI Too Long", + 415: b"Unsupported Media Type", + 416: b"Requested Range Not Satisfiable", + 417: b"Expectation Failed", + 418: b"I'm a teapot", + 422: b"Unprocessable Entity", + 423: b"Locked", + 424: b"Failed Dependency", + 426: b"Upgrade Required", + 428: b"Precondition Required", + 429: b"Too Many Requests", + 431: b"Request Header Fields Too Large", + 451: b"Unavailable For Legal Reasons", + 500: b"Internal Server Error", + 501: b"Not Implemented", + 502: b"Bad Gateway", + 503: b"Service Unavailable", + 504: b"Gateway Timeout", + 505: b"HTTP Version Not Supported", + 506: b"Variant Also Negotiates", + 507: b"Insufficient Storage", + 508: b"Loop Detected", + 510: b"Not Extended", + 511: b"Network Authentication Required", } # According to https://tools.ietf.org/html/rfc2616#section-7.1 -_ENTITY_HEADERS = frozenset([ - 'allow', - 'content-encoding', - 'content-language', - 'content-length', - 'content-location', - 'content-md5', - 'content-range', - 'content-type', - 'expires', - 'last-modified', - 'extension-header' -]) +_ENTITY_HEADERS = frozenset( + [ + "allow", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-type", + "expires", + "last-modified", + "extension-header", + ] +) # According to https://tools.ietf.org/html/rfc2616#section-13.5.1 -_HOP_BY_HOP_HEADERS = frozenset([ - 'connection', - 'keep-alive', - 'proxy-authenticate', - 'proxy-authorization', - 'te', - 'trailers', - 'transfer-encoding', - 'upgrade' -]) +_HOP_BY_HOP_HEADERS = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ] +) def has_message_body(status): @@ -110,8 +114,7 @@ def is_hop_by_hop_header(header): return header.lower() in _HOP_BY_HOP_HEADERS -def remove_entity_headers(headers, - allowed=('content-location', 'expires')): +def remove_entity_headers(headers, allowed=("content-location", "expires")): """ Removes all the entity headers present in the headers given. According to RFC 2616 Section 10.3.5, @@ -122,7 +125,9 @@ def remove_entity_headers(headers, returns the headers without the entity headers """ allowed = set([h.lower() for h in allowed]) - headers = {header: value for header, value in headers.items() - if not is_entity_header(header) - and header.lower() not in allowed} + headers = { + header: value + for header, value in headers.items() + if not is_entity_header(header) and header.lower() not in allowed + } return headers diff --git a/sanic/log.py b/sanic/log.py index 67381ce6..cb8ca524 100644 --- a/sanic/log.py +++ b/sanic/log.py @@ -5,59 +5,54 @@ import sys LOGGING_CONFIG_DEFAULTS = dict( version=1, disable_existing_loggers=False, - loggers={ - "root": { - "level": "INFO", - "handlers": ["console"] - }, + "root": {"level": "INFO", "handlers": ["console"]}, "sanic.error": { "level": "INFO", "handlers": ["error_console"], "propagate": True, - "qualname": "sanic.error" + "qualname": "sanic.error", }, - "sanic.access": { "level": "INFO", "handlers": ["access_console"], "propagate": True, - "qualname": "sanic.access" - } + "qualname": "sanic.access", + }, }, handlers={ "console": { "class": "logging.StreamHandler", "formatter": "generic", - "stream": sys.stdout + "stream": sys.stdout, }, "error_console": { "class": "logging.StreamHandler", "formatter": "generic", - "stream": sys.stderr + "stream": sys.stderr, }, "access_console": { "class": "logging.StreamHandler", "formatter": "access", - "stream": sys.stdout + "stream": sys.stdout, }, }, formatters={ "generic": { "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", "datefmt": "[%Y-%m-%d %H:%M:%S %z]", - "class": "logging.Formatter" + "class": "logging.Formatter", }, "access": { - "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + - "%(request)s %(message)s %(status)d %(byte)d", + "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + + "%(request)s %(message)s %(status)d %(byte)d", "datefmt": "[%Y-%m-%d %H:%M:%S %z]", - "class": "logging.Formatter" + "class": "logging.Formatter", }, - } + }, ) -logger = logging.getLogger('sanic.root') -error_logger = logging.getLogger('sanic.error') -access_logger = logging.getLogger('sanic.access') +logger = logging.getLogger("sanic.root") +error_logger = logging.getLogger("sanic.error") +access_logger = logging.getLogger("sanic.access") diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index e0cb42e0..70bc2fcc 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -18,7 +18,7 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: old = None while not os.path.isfile(filename): @@ -27,7 +27,7 @@ def _iter_module_files(): if filename == old: break else: - if filename[-4:] in ('.pyc', '.pyo'): + if filename[-4:] in (".pyc", ".pyo"): filename = filename[:-1] yield filename @@ -45,11 +45,13 @@ def restart_with_reloader(): """ args = _get_args_for_reloading() new_environ = os.environ.copy() - new_environ['SANIC_SERVER_RUNNING'] = 'true' - cmd = ' '.join(args) + new_environ["SANIC_SERVER_RUNNING"] = "true" + cmd = " ".join(args) worker_process = Process( - target=subprocess.call, args=(cmd,), - kwargs=dict(shell=True, env=new_environ)) + target=subprocess.call, + args=(cmd,), + kwargs=dict(shell=True, env=new_environ), + ) worker_process.start() return worker_process @@ -67,8 +69,10 @@ def kill_process_children_unix(pid): children_list_pid = children_list_file.read().split() for child_pid in children_list_pid: - children_proc_path = "/proc/%s/task/%s/children" % \ - (child_pid, child_pid) + children_proc_path = "/proc/%s/task/%s/children" % ( + child_pid, + child_pid, + ) if not os.path.isfile(children_proc_path): continue with open(children_proc_path) as children_list_file_2: @@ -90,7 +94,7 @@ def kill_process_children_osx(pid): :param pid: PID of parent process (process ID) :return: Nothing """ - subprocess.run(['pkill', '-P', str(pid)]) + subprocess.run(["pkill", "-P", str(pid)]) def kill_process_children(pid): @@ -99,12 +103,12 @@ def kill_process_children(pid): :param pid: PID of parent process (process ID) :return: Nothing """ - if sys.platform == 'darwin': + if sys.platform == "darwin": kill_process_children_osx(pid) - elif sys.platform == 'linux': + elif sys.platform == "linux": kill_process_children_unix(pid) else: - pass # should signal error here + pass # should signal error here def kill_program_completly(proc): @@ -127,9 +131,11 @@ def watchdog(sleep_interval): mtimes = {} worker_process = restart_with_reloader() signal.signal( - signal.SIGTERM, lambda *args: kill_program_completly(worker_process)) + signal.SIGTERM, lambda *args: kill_program_completly(worker_process) + ) signal.signal( - signal.SIGINT, lambda *args: kill_program_completly(worker_process)) + signal.SIGINT, lambda *args: kill_program_completly(worker_process) + ) while True: for filename in _iter_module_files(): try: diff --git a/sanic/request.py b/sanic/request.py index 70240207..fec43bc9 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -10,9 +10,11 @@ try: from ujson import loads as json_loads except ImportError: if sys.version_info[:2] == (3, 5): + def json_loads(data): # on Python 3.5 json.loads only supports str not bytes return json.loads(data.decode()) + else: json_loads = json.loads @@ -43,11 +45,28 @@ class RequestParameters(dict): class Request(dict): """Properties of an HTTP request such as URL, headers, etc.""" + __slots__ = ( - 'app', 'headers', 'version', 'method', '_cookies', 'transport', - 'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', - '_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr', - '_socket', '_port', '__weakref__', 'raw_url' + "app", + "headers", + "version", + "method", + "_cookies", + "transport", + "body", + "parsed_json", + "parsed_args", + "parsed_form", + "parsed_files", + "_ip", + "_parsed_url", + "uri_template", + "stream", + "_remote_addr", + "_socket", + "_port", + "__weakref__", + "raw_url", ) def __init__(self, url_bytes, headers, version, method, transport): @@ -73,10 +92,10 @@ class Request(dict): def __repr__(self): if self.method is None or not self.path: - return '<{0}>'.format(self.__class__.__name__) - return '<{0}: {1} {2}>'.format(self.__class__.__name__, - self.method, - self.path) + return "<{0}>".format(self.__class__.__name__) + return "<{0}: {1} {2}>".format( + self.__class__.__name__, self.method, self.path + ) def __bool__(self): if self.transport: @@ -106,8 +125,8 @@ class Request(dict): :return: token related to request """ - prefixes = ('Bearer', 'Token') - auth_header = self.headers.get('Authorization') + prefixes = ("Bearer", "Token") + auth_header = self.headers.get("Authorization") if auth_header is not None: for prefix in prefixes: @@ -122,17 +141,20 @@ class Request(dict): self.parsed_form = RequestParameters() self.parsed_files = RequestParameters() content_type = self.headers.get( - 'Content-Type', DEFAULT_HTTP_CONTENT_TYPE) + "Content-Type", DEFAULT_HTTP_CONTENT_TYPE + ) content_type, parameters = parse_header(content_type) try: - if content_type == 'application/x-www-form-urlencoded': + if content_type == "application/x-www-form-urlencoded": self.parsed_form = RequestParameters( - parse_qs(self.body.decode('utf-8'))) - elif content_type == 'multipart/form-data': + parse_qs(self.body.decode("utf-8")) + ) + elif content_type == "multipart/form-data": # TODO: Stream this instead of reading to/from memory - boundary = parameters['boundary'].encode('utf-8') - self.parsed_form, self.parsed_files = ( - parse_multipart_form(self.body, boundary)) + boundary = parameters["boundary"].encode("utf-8") + self.parsed_form, self.parsed_files = parse_multipart_form( + self.body, boundary + ) except Exception: error_logger.exception("Failed when parsing form") @@ -150,7 +172,8 @@ class Request(dict): if self.parsed_args is None: if self.query_string: self.parsed_args = RequestParameters( - parse_qs(self.query_string)) + parse_qs(self.query_string) + ) else: self.parsed_args = RequestParameters() return self.parsed_args @@ -162,37 +185,40 @@ class Request(dict): @property def cookies(self): if self._cookies is None: - cookie = self.headers.get('Cookie') + cookie = self.headers.get("Cookie") if cookie is not None: cookies = SimpleCookie() cookies.load(cookie) - self._cookies = {name: cookie.value - for name, cookie in cookies.items()} + self._cookies = { + name: cookie.value for name, cookie in cookies.items() + } else: self._cookies = {} return self._cookies @property def ip(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._ip @property def port(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._port @property def socket(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._socket def _get_address(self): - self._socket = self.transport.get_extra_info('peername') or \ - (None, None) + self._socket = self.transport.get_extra_info("peername") or ( + None, + None, + ) self._ip = self._socket[0] self._port = self._socket[1] @@ -202,29 +228,31 @@ class Request(dict): :return: original client ip. """ - if not hasattr(self, '_remote_addr'): - forwarded_for = self.headers.get('X-Forwarded-For', '').split(',') + if not hasattr(self, "_remote_addr"): + forwarded_for = self.headers.get("X-Forwarded-For", "").split(",") remote_addrs = [ - addr for addr in [ - addr.strip() for addr in forwarded_for - ] if addr - ] + addr + for addr in [addr.strip() for addr in forwarded_for] + if addr + ] if len(remote_addrs) > 0: self._remote_addr = remote_addrs[0] else: - self._remote_addr = '' + self._remote_addr = "" return self._remote_addr @property def scheme(self): - if self.app.websocket_enabled \ - and self.headers.get('upgrade') == 'websocket': - scheme = 'ws' + if ( + self.app.websocket_enabled + and self.headers.get("upgrade") == "websocket" + ): + scheme = "ws" else: - scheme = 'http' + scheme = "http" - if self.transport.get_extra_info('sslcontext'): - scheme += 's' + if self.transport.get_extra_info("sslcontext"): + scheme += "s" return scheme @@ -232,11 +260,11 @@ class Request(dict): def host(self): # it appears that httptools doesn't return the host # so pull it from the headers - return self.headers.get('Host', '') + return self.headers.get("Host", "") @property def content_type(self): - return self.headers.get('Content-Type', DEFAULT_HTTP_CONTENT_TYPE) + return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) @property def match_info(self): @@ -245,27 +273,23 @@ class Request(dict): @property def path(self): - return self._parsed_url.path.decode('utf-8') + return self._parsed_url.path.decode("utf-8") @property def query_string(self): if self._parsed_url.query: - return self._parsed_url.query.decode('utf-8') + return self._parsed_url.query.decode("utf-8") else: - return '' + return "" @property def url(self): - return urlunparse(( - self.scheme, - self.host, - self.path, - None, - self.query_string, - None)) + return urlunparse( + (self.scheme, self.host, self.path, None, self.query_string, None) + ) -File = namedtuple('File', ['type', 'body', 'name']) +File = namedtuple("File", ["type", "body", "name"]) def parse_multipart_form(body, boundary): @@ -281,37 +305,38 @@ def parse_multipart_form(body, boundary): form_parts = body.split(boundary) for form_part in form_parts[1:-1]: file_name = None - content_type = 'text/plain' - content_charset = 'utf-8' + content_type = "text/plain" + content_charset = "utf-8" field_name = None line_index = 2 line_end_index = 0 while not line_end_index == -1: - line_end_index = form_part.find(b'\r\n', line_index) - form_line = form_part[line_index:line_end_index].decode('utf-8') + line_end_index = form_part.find(b"\r\n", line_index) + form_line = form_part[line_index:line_end_index].decode("utf-8") line_index = line_end_index + 2 if not form_line: break - colon_index = form_line.index(':') + colon_index = form_line.index(":") form_header_field = form_line[0:colon_index].lower() form_header_value, form_parameters = parse_header( - form_line[colon_index + 2:]) + form_line[colon_index + 2 :] + ) - if form_header_field == 'content-disposition': - file_name = form_parameters.get('filename') - field_name = form_parameters.get('name') - elif form_header_field == 'content-type': + if form_header_field == "content-disposition": + file_name = form_parameters.get("filename") + field_name = form_parameters.get("name") + elif form_header_field == "content-type": content_type = form_header_value - content_charset = form_parameters.get('charset', 'utf-8') + content_charset = form_parameters.get("charset", "utf-8") if field_name: post_data = form_part[line_index:-4] if file_name: - form_file = File(type=content_type, - name=file_name, - body=post_data) + form_file = File( + type=content_type, name=file_name, body=post_data + ) if field_name in files: files[field_name].append(form_file) else: @@ -323,7 +348,9 @@ def parse_multipart_form(body, boundary): else: fields[field_name] = [value] else: - logger.debug('Form-data field does not have a \'name\' parameter \ - in the Content-Disposition header') + logger.debug( + "Form-data field does not have a 'name' parameter \ + in the Content-Disposition header" + ) return fields, files diff --git a/sanic/response.py b/sanic/response.py index ed1df0f4..cf9aaba2 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -24,16 +24,18 @@ class BaseHTTPResponse: return str(data).encode() def _parse_headers(self): - headers = b'' + headers = b"" for name, value in self.headers.items(): try: - headers += ( - b'%b: %b\r\n' % ( - name.encode(), value.encode('utf-8'))) + headers += b"%b: %b\r\n" % ( + name.encode(), + value.encode("utf-8"), + ) except AttributeError: - headers += ( - b'%b: %b\r\n' % ( - str(name).encode(), str(value).encode('utf-8'))) + headers += b"%b: %b\r\n" % ( + str(name).encode(), + str(value).encode("utf-8"), + ) return headers @@ -46,12 +48,17 @@ class BaseHTTPResponse: class StreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( - 'protocol', 'streaming_fn', 'status', - 'content_type', 'headers', '_cookies' + "protocol", + "streaming_fn", + "status", + "content_type", + "headers", + "_cookies", ) - def __init__(self, streaming_fn, status=200, headers=None, - content_type='text/plain'): + def __init__( + self, streaming_fn, status=200, headers=None, content_type="text/plain" + ): self.content_type = content_type self.streaming_fn = streaming_fn self.status = status @@ -66,61 +73,69 @@ class StreamingHTTPResponse(BaseHTTPResponse): if type(data) != bytes: data = self._encode_body(data) - self.protocol.push_data( - b"%x\r\n%b\r\n" % (len(data), data)) + self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) await self.protocol.drain() async def stream( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + self, version="1.1", keep_alive=False, keep_alive_timeout=None + ): """Streams headers, runs the `streaming_fn` callback that writes content to the response body, then finalizes the response body. """ headers = self.get_headers( - version, keep_alive=keep_alive, - keep_alive_timeout=keep_alive_timeout) + version, + keep_alive=keep_alive, + keep_alive_timeout=keep_alive_timeout, + ) self.protocol.push_data(headers) await self.protocol.drain() await self.streaming_fn(self) - self.protocol.push_data(b'0\r\n\r\n') + self.protocol.push_data(b"0\r\n\r\n") # no need to await drain here after this write, because it is the # very last thing we write and nothing needs to wait for it. def get_headers( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + self, version="1.1", keep_alive=False, keep_alive_timeout=None + ): # This is all returned in a kind-of funky way # We tried to make this as fast as possible in pure python - timeout_header = b'' + timeout_header = b"" if keep_alive and keep_alive_timeout is not None: - timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - self.headers['Transfer-Encoding'] = 'chunked' - self.headers.pop('Content-Length', None) - self.headers['Content-Type'] = self.headers.get( - 'Content-Type', self.content_type) + self.headers["Transfer-Encoding"] = "chunked" + self.headers.pop("Content-Length", None) + self.headers["Content-Type"] = self.headers.get( + "Content-Type", self.content_type + ) headers = self._parse_headers() if self.status is 200: - status = b'OK' + status = b"OK" else: status = STATUS_CODES.get(self.status) - return (b'HTTP/%b %d %b\r\n' - b'%b' - b'%b\r\n') % ( - version.encode(), - self.status, - status, - timeout_header, - headers - ) + return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % ( + version.encode(), + self.status, + status, + timeout_header, + headers, + ) class HTTPResponse(BaseHTTPResponse): - __slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') + __slots__ = ("body", "status", "content_type", "headers", "_cookies") - def __init__(self, body=None, status=200, headers=None, - content_type='text/plain', body_bytes=b''): + def __init__( + self, + body=None, + status=200, + headers=None, + content_type="text/plain", + body_bytes=b"", + ): self.content_type = content_type if body is not None: @@ -132,22 +147,23 @@ class HTTPResponse(BaseHTTPResponse): self.headers = CIMultiDict(headers or {}) self._cookies = None - def output( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): # This is all returned in a kind-of funky way # We tried to make this as fast as possible in pure python - timeout_header = b'' + timeout_header = b"" if keep_alive and keep_alive_timeout is not None: - timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - body = b'' + body = b"" if has_message_body(self.status): body = self.body - self.headers['Content-Length'] = self.headers.get( - 'Content-Length', len(self.body)) + self.headers["Content-Length"] = self.headers.get( + "Content-Length", len(self.body) + ) - self.headers['Content-Type'] = self.headers.get( - 'Content-Type', self.content_type) + self.headers["Content-Type"] = self.headers.get( + "Content-Type", self.content_type + ) if self.status in (304, 412): self.headers = remove_entity_headers(self.headers) @@ -155,23 +171,21 @@ class HTTPResponse(BaseHTTPResponse): headers = self._parse_headers() if self.status is 200: - status = b'OK' + status = b"OK" else: - status = STATUS_CODES.get(self.status, b'UNKNOWN RESPONSE') + status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - return (b'HTTP/%b %d %b\r\n' - b'Connection: %b\r\n' - b'%b' - b'%b\r\n' - b'%b') % ( - version.encode(), - self.status, - status, - b'keep-alive' if keep_alive else b'close', - timeout_header, - headers, - body - ) + return ( + b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" + ) % ( + version.encode(), + self.status, + status, + b"keep-alive" if keep_alive else b"close", + timeout_header, + headers, + body, + ) @property def cookies(self): @@ -180,9 +194,14 @@ class HTTPResponse(BaseHTTPResponse): return self._cookies -def json(body, status=200, headers=None, - content_type="application/json", dumps=json_dumps, - **kwargs): +def json( + body, + status=200, + headers=None, + content_type="application/json", + dumps=json_dumps, + **kwargs +): """ Returns response object with body in json format. @@ -191,12 +210,17 @@ def json(body, status=200, headers=None, :param headers: Custom Headers. :param kwargs: Remaining arguments that are passed to the json encoder. """ - return HTTPResponse(dumps(body, **kwargs), headers=headers, - status=status, content_type=content_type) + return HTTPResponse( + dumps(body, **kwargs), + headers=headers, + status=status, + content_type=content_type, + ) -def text(body, status=200, headers=None, - content_type="text/plain; charset=utf-8"): +def text( + body, status=200, headers=None, content_type="text/plain; charset=utf-8" +): """ Returns response object with body in text format. @@ -206,12 +230,13 @@ def text(body, status=200, headers=None, :param content_type: the content type (string) of the response """ return HTTPResponse( - body, status=status, headers=headers, - content_type=content_type) + body, status=status, headers=headers, content_type=content_type + ) -def raw(body, status=200, headers=None, - content_type="application/octet-stream"): +def raw( + body, status=200, headers=None, content_type="application/octet-stream" +): """ Returns response object without encoding the body. @@ -220,8 +245,12 @@ def raw(body, status=200, headers=None, :param headers: Custom Headers. :param content_type: the content type (string) of the response. """ - return HTTPResponse(body_bytes=body, status=status, headers=headers, - content_type=content_type) + return HTTPResponse( + body_bytes=body, + status=status, + headers=headers, + content_type=content_type, + ) def html(body, status=200, headers=None): @@ -232,12 +261,22 @@ def html(body, status=200, headers=None): :param status: Response code. :param headers: Custom Headers. """ - return HTTPResponse(body, status=status, headers=headers, - content_type="text/html; charset=utf-8") + return HTTPResponse( + body, + status=status, + headers=headers, + content_type="text/html; charset=utf-8", + ) -async def file(location, status=200, mime_type=None, headers=None, - filename=None, _range=None): +async def file( + location, + status=200, + mime_type=None, + headers=None, + filename=None, + _range=None, +): """Return a response object with file data. :param location: Location of file on system. @@ -249,28 +288,40 @@ async def file(location, status=200, mime_type=None, headers=None, headers = headers or {} if filename: headers.setdefault( - 'Content-Disposition', - 'attachment; filename="{}"'.format(filename)) + "Content-Disposition", 'attachment; filename="{}"'.format(filename) + ) filename = filename or path.split(location)[-1] - async with open_async(location, mode='rb') as _file: + async with open_async(location, mode="rb") as _file: if _range: await _file.seek(_range.start) out_stream = await _file.read(_range.size) - headers['Content-Range'] = 'bytes %s-%s/%s' % ( - _range.start, _range.end, _range.total) + headers["Content-Range"] = "bytes %s-%s/%s" % ( + _range.start, + _range.end, + _range.total, + ) else: out_stream = await _file.read() - mime_type = mime_type or guess_type(filename)[0] or 'text/plain' - return HTTPResponse(status=status, - headers=headers, - content_type=mime_type, - body_bytes=out_stream) + mime_type = mime_type or guess_type(filename)[0] or "text/plain" + return HTTPResponse( + status=status, + headers=headers, + content_type=mime_type, + body_bytes=out_stream, + ) -async def file_stream(location, status=200, chunk_size=4096, mime_type=None, - headers=None, filename=None, _range=None): +async def file_stream( + location, + status=200, + chunk_size=4096, + mime_type=None, + headers=None, + filename=None, + _range=None, +): """Return a streaming response object with file data. :param location: Location of file on system. @@ -283,11 +334,11 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None, headers = headers or {} if filename: headers.setdefault( - 'Content-Disposition', - 'attachment; filename="{}"'.format(filename)) + "Content-Disposition", 'attachment; filename="{}"'.format(filename) + ) filename = filename or path.split(location)[-1] - _file = await open_async(location, mode='rb') + _file = await open_async(location, mode="rb") async def _streaming_fn(response): nonlocal _file, chunk_size @@ -312,19 +363,27 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None, await _file.close() return # Returning from this fn closes the stream - mime_type = mime_type or guess_type(filename)[0] or 'text/plain' + mime_type = mime_type or guess_type(filename)[0] or "text/plain" if _range: - headers['Content-Range'] = 'bytes %s-%s/%s' % ( - _range.start, _range.end, _range.total) - return StreamingHTTPResponse(streaming_fn=_streaming_fn, - status=status, - headers=headers, - content_type=mime_type) + headers["Content-Range"] = "bytes %s-%s/%s" % ( + _range.start, + _range.end, + _range.total, + ) + return StreamingHTTPResponse( + streaming_fn=_streaming_fn, + status=status, + headers=headers, + content_type=mime_type, + ) def stream( - streaming_fn, status=200, headers=None, - content_type="text/plain; charset=utf-8"): + streaming_fn, + status=200, + headers=None, + content_type="text/plain; charset=utf-8", +): """Accepts an coroutine `streaming_fn` which can be used to write chunks to a streaming response. Returns a `StreamingHTTPResponse`. @@ -344,15 +403,13 @@ def stream( :param headers: Custom Headers. """ return StreamingHTTPResponse( - streaming_fn, - headers=headers, - content_type=content_type, - status=status + streaming_fn, headers=headers, content_type=content_type, status=status ) -def redirect(to, headers=None, status=302, - content_type="text/html; charset=utf-8"): +def redirect( + to, headers=None, status=302, content_type="text/html; charset=utf-8" +): """Abort execution and cause a 302 redirect (by default). :param to: path or fully qualified URL to redirect to @@ -367,9 +424,8 @@ def redirect(to, headers=None, status=302, safe_to = quote_plus(to, safe=":/#?&=@[]!$&'()*+,;") # According to RFC 7231, a relative URI is now permitted. - headers['Location'] = safe_to + headers["Location"] = safe_to return HTTPResponse( - status=status, - headers=headers, - content_type=content_type) + status=status, headers=headers, content_type=content_type + ) diff --git a/sanic/router.py b/sanic/router.py index 8ddba1a3..038c53bf 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -9,25 +9,28 @@ from sanic.exceptions import NotFound, MethodNotSupported from sanic.views import CompositionView Route = namedtuple( - 'Route', - ['handler', 'methods', 'pattern', 'parameters', 'name', 'uri']) -Parameter = namedtuple('Parameter', ['name', 'cast']) + "Route", ["handler", "methods", "pattern", "parameters", "name", "uri"] +) +Parameter = namedtuple("Parameter", ["name", "cast"]) REGEX_TYPES = { - 'string': (str, r'[^/]+'), - 'int': (int, r'\d+'), - 'number': (float, r'[0-9\\.]+'), - 'alpha': (str, r'[A-Za-z]+'), - 'path': (str, r'[^/].*?'), - 'uuid': (uuid.UUID, r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' - r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}') + "string": (str, r"[^/]+"), + "int": (int, r"\d+"), + "number": (float, r"[0-9\\.]+"), + "alpha": (str, r"[A-Za-z]+"), + "path": (str, r"[^/].*?"), + "uuid": ( + uuid.UUID, + r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" + r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}", + ), } ROUTER_CACHE_SIZE = 1024 def url_hash(url): - return url.count('/') + return url.count("/") class RouteExists(Exception): @@ -68,10 +71,11 @@ class Router: also be passed in as the type. The argument given to the function will always be a string, independent of the type. """ + routes_static = None routes_dynamic = None routes_always_check = None - parameter_pattern = re.compile(r'<(.+?)>') + parameter_pattern = re.compile(r"<(.+?)>") def __init__(self): self.routes_all = {} @@ -98,9 +102,9 @@ class Router: """ # We could receive NAME or NAME:PATTERN name = parameter_string - pattern = 'string' - if ':' in parameter_string: - name, pattern = parameter_string.split(':', 1) + pattern = "string" + if ":" in parameter_string: + name, pattern = parameter_string.split(":", 1) if not name: raise ValueError( "Invalid parameter syntax: {}".format(parameter_string) @@ -112,8 +116,16 @@ class Router: return name, _type, pattern - def add(self, uri, methods, handler, host=None, strict_slashes=False, - version=None, name=None): + def add( + self, + uri, + methods, + handler, + host=None, + strict_slashes=False, + version=None, + name=None, + ): """Add a handler to the route list :param uri: path to match @@ -127,8 +139,8 @@ class Router: :return: Nothing """ if version is not None: - version = re.escape(str(version).strip('/').lstrip('v')) - uri = "/".join(["/v{}".format(version), uri.lstrip('/')]) + version = re.escape(str(version).strip("/").lstrip("v")) + uri = "/".join(["/v{}".format(version), uri.lstrip("/")]) # add regular version self._add(uri, methods, handler, host, name) @@ -143,28 +155,26 @@ class Router: return # Add versions with and without trailing / - slashed_methods = self.routes_all.get(uri + '/', frozenset({})) + slashed_methods = self.routes_all.get(uri + "/", frozenset({})) unslashed_methods = self.routes_all.get(uri[:-1], frozenset({})) if isinstance(methods, Iterable): - _slash_is_missing = all(method in slashed_methods for - method in methods) - _without_slash_is_missing = all(method in unslashed_methods for - method in methods) + _slash_is_missing = all( + method in slashed_methods for method in methods + ) + _without_slash_is_missing = all( + method in unslashed_methods for method in methods + ) else: _slash_is_missing = methods in slashed_methods _without_slash_is_missing = methods in unslashed_methods - slash_is_missing = ( - not uri[-1] == '/' and not _slash_is_missing - ) + slash_is_missing = not uri[-1] == "/" and not _slash_is_missing without_slash_is_missing = ( - uri[-1] == '/' and not - _without_slash_is_missing and not - uri == '/' + uri[-1] == "/" and not _without_slash_is_missing and not uri == "/" ) # add version with trailing slash if slash_is_missing: - self._add(uri + '/', methods, handler, host, name) + self._add(uri + "/", methods, handler, host, name) # add version without trailing slash elif without_slash_is_missing: self._add(uri[:-1], methods, handler, host, name) @@ -187,8 +197,10 @@ class Router: else: if not isinstance(host, Iterable): - raise ValueError("Expected either string or Iterable of " - "host strings, not {!r}".format(host)) + raise ValueError( + "Expected either string or Iterable of " + "host strings, not {!r}".format(host) + ) for host_ in host: self.add(uri, methods, handler, host_, name) @@ -208,38 +220,39 @@ class Router: if name in parameter_names: raise ParameterNameConflicts( - "Multiple parameter named <{name}> " - "in route uri {uri}".format(name=name, uri=uri)) + "Multiple parameter named <{name}> " + "in route uri {uri}".format(name=name, uri=uri) + ) parameter_names.add(name) - parameter = Parameter( - name=name, cast=_type) + parameter = Parameter(name=name, cast=_type) parameters.append(parameter) # Mark the whole route as unhashable if it has the hash key in it - if re.search(r'(^|[^^]){1}/', pattern): - properties['unhashable'] = True + if re.search(r"(^|[^^]){1}/", pattern): + properties["unhashable"] = True # Mark the route as unhashable if it matches the hash key - elif re.search(r'/', pattern): - properties['unhashable'] = True + elif re.search(r"/", pattern): + properties["unhashable"] = True - return '({})'.format(pattern) + return "({})".format(pattern) pattern_string = re.sub(self.parameter_pattern, add_parameter, uri) - pattern = re.compile(r'^{}$'.format(pattern_string)) + pattern = re.compile(r"^{}$".format(pattern_string)) def merge_route(route, methods, handler): # merge to the existing route when possible. if not route.methods or not methods: # method-unspecified routes are not mergeable. - raise RouteExists( - "Route already registered: {}".format(uri)) + raise RouteExists("Route already registered: {}".format(uri)) elif route.methods.intersection(methods): # already existing method is not overloadable. duplicated = methods.intersection(route.methods) raise RouteExists( "Route already registered: {} [{}]".format( - uri, ','.join(list(duplicated)))) + uri, ",".join(list(duplicated)) + ) + ) if isinstance(route.handler, CompositionView): view = route.handler else: @@ -247,19 +260,22 @@ class Router: view.add(route.methods, route.handler) view.add(methods, handler) route = route._replace( - handler=view, methods=methods.union(route.methods)) + handler=view, methods=methods.union(route.methods) + ) return route if parameters: # TODO: This is too complex, we need to reduce the complexity - if properties['unhashable']: + if properties["unhashable"]: routes_to_check = self.routes_always_check ndx, route = self.check_dynamic_route_exists( - pattern, routes_to_check, parameters) + pattern, routes_to_check, parameters + ) else: routes_to_check = self.routes_dynamic[url_hash(uri)] ndx, route = self.check_dynamic_route_exists( - pattern, routes_to_check, parameters) + pattern, routes_to_check, parameters + ) if ndx != -1: # Pop the ndx of the route, no dups of the same route routes_to_check.pop(ndx) @@ -270,35 +286,41 @@ class Router: # if available # special prefix for static files is_static = False - if name and name.startswith('_static_'): + if name and name.startswith("_static_"): is_static = True - name = name.split('_static_', 1)[-1] + name = name.split("_static_", 1)[-1] - if hasattr(handler, '__blueprintname__'): - handler_name = '{}.{}'.format( - handler.__blueprintname__, name or handler.__name__) + if hasattr(handler, "__blueprintname__"): + handler_name = "{}.{}".format( + handler.__blueprintname__, name or handler.__name__ + ) else: - handler_name = name or getattr(handler, '__name__', None) + handler_name = name or getattr(handler, "__name__", None) if route: route = merge_route(route, methods, handler) else: route = Route( - handler=handler, methods=methods, pattern=pattern, - parameters=parameters, name=handler_name, uri=uri) + handler=handler, + methods=methods, + pattern=pattern, + parameters=parameters, + name=handler_name, + uri=uri, + ) self.routes_all[uri] = route if is_static: pair = self.routes_static_files.get(handler_name) - if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): + if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])): self.routes_static_files[handler_name] = (uri, route) else: pair = self.routes_names.get(handler_name) - if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): + if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])): self.routes_names[handler_name] = (uri, route) - if properties['unhashable']: + if properties["unhashable"]: self.routes_always_check.append(route) elif parameters: self.routes_dynamic[url_hash(uri)].append(route) @@ -333,8 +355,10 @@ class Router: if route in self.routes_always_check: self.routes_always_check.remove(route) - elif url_hash(uri) in self.routes_dynamic \ - and route in self.routes_dynamic[url_hash(uri)]: + elif ( + url_hash(uri) in self.routes_dynamic + and route in self.routes_dynamic[url_hash(uri)] + ): self.routes_dynamic[url_hash(uri)].remove(route) else: self.routes_static.pop(uri) @@ -353,7 +377,7 @@ class Router: if not view_name: return (None, None) - if view_name == 'static' or view_name.endswith('.static'): + if view_name == "static" or view_name.endswith(".static"): return self.routes_static_files.get(name, (None, None)) return self.routes_names.get(view_name, (None, None)) @@ -367,14 +391,15 @@ class Router: """ # No virtual hosts specified; default behavior if not self.hosts: - return self._get(request.path, request.method, '') + return self._get(request.path, request.method, "") # virtual hosts specified; try to match route to the host header try: - return self._get(request.path, request.method, - request.headers.get("Host", '')) + return self._get( + request.path, request.method, request.headers.get("Host", "") + ) # try default hosts except NotFound: - return self._get(request.path, request.method, '') + return self._get(request.path, request.method, "") def get_supported_methods(self, url): """Get a list of supported methods for a url and optional host. @@ -384,7 +409,7 @@ class Router: """ route = self.routes_all.get(url) # if methods are None then this logic will prevent an error - return getattr(route, 'methods', None) or frozenset() + return getattr(route, "methods", None) or frozenset() @lru_cache(maxsize=ROUTER_CACHE_SIZE) def _get(self, url, method, host): @@ -399,9 +424,10 @@ class Router: # Check against known static routes route = self.routes_static.get(url) method_not_supported = MethodNotSupported( - 'Method {} not allowed for URL {}'.format(method, url), + "Method {} not allowed for URL {}".format(method, url), method=method, - allowed_methods=self.get_supported_methods(url)) + allowed_methods=self.get_supported_methods(url), + ) if route: if route.methods and method not in route.methods: raise method_not_supported @@ -427,13 +453,14 @@ class Router: # Route was found but the methods didn't match if route_found: raise method_not_supported - raise NotFound('Requested URL {} not found'.format(url)) + raise NotFound("Requested URL {} not found".format(url)) - kwargs = {p.name: p.cast(value) - for value, p - in zip(match.groups(1), route.parameters)} + kwargs = { + p.name: p.cast(value) + for value, p in zip(match.groups(1), route.parameters) + } route_handler = route.handler - if hasattr(route_handler, 'handlers'): + if hasattr(route_handler, "handlers"): route_handler = route_handler.handlers[method] return route_handler, [], kwargs, route.uri @@ -446,7 +473,8 @@ class Router: handler = self.get(request)[0] except (NotFound, MethodNotSupported): return False - if (hasattr(handler, 'view_class') and - hasattr(handler.view_class, request.method.lower())): + if hasattr(handler, "view_class") and hasattr( + handler.view_class, request.method.lower() + ): handler = getattr(handler.view_class, request.method.lower()) - return hasattr(handler, 'is_stream') + return hasattr(handler, "is_stream") diff --git a/sanic/server.py b/sanic/server.py index e5069875..2f743b8a 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -4,16 +4,8 @@ import traceback from functools import partial from inspect import isawaitable from multiprocessing import Process -from signal import ( - SIGTERM, SIGINT, SIG_IGN, - signal as signal_func, - Signals -) -from socket import ( - socket, - SOL_SOCKET, - SO_REUSEADDR, -) +from signal import SIGTERM, SIGINT, SIG_IGN, signal as signal_func, Signals +from socket import socket, SOL_SOCKET, SO_REUSEADDR from time import time from httptools import HttpRequestParser @@ -22,6 +14,7 @@ from multidict import CIMultiDict try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -30,8 +23,12 @@ from sanic.log import logger, access_logger from sanic.response import HTTPResponse from sanic.request import Request from sanic.exceptions import ( - RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError, - ServiceUnavailable) + RequestTimeout, + PayloadTooLarge, + InvalidUsage, + ServerError, + ServiceUnavailable, +) current_time = None @@ -43,27 +40,58 @@ class Signal: class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection - 'loop', 'transport', 'connections', 'signal', + "loop", + "transport", + "connections", + "signal", # request params - 'parser', 'request', 'url', 'headers', + "parser", + "request", + "url", + "headers", # request config - 'request_handler', 'request_timeout', 'response_timeout', - 'keep_alive_timeout', 'request_max_size', 'request_class', - 'is_request_stream', 'router', + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_class", + "is_request_stream", + "router", # enable or disable access log purpose - 'access_log', + "access_log", # connection management - '_total_request_size', '_request_timeout_handler', - '_response_timeout_handler', '_keep_alive_timeout_handler', - '_last_request_time', '_last_response_time', '_is_stream_handler', - '_not_paused') + "_total_request_size", + "_request_timeout_handler", + "_response_timeout_handler", + "_keep_alive_timeout_handler", + "_last_request_time", + "_last_response_time", + "_is_stream_handler", + "_not_paused", + ) - def __init__(self, *, loop, request_handler, error_handler, - signal=Signal(), connections=set(), request_timeout=60, - response_timeout=60, keep_alive_timeout=5, - request_max_size=None, request_class=None, access_log=True, - keep_alive=True, is_request_stream=False, router=None, - state=None, debug=False, **kwargs): + def __init__( + self, + *, + loop, + request_handler, + error_handler, + signal=Signal(), + connections=set(), + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + request_max_size=None, + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + state=None, + debug=False, + **kwargs + ): self.loop = loop self.transport = None self.request = None @@ -93,19 +121,20 @@ class HttpProtocol(asyncio.Protocol): self._request_handler_task = None self._request_stream_task = None self._keep_alive = keep_alive - self._header_fragment = b'' + self._header_fragment = b"" self.state = state if state else {} - if 'requests_count' not in self.state: - self.state['requests_count'] = 0 + if "requests_count" not in self.state: + self.state["requests_count"] = 0 self._debug = debug self._not_paused.set() @property def keep_alive(self): return ( - self._keep_alive and - not self.signal.stopped and - self.parser.should_keep_alive()) + self._keep_alive + and not self.signal.stopped + and self.parser.should_keep_alive() + ) # -------------------------------------------- # # Connection @@ -114,7 +143,8 @@ class HttpProtocol(asyncio.Protocol): def connection_made(self, transport): self.connections.add(self) self._request_timeout_handler = self.loop.call_later( - self.request_timeout, self.request_timeout_callback) + self.request_timeout, self.request_timeout_callback + ) self.transport = transport self._last_request_time = current_time @@ -145,16 +175,15 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed - self._request_timeout_handler = ( - self.loop.call_later(time_left, - self.request_timeout_callback) + self._request_timeout_handler = self.loop.call_later( + time_left, self.request_timeout_callback ) else: if self._request_stream_task: self._request_stream_task.cancel() if self._request_handler_task: self._request_handler_task.cancel() - self.write_error(RequestTimeout('Request Timeout')) + self.write_error(RequestTimeout("Request Timeout")) def response_timeout_callback(self): # Check if elapsed time since response was initiated exceeds our @@ -162,16 +191,15 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_request_time if time_elapsed < self.response_timeout: time_left = self.response_timeout - time_elapsed - self._response_timeout_handler = ( - self.loop.call_later(time_left, - self.response_timeout_callback) + self._response_timeout_handler = self.loop.call_later( + time_left, self.response_timeout_callback ) else: if self._request_stream_task: self._request_stream_task.cancel() if self._request_handler_task: self._request_handler_task.cancel() - self.write_error(ServiceUnavailable('Response Timeout')) + self.write_error(ServiceUnavailable("Response Timeout")) def keep_alive_timeout_callback(self): # Check if elapsed time since last response exceeds our configured @@ -179,12 +207,11 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_response_time if time_elapsed < self.keep_alive_timeout: time_left = self.keep_alive_timeout - time_elapsed - self._keep_alive_timeout_handler = ( - self.loop.call_later(time_left, - self.keep_alive_timeout_callback) + self._keep_alive_timeout_handler = self.loop.call_later( + time_left, self.keep_alive_timeout_callback ) else: - logger.debug('KeepAlive Timeout. Closing connection.') + logger.debug("KeepAlive Timeout. Closing connection.") self.transport.close() self.transport = None @@ -197,7 +224,7 @@ class HttpProtocol(asyncio.Protocol): # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: - self.write_error(PayloadTooLarge('Payload Too Large')) + self.write_error(PayloadTooLarge("Payload Too Large")) # Create parser if this is the first time we're receiving data if self.parser is None: @@ -206,15 +233,15 @@ class HttpProtocol(asyncio.Protocol): self.parser = HttpRequestParser(self) # requests count - self.state['requests_count'] = self.state['requests_count'] + 1 + self.state["requests_count"] = self.state["requests_count"] + 1 # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: - message = 'Bad Request' + message = "Bad Request" if self._debug: - message += '\n' + traceback.format_exc() + message += "\n" + traceback.format_exc() self.write_error(InvalidUsage(message)) def on_url(self, url): @@ -227,17 +254,20 @@ class HttpProtocol(asyncio.Protocol): self._header_fragment += name if value is not None: - if self._header_fragment == b'Content-Length' \ - and int(value) > self.request_max_size: - self.write_error(PayloadTooLarge('Payload Too Large')) + if ( + self._header_fragment == b"Content-Length" + and int(value) > self.request_max_size + ): + self.write_error(PayloadTooLarge("Payload Too Large")) try: value = value.decode() except UnicodeDecodeError: - value = value.decode('latin_1') + value = value.decode("latin_1") self.headers.append( - (self._header_fragment.decode().casefold(), value)) + (self._header_fragment.decode().casefold(), value) + ) - self._header_fragment = b'' + self._header_fragment = b"" def on_headers_complete(self): self.request = self.request_class( @@ -245,7 +275,7 @@ class HttpProtocol(asyncio.Protocol): headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), - transport=self.transport + transport=self.transport, ) # Remove any existing KeepAlive handler here, # It will be recreated if required on the new request. @@ -254,7 +284,8 @@ class HttpProtocol(asyncio.Protocol): self._keep_alive_timeout_handler = None if self.is_request_stream: self._is_stream_handler = self.router.is_stream_handler( - self.request) + self.request + ) if self._is_stream_handler: self.request.stream = asyncio.Queue() self.execute_request_handler() @@ -262,7 +293,8 @@ class HttpProtocol(asyncio.Protocol): def on_body(self, body): if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( - self.request.stream.put(body)) + self.request.stream.put(body) + ) return self.request.body.append(body) @@ -274,47 +306,49 @@ class HttpProtocol(asyncio.Protocol): self._request_timeout_handler = None if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( - self.request.stream.put(None)) + self.request.stream.put(None) + ) return - self.request.body = b''.join(self.request.body) + self.request.body = b"".join(self.request.body) self.execute_request_handler() def execute_request_handler(self): self._response_timeout_handler = self.loop.call_later( - self.response_timeout, self.response_timeout_callback) + self.response_timeout, self.response_timeout_callback + ) self._last_request_time = current_time self._request_handler_task = self.loop.create_task( self.request_handler( - self.request, - self.write_response, - self.stream_response)) + self.request, self.write_response, self.stream_response + ) + ) # -------------------------------------------- # # Responding # -------------------------------------------- # def log_response(self, response): if self.access_log: - extra = { - 'status': getattr(response, 'status', 0), - } + extra = {"status": getattr(response, "status", 0)} if isinstance(response, HTTPResponse): - extra['byte'] = len(response.body) + extra["byte"] = len(response.body) else: - extra['byte'] = -1 + extra["byte"] = -1 - extra['host'] = 'UNKNOWN' + extra["host"] = "UNKNOWN" if self.request is not None: if self.request.ip: - extra['host'] = '{0}:{1}'.format(self.request.ip, - self.request.port) + extra["host"] = "{0}:{1}".format( + self.request.ip, self.request.port + ) - extra['request'] = '{0} {1}'.format(self.request.method, - self.request.url) + extra["request"] = "{0} {1}".format( + self.request.method, self.request.url + ) else: - extra['request'] = 'nil' + extra["request"] = "nil" - access_logger.info('', extra=extra) + access_logger.info("", extra=extra) def write_response(self, response): """ @@ -327,31 +361,37 @@ class HttpProtocol(asyncio.Protocol): keep_alive = self.keep_alive self.transport.write( response.output( - self.request.version, keep_alive, - self.keep_alive_timeout)) + self.request.version, keep_alive, self.keep_alive_timeout + ) + ) self.log_response(response) except AttributeError: - logger.error('Invalid response object for url %s, ' - 'Expected Type: HTTPResponse, Actual Type: %s', - self.url, type(response)) - self.write_error(ServerError('Invalid response type')) + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) except RuntimeError: if self._debug: - logger.error('Connection lost before response written @ %s', - self.request.ip) + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) keep_alive = False except Exception as e: self.bail_out( - "Writing response failed, connection closed {}".format( - repr(e))) + "Writing response failed, connection closed {}".format(repr(e)) + ) finally: if not keep_alive: self.transport.close() self.transport = None else: self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, - self.keep_alive_timeout_callback) + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) self._last_response_time = current_time self.cleanup() @@ -375,30 +415,36 @@ class HttpProtocol(asyncio.Protocol): keep_alive = self.keep_alive response.protocol = self await response.stream( - self.request.version, keep_alive, self.keep_alive_timeout) + self.request.version, keep_alive, self.keep_alive_timeout + ) self.log_response(response) except AttributeError: - logger.error('Invalid response object for url %s, ' - 'Expected Type: HTTPResponse, Actual Type: %s', - self.url, type(response)) - self.write_error(ServerError('Invalid response type')) + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) except RuntimeError: if self._debug: - logger.error('Connection lost before response written @ %s', - self.request.ip) + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) keep_alive = False except Exception as e: self.bail_out( - "Writing response failed, connection closed {}".format( - repr(e))) + "Writing response failed, connection closed {}".format(repr(e)) + ) finally: if not keep_alive: self.transport.close() self.transport = None else: self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, - self.keep_alive_timeout_callback) + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) self._last_response_time = current_time self.cleanup() @@ -411,32 +457,37 @@ class HttpProtocol(asyncio.Protocol): response = None try: response = self.error_handler.response(self.request, exception) - version = self.request.version if self.request else '1.1' + version = self.request.version if self.request else "1.1" self.transport.write(response.output(version)) except RuntimeError: if self._debug: - logger.error('Connection lost before error written @ %s', - self.request.ip if self.request else 'Unknown') + logger.error( + "Connection lost before error written @ %s", + self.request.ip if self.request else "Unknown", + ) except Exception as e: self.bail_out( - "Writing error failed, connection closed {}".format( - repr(e)), from_error=True + "Writing error failed, connection closed {}".format(repr(e)), + from_error=True, ) finally: - if self.parser and (self.keep_alive - or getattr(response, 'status', 0) == 408): + if self.parser and ( + self.keep_alive or getattr(response, "status", 0) == 408 + ): self.log_response(response) try: self.transport.close() except AttributeError: - logger.debug('Connection lost before server could close it.') + logger.debug("Connection lost before server could close it.") def bail_out(self, message, from_error=False): if from_error or self.transport.is_closing(): - logger.error("Transport closed @ %s and exception " - "experienced during error handling", - self.transport.get_extra_info('peername')) - logger.debug('Exception:', exc_info=True) + logger.error( + "Transport closed @ %s and exception " + "experienced during error handling", + self.transport.get_extra_info("peername"), + ) + logger.debug("Exception:", exc_info=True) else: self.write_error(ServerError(message)) logger.error(message) @@ -497,17 +548,43 @@ def trigger_events(events, loop): loop.run_until_complete(result) -def serve(host, port, request_handler, error_handler, before_start=None, - after_start=None, before_stop=None, after_stop=None, debug=False, - request_timeout=60, response_timeout=60, keep_alive_timeout=5, - ssl=None, sock=None, request_max_size=None, reuse_port=False, - loop=None, protocol=HttpProtocol, backlog=100, - register_sys_signals=True, run_multiple=False, run_async=False, - connections=None, signal=Signal(), request_class=None, - access_log=True, keep_alive=True, is_request_stream=False, - router=None, websocket_max_size=None, websocket_max_queue=None, - websocket_read_limit=2 ** 16, websocket_write_limit=2 ** 16, - state=None, graceful_shutdown_timeout=15.0): +def serve( + host, + port, + request_handler, + error_handler, + before_start=None, + after_start=None, + before_stop=None, + after_stop=None, + debug=False, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + ssl=None, + sock=None, + request_max_size=None, + reuse_port=False, + loop=None, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_multiple=False, + run_async=False, + connections=None, + signal=Signal(), + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + state=None, + graceful_shutdown_timeout=15.0, +): """Start asynchronous HTTP Server on an individual process. :param host: Address to host on @@ -592,7 +669,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, ssl=ssl, reuse_port=reuse_port, sock=sock, - backlog=backlog + backlog=backlog, ) # Instead of pulling time at the end of every request, @@ -623,11 +700,13 @@ def serve(host, port, request_handler, error_handler, before_start=None, try: loop.add_signal_handler(_signal, loop.stop) except NotImplementedError: - logger.warning('Sanic tried to use loop.add_signal_handler ' - 'but it is not implemented on this platform.') + logger.warning( + "Sanic tried to use loop.add_signal_handler " + "but it is not implemented on this platform." + ) pid = os.getpid() try: - logger.info('Starting worker [%s]', pid) + logger.info("Starting worker [%s]", pid) loop.run_forever() finally: logger.info("Stopping worker [%s]", pid) @@ -658,9 +737,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) + coros.append(conn.websocket.close_connection()) else: conn.close() @@ -681,18 +758,18 @@ def serve_multiple(server_settings, workers): :param stop_event: if provided, is used as a stop signal :return: """ - server_settings['reuse_port'] = True - server_settings['run_multiple'] = True + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True # Handling when custom socket is not provided. - if server_settings.get('sock') is None: + if server_settings.get("sock") is None: sock = socket() sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - sock.bind((server_settings['host'], server_settings['port'])) + sock.bind((server_settings["host"], server_settings["port"])) sock.set_inheritable(True) - server_settings['sock'] = sock - server_settings['host'] = None - server_settings['port'] = None + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None def sig_handler(signal, frame): logger.info("Received signal %s. Shutting down.", Signals(signal).name) @@ -716,4 +793,4 @@ def serve_multiple(server_settings, workers): # the above processes will block this until they're stopped for process in processes: process.terminate() - server_settings.get('sock').close() + server_settings.get("sock").close() diff --git a/sanic/static.py b/sanic/static.py index 07831390..9aaa9750 100644 --- a/sanic/static.py +++ b/sanic/static.py @@ -16,10 +16,19 @@ from sanic.handlers import ContentRangeHandler from sanic.response import file, file_stream, HTTPResponse -def register(app, uri, file_or_directory, pattern, - use_modified_since, use_content_range, - stream_large_files, name='static', host=None, - strict_slashes=None, content_type=None): +def register( + app, + uri, + file_or_directory, + pattern, + use_modified_since, + use_content_range, + stream_large_files, + name="static", + host=None, + strict_slashes=None, + content_type=None, +): # TODO: Though sanic is not a file server, I feel like we should at least # make a good effort here. Modified-since is nice, but we could # also look into etags, expires, and caching @@ -46,12 +55,12 @@ def register(app, uri, file_or_directory, pattern, # If we're not trying to match a file directly, # serve from the folder if not path.isfile(file_or_directory): - uri += '' + uri += "" async def _handler(request, file_uri=None): # Using this to determine if the URL is trying to break out of the path # served. os.path.realpath seems to be very slow - if file_uri and '../' in file_uri: + if file_uri and "../" in file_uri: raise InvalidUsage("Invalid URL") # Merge served directory and requested file if provided # Strip all / that in the beginning of the URL to help prevent python @@ -59,15 +68,16 @@ def register(app, uri, file_or_directory, pattern, root_path = file_path = file_or_directory if file_uri: file_path = path.join( - file_or_directory, sub('^[/]*', '', file_uri)) + file_or_directory, sub("^[/]*", "", file_uri) + ) # URL decode the path sent by the browser otherwise we won't be able to # match filenames which got encoded (filenames with spaces etc) file_path = path.abspath(unquote(file_path)) if not file_path.startswith(path.abspath(unquote(root_path))): - raise FileNotFound('File not found', - path=file_or_directory, - relative_url=file_uri) + raise FileNotFound( + "File not found", path=file_or_directory, relative_url=file_uri + ) try: headers = {} # Check if the client has been sent this file before @@ -76,29 +86,31 @@ def register(app, uri, file_or_directory, pattern, if use_modified_since: stats = await stat(file_path) modified_since = strftime( - '%a, %d %b %Y %H:%M:%S GMT', gmtime(stats.st_mtime)) - if request.headers.get('If-Modified-Since') == modified_since: + "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) + ) + if request.headers.get("If-Modified-Since") == modified_since: return HTTPResponse(status=304) - headers['Last-Modified'] = modified_since + headers["Last-Modified"] = modified_since _range = None if use_content_range: _range = None if not stats: stats = await stat(file_path) - headers['Accept-Ranges'] = 'bytes' - headers['Content-Length'] = str(stats.st_size) - if request.method != 'HEAD': + headers["Accept-Ranges"] = "bytes" + headers["Content-Length"] = str(stats.st_size) + if request.method != "HEAD": try: _range = ContentRangeHandler(request, stats) except HeaderNotFound: pass else: - del headers['Content-Length'] + del headers["Content-Length"] for key, value in _range.headers.items(): headers[key] = value - headers['Content-Type'] = content_type \ - or guess_type(file_path)[0] or 'text/plain' - if request.method == 'HEAD': + headers["Content-Type"] = ( + content_type or guess_type(file_path)[0] or "text/plain" + ) + if request.method == "HEAD": return HTTPResponse(headers=headers) else: if stream_large_files: @@ -110,19 +122,25 @@ def register(app, uri, file_or_directory, pattern, if not stats: stats = await stat(file_path) if stats.st_size >= threshold: - return await file_stream(file_path, headers=headers, - _range=_range) + return await file_stream( + file_path, headers=headers, _range=_range + ) return await file(file_path, headers=headers, _range=_range) except ContentRangeError: raise except Exception: - raise FileNotFound('File not found', - path=file_or_directory, - relative_url=file_uri) + raise FileNotFound( + "File not found", path=file_or_directory, relative_url=file_uri + ) # special prefix for static files - if not name.startswith('_static_'): - name = '_static_{}'.format(name) + if not name.startswith("_static_"): + name = "_static_{}".format(name) - app.route(uri, methods=['GET', 'HEAD'], name=name, host=host, - strict_slashes=strict_slashes)(_handler) + app.route( + uri, + methods=["GET", "HEAD"], + name=name, + host=host, + strict_slashes=strict_slashes, + )(_handler) diff --git a/sanic/testing.py b/sanic/testing.py index e9bf2b5d..563de273 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -4,7 +4,7 @@ from sanic.exceptions import MethodNotSupported from sanic.response import text -HOST = '127.0.0.1' +HOST = "127.0.0.1" PORT = 42101 @@ -15,18 +15,22 @@ class SanicTestClient: async def _local_request(self, method, uri, cookies=None, *args, **kwargs): import aiohttp - if uri.startswith(('http:', 'https:', 'ftp:', 'ftps://' '//')): + + if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")): url = uri else: - url = 'http://{host}:{port}{uri}'.format( - host=HOST, port=self.port, uri=uri) + url = "http://{host}:{port}{uri}".format( + host=HOST, port=self.port, uri=uri + ) logger.info(url) conn = aiohttp.TCPConnector(verify_ssl=False) async with aiohttp.ClientSession( - cookies=cookies, connector=conn) as session: - async with getattr( - session, method.lower())(url, *args, **kwargs) as response: + cookies=cookies, connector=conn + ) as session: + async with getattr(session, method.lower())( + url, *args, **kwargs + ) as response: try: response.text = await response.text() except UnicodeDecodeError as e: @@ -34,50 +38,60 @@ class SanicTestClient: try: response.json = await response.json() - except (JSONDecodeError, - UnicodeDecodeError, - aiohttp.ClientResponseError): + except ( + JSONDecodeError, + UnicodeDecodeError, + aiohttp.ClientResponseError, + ): response.json = None response.body = await response.read() return response def _sanic_endpoint_test( - self, method='get', uri='/', gather_request=True, - debug=False, server_kwargs={"auto_reload": False}, - *request_args, **request_kwargs): + self, + method="get", + uri="/", + gather_request=True, + debug=False, + server_kwargs={"auto_reload": False}, + *request_args, + **request_kwargs + ): results = [None, None] exceptions = [] if gather_request: + def _collect_request(request): if results[0] is None: results[0] = request + self.app.request_middleware.appendleft(_collect_request) @self.app.exception(MethodNotSupported) async def error_handler(request, exception): - if request.method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: + if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: return text( - '', exception.status_code, headers=exception.headers + "", exception.status_code, headers=exception.headers ) else: return self.app.error_handler.default(request, exception) - @self.app.listener('after_server_start') + @self.app.listener("after_server_start") async def _collect_response(sanic, loop): try: response = await self._local_request( - method, uri, *request_args, - **request_kwargs) + method, uri, *request_args, **request_kwargs + ) results[-1] = response except Exception as e: - logger.exception('Exception') + logger.exception("Exception") exceptions.append(e) self.app.stop() self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs) - self.app.listeners['after_server_start'].pop() + self.app.listeners["after_server_start"].pop() if exceptions: raise ValueError("Exception during request: {}".format(exceptions)) @@ -89,31 +103,34 @@ class SanicTestClient: except BaseException: raise ValueError( "Request and response object expected, got ({})".format( - results)) + results + ) + ) else: try: return results[-1] except BaseException: raise ValueError( - "Request object expected, got ({})".format(results)) + "Request object expected, got ({})".format(results) + ) def get(self, *args, **kwargs): - return self._sanic_endpoint_test('get', *args, **kwargs) + return self._sanic_endpoint_test("get", *args, **kwargs) def post(self, *args, **kwargs): - return self._sanic_endpoint_test('post', *args, **kwargs) + return self._sanic_endpoint_test("post", *args, **kwargs) def put(self, *args, **kwargs): - return self._sanic_endpoint_test('put', *args, **kwargs) + return self._sanic_endpoint_test("put", *args, **kwargs) def delete(self, *args, **kwargs): - return self._sanic_endpoint_test('delete', *args, **kwargs) + return self._sanic_endpoint_test("delete", *args, **kwargs) def patch(self, *args, **kwargs): - return self._sanic_endpoint_test('patch', *args, **kwargs) + return self._sanic_endpoint_test("patch", *args, **kwargs) def options(self, *args, **kwargs): - return self._sanic_endpoint_test('options', *args, **kwargs) + return self._sanic_endpoint_test("options", *args, **kwargs) def head(self, *args, **kwargs): - return self._sanic_endpoint_test('head', *args, **kwargs) + return self._sanic_endpoint_test("head", *args, **kwargs) diff --git a/sanic/views.py b/sanic/views.py index f47f2044..29470237 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -48,6 +48,7 @@ class HTTPMethodView: """Return view function for use with the routing system, that dispatches request to appropriate handler method. """ + def view(*args, **kwargs): self = view.view_class(*class_args, **class_kwargs) return self.dispatch_request(*args, **kwargs) @@ -94,11 +95,13 @@ class CompositionView: for method in methods: if method not in HTTP_METHODS: raise InvalidUsage( - '{} is not a valid HTTP method.'.format(method)) + "{} is not a valid HTTP method.".format(method) + ) if method in self.handlers: raise InvalidUsage( - 'Method {} is already registered.'.format(method)) + "Method {} is already registered.".format(method) + ) self.handlers[method] = handler def __call__(self, request, *args, **kwargs): diff --git a/sanic/websocket.py b/sanic/websocket.py index 9ccf9fdf..74e40322 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -6,11 +6,16 @@ from websockets import ConnectionClosed # noqa class WebSocketProtocol(HttpProtocol): - def __init__(self, *args, websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, **kwargs): + def __init__( + self, + *args, + websocket_timeout=10, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + **kwargs + ): super().__init__(*args, **kwargs) self.websocket = None self.websocket_timeout = websocket_timeout @@ -63,24 +68,26 @@ class WebSocketProtocol(HttpProtocol): key = handshake.check_request(request.headers) handshake.build_response(headers, key) except InvalidHandshake: - raise InvalidUsage('Invalid websocket request') + raise InvalidUsage("Invalid websocket request") subprotocol = None - if subprotocols and 'Sec-Websocket-Protocol' in request.headers: + if subprotocols and "Sec-Websocket-Protocol" in request.headers: # select a subprotocol - client_subprotocols = [p.strip() for p in request.headers[ - 'Sec-Websocket-Protocol'].split(',')] + client_subprotocols = [ + p.strip() + for p in request.headers["Sec-Websocket-Protocol"].split(",") + ] for p in client_subprotocols: if p in subprotocols: subprotocol = p - headers['Sec-Websocket-Protocol'] = subprotocol + headers["Sec-Websocket-Protocol"] = subprotocol break # write the 101 response back to the client - rv = b'HTTP/1.1 101 Switching Protocols\r\n' + rv = b"HTTP/1.1 101 Switching Protocols\r\n" for k, v in headers.items(): - rv += k.encode('utf-8') + b': ' + v.encode('utf-8') + b'\r\n' - rv += b'\r\n' + rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" + rv += b"\r\n" request.transport.write(rv) # hook up the websocket protocol @@ -89,7 +96,7 @@ class WebSocketProtocol(HttpProtocol): max_size=self.websocket_max_size, max_queue=self.websocket_max_queue, read_limit=self.websocket_read_limit, - write_limit=self.websocket_write_limit + write_limit=self.websocket_write_limit, ) self.websocket.subprotocol = subprotocol self.websocket.connection_made(request.transport) diff --git a/sanic/worker.py b/sanic/worker.py index d367a7c3..ac854b94 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -12,6 +12,7 @@ except ImportError: try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -50,36 +51,43 @@ class GunicornWorker(base.Worker): def run(self): is_debug = self.log.loglevel == logging.DEBUG protocol = ( - self.websocket_protocol if self.app.callable.websocket_enabled - else self.http_protocol) + self.websocket_protocol + if self.app.callable.websocket_enabled + else self.http_protocol + ) self._server_settings = self.app.callable._helper( loop=self.loop, debug=is_debug, protocol=protocol, ssl=self.ssl_context, - run_async=True) - self._server_settings['signal'] = self.signal - self._server_settings.pop('sock') - trigger_events(self._server_settings.get('before_start', []), - self.loop) - self._server_settings['before_start'] = () + run_async=True, + ) + self._server_settings["signal"] = self.signal + self._server_settings.pop("sock") + trigger_events( + self._server_settings.get("before_start", []), self.loop + ) + self._server_settings["before_start"] = () self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: self.loop.run_until_complete(self._runner) self.app.callable.is_running = True - trigger_events(self._server_settings.get('after_start', []), - self.loop) + trigger_events( + self._server_settings.get("after_start", []), self.loop + ) self.loop.run_until_complete(self._check_alive()) - trigger_events(self._server_settings.get('before_stop', []), - self.loop) + trigger_events( + self._server_settings.get("before_stop", []), self.loop + ) self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: try: - trigger_events(self._server_settings.get('after_stop', []), - self.loop) + trigger_events( + self._server_settings.get("after_stop", []), self.loop + ) except BaseException: traceback.print_exc() finally: @@ -90,8 +98,11 @@ class GunicornWorker(base.Worker): async def close(self): if self.servers: # stop accepting connections - self.log.info("Stopping server: %s, connections: %s", - self.pid, len(self.connections)) + self.log.info( + "Stopping server: %s, connections: %s", + self.pid, + len(self.connections), + ) for server in self.servers: server.close() await server.wait_closed() @@ -105,8 +116,9 @@ class GunicornWorker(base.Worker): # gracefully shutdown timeout start_shutdown = 0 graceful_shutdown_timeout = self.cfg.graceful_timeout - while self.connections and \ - (start_shutdown < graceful_shutdown_timeout): + while self.connections and ( + start_shutdown < graceful_shutdown_timeout + ): await asyncio.sleep(0.1) start_shutdown = start_shutdown + 0.1 @@ -115,9 +127,7 @@ class GunicornWorker(base.Worker): coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) + coros.append(conn.websocket.close_connection()) else: conn.close() _shutdown = asyncio.gather(*coros, loop=self.loop) @@ -148,8 +158,9 @@ class GunicornWorker(base.Worker): ) if self.max_requests and req_count > self.max_requests: self.alive = False - self.log.info("Max requests exceeded, shutting down: %s", - self) + self.log.info( + "Max requests exceeded, shutting down: %s", self + ) elif pid == os.getpid() and self.ppid != os.getppid(): self.alive = False self.log.info("Parent changed, shutting down: %s", self) @@ -175,23 +186,29 @@ class GunicornWorker(base.Worker): def init_signals(self): # Set up signals through the event loop API. - self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit, - signal.SIGQUIT, None) + self.loop.add_signal_handler( + signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None + ) - self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit, - signal.SIGTERM, None) + self.loop.add_signal_handler( + signal.SIGTERM, self.handle_exit, signal.SIGTERM, None + ) - self.loop.add_signal_handler(signal.SIGINT, self.handle_quit, - signal.SIGINT, None) + self.loop.add_signal_handler( + signal.SIGINT, self.handle_quit, signal.SIGINT, None + ) - self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch, - signal.SIGWINCH, None) + self.loop.add_signal_handler( + signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None + ) - self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1, - signal.SIGUSR1, None) + self.loop.add_signal_handler( + signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None + ) - self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort, - signal.SIGABRT, None) + self.loop.add_signal_handler( + signal.SIGABRT, self.handle_abort, signal.SIGABRT, None + ) # Don't let SIGTERM and SIGUSR1 disturb active requests # by interrupting system calls