{exc_value}
Traceback (most recent call last):
{frame_html}File {0.filename}, line {0.lineno}, @@ -106,15 +106,15 @@ TRACEBACK_LINE_HTML = '''
{0.line}
The server encountered an internal error and cannot complete your request.
-''' +""" _sanic_exceptions = {} @@ -124,15 +124,16 @@ def add_status_code(code): """ Decorator used for adding exceptions to _sanic_exceptions. """ + def class_decorator(cls): cls.status_code = code _sanic_exceptions[code] = cls return cls + return class_decorator class SanicException(Exception): - def __init__(self, message, status_code=None): super().__init__(message) @@ -156,8 +157,8 @@ class MethodNotSupported(SanicException): super().__init__(message) self.headers = dict() self.headers["Allow"] = ", ".join(allowed_methods) - if method in ['HEAD', 'PATCH', 'PUT', 'DELETE']: - self.headers['Content-Length'] = 0 + if method in ["HEAD", "PATCH", "PUT", "DELETE"]: + self.headers["Content-Length"] = 0 @add_status_code(500) @@ -169,6 +170,7 @@ class ServerError(SanicException): class ServiceUnavailable(SanicException): """The server is currently unavailable (because it is overloaded or down for maintenance). Generally, this is a temporary state.""" + pass @@ -192,6 +194,7 @@ class RequestTimeout(SanicException): the connection. The socket connection has actually been lost - the Web server has 'timed out' on that particular socket connection. """ + pass @@ -209,8 +212,8 @@ class ContentRangeError(SanicException): def __init__(self, message, content_range): super().__init__(message) self.headers = { - 'Content-Type': 'text/plain', - "Content-Range": "bytes */%s" % (content_range.total,) + "Content-Type": "text/plain", + "Content-Range": "bytes */%s" % (content_range.total,), } @@ -225,7 +228,7 @@ class InvalidRangeType(ContentRangeError): class PyFileError(Exception): def __init__(self, file): - super().__init__('could not execute config file %s', file) + super().__init__("could not execute config file %s", file) @add_status_code(401) @@ -263,13 +266,14 @@ class Unauthorized(SanicException): scheme="Bearer", realm="Restricted Area") """ + def __init__(self, message, status_code=None, scheme=None, **kwargs): super().__init__(message, status_code) # if auth-scheme is specified, set "WWW-Authenticate" header if scheme is not None: values = ['{!s}="{!s}"'.format(k, v) for k, v in kwargs.items()] - challenge = ', '.join(values) + challenge = ", ".join(values) self.headers = { "WWW-Authenticate": "{} {}".format(scheme, challenge).rstrip() @@ -288,6 +292,6 @@ def abort(status_code, message=None): if message is None: message = STATUS_CODES.get(status_code) # These are stored as bytes in the STATUS_CODES dict - message = message.decode('utf8') + message = message.decode("utf8") sanic_exception = _sanic_exceptions.get(status_code, SanicException) raise sanic_exception(message=message, status_code=status_code) diff --git a/sanic/handlers.py b/sanic/handlers.py index 30923e27..ec6de82e 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -11,7 +11,8 @@ from sanic.exceptions import ( TRACEBACK_STYLE, TRACEBACK_WRAPPER_HTML, TRACEBACK_WRAPPER_INNER_HTML, - TRACEBACK_BORDER) + TRACEBACK_BORDER, +) from sanic.log import logger from sanic.response import text, html @@ -36,7 +37,8 @@ class ErrorHandler: return TRACEBACK_WRAPPER_INNER_HTML.format( exc_name=exception.__class__.__name__, exc_value=exception, - frame_html=''.join(frame_html)) + frame_html="".join(frame_html), + ) def _render_traceback_html(self, exception, request): exc_type, exc_value, tb = sys.exc_info() @@ -51,7 +53,8 @@ class ErrorHandler: exc_name=exception.__class__.__name__, exc_value=exception, inner_html=TRACEBACK_BORDER.join(reversed(exceptions)), - path=request.path) + path=request.path, + ) def add(self, exception, handler): self.handlers.append((exception, handler)) @@ -88,17 +91,18 @@ class ErrorHandler: url = repr(request.url) except AttributeError: url = "unknown" - response_message = ('Exception raised in exception handler ' - '"%s" for uri: %s') + response_message = ( + "Exception raised in exception handler " '"%s" for uri: %s' + ) logger.exception(response_message, handler.__name__, url) if self.debug: return text(response_message % (handler.__name__, url), 500) else: - return text('An error occurred while handling an error', 500) + return text("An error occurred while handling an error", 500) return response - def log(self, message, level='error'): + def log(self, message, level="error"): """ Deprecated, do not use. """ @@ -110,14 +114,14 @@ class ErrorHandler: except AttributeError: url = "unknown" - response_message = ('Exception occurred while handling uri: %s') + response_message = "Exception occurred while handling uri: %s" logger.exception(response_message, url) if issubclass(type(exception), SanicException): return text( - 'Error: {}'.format(exception), - status=getattr(exception, 'status_code', 500), - headers=getattr(exception, 'headers', dict()) + "Error: {}".format(exception), + status=getattr(exception, "status_code", 500), + headers=getattr(exception, "headers", dict()), ) elif self.debug: html_output = self._render_traceback_html(exception, request) @@ -129,32 +133,37 @@ class ErrorHandler: class ContentRangeHandler: """Class responsible for parsing request header""" - __slots__ = ('start', 'end', 'size', 'total', 'headers') + + __slots__ = ("start", "end", "size", "total", "headers") def __init__(self, request, stats): self.total = stats.st_size - _range = request.headers.get('Range') + _range = request.headers.get("Range") if _range is None: - raise HeaderNotFound('Range Header Not Found') - unit, _, value = tuple(map(str.strip, _range.partition('='))) - if unit != 'bytes': + raise HeaderNotFound("Range Header Not Found") + unit, _, value = tuple(map(str.strip, _range.partition("="))) + if unit != "bytes": raise InvalidRangeType( - '%s is not a valid Range Type' % (unit,), self) - start_b, _, end_b = tuple(map(str.strip, value.partition('-'))) + "%s is not a valid Range Type" % (unit,), self + ) + start_b, _, end_b = tuple(map(str.strip, value.partition("-"))) try: self.start = int(start_b) if start_b else None except ValueError: raise ContentRangeError( - '\'%s\' is invalid for Content Range' % (start_b,), self) + "'%s' is invalid for Content Range" % (start_b,), self + ) try: self.end = int(end_b) if end_b else None except ValueError: raise ContentRangeError( - '\'%s\' is invalid for Content Range' % (end_b,), self) + "'%s' is invalid for Content Range" % (end_b,), self + ) if self.end is None: if self.start is None: raise ContentRangeError( - 'Invalid for Content Range parameters', self) + "Invalid for Content Range parameters", self + ) else: # this case represents `Content-Range: bytes 5-` self.end = self.total @@ -165,11 +174,13 @@ class ContentRangeHandler: self.end = self.total if self.start >= self.end: raise ContentRangeError( - 'Invalid for Content Range parameters', self) + "Invalid for Content Range parameters", self + ) self.size = self.end - self.start self.headers = { - 'Content-Range': "bytes %s-%s/%s" % ( - self.start, self.end, self.total)} + "Content-Range": "bytes %s-%s/%s" + % (self.start, self.end, self.total) + } def __bool__(self): return self.size > 0 diff --git a/sanic/helpers.py b/sanic/helpers.py index 482253b3..d992f92e 100644 --- a/sanic/helpers.py +++ b/sanic/helpers.py @@ -1,93 +1,97 @@ """Defines basics of HTTP standard.""" STATUS_CODES = { - 100: b'Continue', - 101: b'Switching Protocols', - 102: b'Processing', - 200: b'OK', - 201: b'Created', - 202: b'Accepted', - 203: b'Non-Authoritative Information', - 204: b'No Content', - 205: b'Reset Content', - 206: b'Partial Content', - 207: b'Multi-Status', - 208: b'Already Reported', - 226: b'IM Used', - 300: b'Multiple Choices', - 301: b'Moved Permanently', - 302: b'Found', - 303: b'See Other', - 304: b'Not Modified', - 305: b'Use Proxy', - 307: b'Temporary Redirect', - 308: b'Permanent Redirect', - 400: b'Bad Request', - 401: b'Unauthorized', - 402: b'Payment Required', - 403: b'Forbidden', - 404: b'Not Found', - 405: b'Method Not Allowed', - 406: b'Not Acceptable', - 407: b'Proxy Authentication Required', - 408: b'Request Timeout', - 409: b'Conflict', - 410: b'Gone', - 411: b'Length Required', - 412: b'Precondition Failed', - 413: b'Request Entity Too Large', - 414: b'Request-URI Too Long', - 415: b'Unsupported Media Type', - 416: b'Requested Range Not Satisfiable', - 417: b'Expectation Failed', - 418: b'I\'m a teapot', - 422: b'Unprocessable Entity', - 423: b'Locked', - 424: b'Failed Dependency', - 426: b'Upgrade Required', - 428: b'Precondition Required', - 429: b'Too Many Requests', - 431: b'Request Header Fields Too Large', - 451: b'Unavailable For Legal Reasons', - 500: b'Internal Server Error', - 501: b'Not Implemented', - 502: b'Bad Gateway', - 503: b'Service Unavailable', - 504: b'Gateway Timeout', - 505: b'HTTP Version Not Supported', - 506: b'Variant Also Negotiates', - 507: b'Insufficient Storage', - 508: b'Loop Detected', - 510: b'Not Extended', - 511: b'Network Authentication Required' + 100: b"Continue", + 101: b"Switching Protocols", + 102: b"Processing", + 200: b"OK", + 201: b"Created", + 202: b"Accepted", + 203: b"Non-Authoritative Information", + 204: b"No Content", + 205: b"Reset Content", + 206: b"Partial Content", + 207: b"Multi-Status", + 208: b"Already Reported", + 226: b"IM Used", + 300: b"Multiple Choices", + 301: b"Moved Permanently", + 302: b"Found", + 303: b"See Other", + 304: b"Not Modified", + 305: b"Use Proxy", + 307: b"Temporary Redirect", + 308: b"Permanent Redirect", + 400: b"Bad Request", + 401: b"Unauthorized", + 402: b"Payment Required", + 403: b"Forbidden", + 404: b"Not Found", + 405: b"Method Not Allowed", + 406: b"Not Acceptable", + 407: b"Proxy Authentication Required", + 408: b"Request Timeout", + 409: b"Conflict", + 410: b"Gone", + 411: b"Length Required", + 412: b"Precondition Failed", + 413: b"Request Entity Too Large", + 414: b"Request-URI Too Long", + 415: b"Unsupported Media Type", + 416: b"Requested Range Not Satisfiable", + 417: b"Expectation Failed", + 418: b"I'm a teapot", + 422: b"Unprocessable Entity", + 423: b"Locked", + 424: b"Failed Dependency", + 426: b"Upgrade Required", + 428: b"Precondition Required", + 429: b"Too Many Requests", + 431: b"Request Header Fields Too Large", + 451: b"Unavailable For Legal Reasons", + 500: b"Internal Server Error", + 501: b"Not Implemented", + 502: b"Bad Gateway", + 503: b"Service Unavailable", + 504: b"Gateway Timeout", + 505: b"HTTP Version Not Supported", + 506: b"Variant Also Negotiates", + 507: b"Insufficient Storage", + 508: b"Loop Detected", + 510: b"Not Extended", + 511: b"Network Authentication Required", } # According to https://tools.ietf.org/html/rfc2616#section-7.1 -_ENTITY_HEADERS = frozenset([ - 'allow', - 'content-encoding', - 'content-language', - 'content-length', - 'content-location', - 'content-md5', - 'content-range', - 'content-type', - 'expires', - 'last-modified', - 'extension-header' -]) +_ENTITY_HEADERS = frozenset( + [ + "allow", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-md5", + "content-range", + "content-type", + "expires", + "last-modified", + "extension-header", + ] +) # According to https://tools.ietf.org/html/rfc2616#section-13.5.1 -_HOP_BY_HOP_HEADERS = frozenset([ - 'connection', - 'keep-alive', - 'proxy-authenticate', - 'proxy-authorization', - 'te', - 'trailers', - 'transfer-encoding', - 'upgrade' -]) +_HOP_BY_HOP_HEADERS = frozenset( + [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ] +) def has_message_body(status): @@ -110,8 +114,7 @@ def is_hop_by_hop_header(header): return header.lower() in _HOP_BY_HOP_HEADERS -def remove_entity_headers(headers, - allowed=('content-location', 'expires')): +def remove_entity_headers(headers, allowed=("content-location", "expires")): """ Removes all the entity headers present in the headers given. According to RFC 2616 Section 10.3.5, @@ -122,7 +125,9 @@ def remove_entity_headers(headers, returns the headers without the entity headers """ allowed = set([h.lower() for h in allowed]) - headers = {header: value for header, value in headers.items() - if not is_entity_header(header) - and header.lower() not in allowed} + headers = { + header: value + for header, value in headers.items() + if not is_entity_header(header) and header.lower() not in allowed + } return headers diff --git a/sanic/log.py b/sanic/log.py index 67381ce6..cb8ca524 100644 --- a/sanic/log.py +++ b/sanic/log.py @@ -5,59 +5,54 @@ import sys LOGGING_CONFIG_DEFAULTS = dict( version=1, disable_existing_loggers=False, - loggers={ - "root": { - "level": "INFO", - "handlers": ["console"] - }, + "root": {"level": "INFO", "handlers": ["console"]}, "sanic.error": { "level": "INFO", "handlers": ["error_console"], "propagate": True, - "qualname": "sanic.error" + "qualname": "sanic.error", }, - "sanic.access": { "level": "INFO", "handlers": ["access_console"], "propagate": True, - "qualname": "sanic.access" - } + "qualname": "sanic.access", + }, }, handlers={ "console": { "class": "logging.StreamHandler", "formatter": "generic", - "stream": sys.stdout + "stream": sys.stdout, }, "error_console": { "class": "logging.StreamHandler", "formatter": "generic", - "stream": sys.stderr + "stream": sys.stderr, }, "access_console": { "class": "logging.StreamHandler", "formatter": "access", - "stream": sys.stdout + "stream": sys.stdout, }, }, formatters={ "generic": { "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", "datefmt": "[%Y-%m-%d %H:%M:%S %z]", - "class": "logging.Formatter" + "class": "logging.Formatter", }, "access": { - "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + - "%(request)s %(message)s %(status)d %(byte)d", + "format": "%(asctime)s - (%(name)s)[%(levelname)s][%(host)s]: " + + "%(request)s %(message)s %(status)d %(byte)d", "datefmt": "[%Y-%m-%d %H:%M:%S %z]", - "class": "logging.Formatter" + "class": "logging.Formatter", }, - } + }, ) -logger = logging.getLogger('sanic.root') -error_logger = logging.getLogger('sanic.error') -access_logger = logging.getLogger('sanic.access') +logger = logging.getLogger("sanic.root") +error_logger = logging.getLogger("sanic.error") +access_logger = logging.getLogger("sanic.access") diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index e0cb42e0..70bc2fcc 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -18,7 +18,7 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: old = None while not os.path.isfile(filename): @@ -27,7 +27,7 @@ def _iter_module_files(): if filename == old: break else: - if filename[-4:] in ('.pyc', '.pyo'): + if filename[-4:] in (".pyc", ".pyo"): filename = filename[:-1] yield filename @@ -45,11 +45,13 @@ def restart_with_reloader(): """ args = _get_args_for_reloading() new_environ = os.environ.copy() - new_environ['SANIC_SERVER_RUNNING'] = 'true' - cmd = ' '.join(args) + new_environ["SANIC_SERVER_RUNNING"] = "true" + cmd = " ".join(args) worker_process = Process( - target=subprocess.call, args=(cmd,), - kwargs=dict(shell=True, env=new_environ)) + target=subprocess.call, + args=(cmd,), + kwargs=dict(shell=True, env=new_environ), + ) worker_process.start() return worker_process @@ -67,8 +69,10 @@ def kill_process_children_unix(pid): children_list_pid = children_list_file.read().split() for child_pid in children_list_pid: - children_proc_path = "/proc/%s/task/%s/children" % \ - (child_pid, child_pid) + children_proc_path = "/proc/%s/task/%s/children" % ( + child_pid, + child_pid, + ) if not os.path.isfile(children_proc_path): continue with open(children_proc_path) as children_list_file_2: @@ -90,7 +94,7 @@ def kill_process_children_osx(pid): :param pid: PID of parent process (process ID) :return: Nothing """ - subprocess.run(['pkill', '-P', str(pid)]) + subprocess.run(["pkill", "-P", str(pid)]) def kill_process_children(pid): @@ -99,12 +103,12 @@ def kill_process_children(pid): :param pid: PID of parent process (process ID) :return: Nothing """ - if sys.platform == 'darwin': + if sys.platform == "darwin": kill_process_children_osx(pid) - elif sys.platform == 'linux': + elif sys.platform == "linux": kill_process_children_unix(pid) else: - pass # should signal error here + pass # should signal error here def kill_program_completly(proc): @@ -127,9 +131,11 @@ def watchdog(sleep_interval): mtimes = {} worker_process = restart_with_reloader() signal.signal( - signal.SIGTERM, lambda *args: kill_program_completly(worker_process)) + signal.SIGTERM, lambda *args: kill_program_completly(worker_process) + ) signal.signal( - signal.SIGINT, lambda *args: kill_program_completly(worker_process)) + signal.SIGINT, lambda *args: kill_program_completly(worker_process) + ) while True: for filename in _iter_module_files(): try: diff --git a/sanic/request.py b/sanic/request.py index 70240207..fec43bc9 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -10,9 +10,11 @@ try: from ujson import loads as json_loads except ImportError: if sys.version_info[:2] == (3, 5): + def json_loads(data): # on Python 3.5 json.loads only supports str not bytes return json.loads(data.decode()) + else: json_loads = json.loads @@ -43,11 +45,28 @@ class RequestParameters(dict): class Request(dict): """Properties of an HTTP request such as URL, headers, etc.""" + __slots__ = ( - 'app', 'headers', 'version', 'method', '_cookies', 'transport', - 'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', - '_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr', - '_socket', '_port', '__weakref__', 'raw_url' + "app", + "headers", + "version", + "method", + "_cookies", + "transport", + "body", + "parsed_json", + "parsed_args", + "parsed_form", + "parsed_files", + "_ip", + "_parsed_url", + "uri_template", + "stream", + "_remote_addr", + "_socket", + "_port", + "__weakref__", + "raw_url", ) def __init__(self, url_bytes, headers, version, method, transport): @@ -73,10 +92,10 @@ class Request(dict): def __repr__(self): if self.method is None or not self.path: - return '<{0}>'.format(self.__class__.__name__) - return '<{0}: {1} {2}>'.format(self.__class__.__name__, - self.method, - self.path) + return "<{0}>".format(self.__class__.__name__) + return "<{0}: {1} {2}>".format( + self.__class__.__name__, self.method, self.path + ) def __bool__(self): if self.transport: @@ -106,8 +125,8 @@ class Request(dict): :return: token related to request """ - prefixes = ('Bearer', 'Token') - auth_header = self.headers.get('Authorization') + prefixes = ("Bearer", "Token") + auth_header = self.headers.get("Authorization") if auth_header is not None: for prefix in prefixes: @@ -122,17 +141,20 @@ class Request(dict): self.parsed_form = RequestParameters() self.parsed_files = RequestParameters() content_type = self.headers.get( - 'Content-Type', DEFAULT_HTTP_CONTENT_TYPE) + "Content-Type", DEFAULT_HTTP_CONTENT_TYPE + ) content_type, parameters = parse_header(content_type) try: - if content_type == 'application/x-www-form-urlencoded': + if content_type == "application/x-www-form-urlencoded": self.parsed_form = RequestParameters( - parse_qs(self.body.decode('utf-8'))) - elif content_type == 'multipart/form-data': + parse_qs(self.body.decode("utf-8")) + ) + elif content_type == "multipart/form-data": # TODO: Stream this instead of reading to/from memory - boundary = parameters['boundary'].encode('utf-8') - self.parsed_form, self.parsed_files = ( - parse_multipart_form(self.body, boundary)) + boundary = parameters["boundary"].encode("utf-8") + self.parsed_form, self.parsed_files = parse_multipart_form( + self.body, boundary + ) except Exception: error_logger.exception("Failed when parsing form") @@ -150,7 +172,8 @@ class Request(dict): if self.parsed_args is None: if self.query_string: self.parsed_args = RequestParameters( - parse_qs(self.query_string)) + parse_qs(self.query_string) + ) else: self.parsed_args = RequestParameters() return self.parsed_args @@ -162,37 +185,40 @@ class Request(dict): @property def cookies(self): if self._cookies is None: - cookie = self.headers.get('Cookie') + cookie = self.headers.get("Cookie") if cookie is not None: cookies = SimpleCookie() cookies.load(cookie) - self._cookies = {name: cookie.value - for name, cookie in cookies.items()} + self._cookies = { + name: cookie.value for name, cookie in cookies.items() + } else: self._cookies = {} return self._cookies @property def ip(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._ip @property def port(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._port @property def socket(self): - if not hasattr(self, '_socket'): + if not hasattr(self, "_socket"): self._get_address() return self._socket def _get_address(self): - self._socket = self.transport.get_extra_info('peername') or \ - (None, None) + self._socket = self.transport.get_extra_info("peername") or ( + None, + None, + ) self._ip = self._socket[0] self._port = self._socket[1] @@ -202,29 +228,31 @@ class Request(dict): :return: original client ip. """ - if not hasattr(self, '_remote_addr'): - forwarded_for = self.headers.get('X-Forwarded-For', '').split(',') + if not hasattr(self, "_remote_addr"): + forwarded_for = self.headers.get("X-Forwarded-For", "").split(",") remote_addrs = [ - addr for addr in [ - addr.strip() for addr in forwarded_for - ] if addr - ] + addr + for addr in [addr.strip() for addr in forwarded_for] + if addr + ] if len(remote_addrs) > 0: self._remote_addr = remote_addrs[0] else: - self._remote_addr = '' + self._remote_addr = "" return self._remote_addr @property def scheme(self): - if self.app.websocket_enabled \ - and self.headers.get('upgrade') == 'websocket': - scheme = 'ws' + if ( + self.app.websocket_enabled + and self.headers.get("upgrade") == "websocket" + ): + scheme = "ws" else: - scheme = 'http' + scheme = "http" - if self.transport.get_extra_info('sslcontext'): - scheme += 's' + if self.transport.get_extra_info("sslcontext"): + scheme += "s" return scheme @@ -232,11 +260,11 @@ class Request(dict): def host(self): # it appears that httptools doesn't return the host # so pull it from the headers - return self.headers.get('Host', '') + return self.headers.get("Host", "") @property def content_type(self): - return self.headers.get('Content-Type', DEFAULT_HTTP_CONTENT_TYPE) + return self.headers.get("Content-Type", DEFAULT_HTTP_CONTENT_TYPE) @property def match_info(self): @@ -245,27 +273,23 @@ class Request(dict): @property def path(self): - return self._parsed_url.path.decode('utf-8') + return self._parsed_url.path.decode("utf-8") @property def query_string(self): if self._parsed_url.query: - return self._parsed_url.query.decode('utf-8') + return self._parsed_url.query.decode("utf-8") else: - return '' + return "" @property def url(self): - return urlunparse(( - self.scheme, - self.host, - self.path, - None, - self.query_string, - None)) + return urlunparse( + (self.scheme, self.host, self.path, None, self.query_string, None) + ) -File = namedtuple('File', ['type', 'body', 'name']) +File = namedtuple("File", ["type", "body", "name"]) def parse_multipart_form(body, boundary): @@ -281,37 +305,38 @@ def parse_multipart_form(body, boundary): form_parts = body.split(boundary) for form_part in form_parts[1:-1]: file_name = None - content_type = 'text/plain' - content_charset = 'utf-8' + content_type = "text/plain" + content_charset = "utf-8" field_name = None line_index = 2 line_end_index = 0 while not line_end_index == -1: - line_end_index = form_part.find(b'\r\n', line_index) - form_line = form_part[line_index:line_end_index].decode('utf-8') + line_end_index = form_part.find(b"\r\n", line_index) + form_line = form_part[line_index:line_end_index].decode("utf-8") line_index = line_end_index + 2 if not form_line: break - colon_index = form_line.index(':') + colon_index = form_line.index(":") form_header_field = form_line[0:colon_index].lower() form_header_value, form_parameters = parse_header( - form_line[colon_index + 2:]) + form_line[colon_index + 2 :] + ) - if form_header_field == 'content-disposition': - file_name = form_parameters.get('filename') - field_name = form_parameters.get('name') - elif form_header_field == 'content-type': + if form_header_field == "content-disposition": + file_name = form_parameters.get("filename") + field_name = form_parameters.get("name") + elif form_header_field == "content-type": content_type = form_header_value - content_charset = form_parameters.get('charset', 'utf-8') + content_charset = form_parameters.get("charset", "utf-8") if field_name: post_data = form_part[line_index:-4] if file_name: - form_file = File(type=content_type, - name=file_name, - body=post_data) + form_file = File( + type=content_type, name=file_name, body=post_data + ) if field_name in files: files[field_name].append(form_file) else: @@ -323,7 +348,9 @@ def parse_multipart_form(body, boundary): else: fields[field_name] = [value] else: - logger.debug('Form-data field does not have a \'name\' parameter \ - in the Content-Disposition header') + logger.debug( + "Form-data field does not have a 'name' parameter \ + in the Content-Disposition header" + ) return fields, files diff --git a/sanic/response.py b/sanic/response.py index ed1df0f4..cf9aaba2 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -24,16 +24,18 @@ class BaseHTTPResponse: return str(data).encode() def _parse_headers(self): - headers = b'' + headers = b"" for name, value in self.headers.items(): try: - headers += ( - b'%b: %b\r\n' % ( - name.encode(), value.encode('utf-8'))) + headers += b"%b: %b\r\n" % ( + name.encode(), + value.encode("utf-8"), + ) except AttributeError: - headers += ( - b'%b: %b\r\n' % ( - str(name).encode(), str(value).encode('utf-8'))) + headers += b"%b: %b\r\n" % ( + str(name).encode(), + str(value).encode("utf-8"), + ) return headers @@ -46,12 +48,17 @@ class BaseHTTPResponse: class StreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( - 'protocol', 'streaming_fn', 'status', - 'content_type', 'headers', '_cookies' + "protocol", + "streaming_fn", + "status", + "content_type", + "headers", + "_cookies", ) - def __init__(self, streaming_fn, status=200, headers=None, - content_type='text/plain'): + def __init__( + self, streaming_fn, status=200, headers=None, content_type="text/plain" + ): self.content_type = content_type self.streaming_fn = streaming_fn self.status = status @@ -66,61 +73,69 @@ class StreamingHTTPResponse(BaseHTTPResponse): if type(data) != bytes: data = self._encode_body(data) - self.protocol.push_data( - b"%x\r\n%b\r\n" % (len(data), data)) + self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) await self.protocol.drain() async def stream( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + self, version="1.1", keep_alive=False, keep_alive_timeout=None + ): """Streams headers, runs the `streaming_fn` callback that writes content to the response body, then finalizes the response body. """ headers = self.get_headers( - version, keep_alive=keep_alive, - keep_alive_timeout=keep_alive_timeout) + version, + keep_alive=keep_alive, + keep_alive_timeout=keep_alive_timeout, + ) self.protocol.push_data(headers) await self.protocol.drain() await self.streaming_fn(self) - self.protocol.push_data(b'0\r\n\r\n') + self.protocol.push_data(b"0\r\n\r\n") # no need to await drain here after this write, because it is the # very last thing we write and nothing needs to wait for it. def get_headers( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + self, version="1.1", keep_alive=False, keep_alive_timeout=None + ): # This is all returned in a kind-of funky way # We tried to make this as fast as possible in pure python - timeout_header = b'' + timeout_header = b"" if keep_alive and keep_alive_timeout is not None: - timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - self.headers['Transfer-Encoding'] = 'chunked' - self.headers.pop('Content-Length', None) - self.headers['Content-Type'] = self.headers.get( - 'Content-Type', self.content_type) + self.headers["Transfer-Encoding"] = "chunked" + self.headers.pop("Content-Length", None) + self.headers["Content-Type"] = self.headers.get( + "Content-Type", self.content_type + ) headers = self._parse_headers() if self.status is 200: - status = b'OK' + status = b"OK" else: status = STATUS_CODES.get(self.status) - return (b'HTTP/%b %d %b\r\n' - b'%b' - b'%b\r\n') % ( - version.encode(), - self.status, - status, - timeout_header, - headers - ) + return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % ( + version.encode(), + self.status, + status, + timeout_header, + headers, + ) class HTTPResponse(BaseHTTPResponse): - __slots__ = ('body', 'status', 'content_type', 'headers', '_cookies') + __slots__ = ("body", "status", "content_type", "headers", "_cookies") - def __init__(self, body=None, status=200, headers=None, - content_type='text/plain', body_bytes=b''): + def __init__( + self, + body=None, + status=200, + headers=None, + content_type="text/plain", + body_bytes=b"", + ): self.content_type = content_type if body is not None: @@ -132,22 +147,23 @@ class HTTPResponse(BaseHTTPResponse): self.headers = CIMultiDict(headers or {}) self._cookies = None - def output( - self, version="1.1", keep_alive=False, keep_alive_timeout=None): + def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): # This is all returned in a kind-of funky way # We tried to make this as fast as possible in pure python - timeout_header = b'' + timeout_header = b"" if keep_alive and keep_alive_timeout is not None: - timeout_header = b'Keep-Alive: %d\r\n' % keep_alive_timeout + timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout - body = b'' + body = b"" if has_message_body(self.status): body = self.body - self.headers['Content-Length'] = self.headers.get( - 'Content-Length', len(self.body)) + self.headers["Content-Length"] = self.headers.get( + "Content-Length", len(self.body) + ) - self.headers['Content-Type'] = self.headers.get( - 'Content-Type', self.content_type) + self.headers["Content-Type"] = self.headers.get( + "Content-Type", self.content_type + ) if self.status in (304, 412): self.headers = remove_entity_headers(self.headers) @@ -155,23 +171,21 @@ class HTTPResponse(BaseHTTPResponse): headers = self._parse_headers() if self.status is 200: - status = b'OK' + status = b"OK" else: - status = STATUS_CODES.get(self.status, b'UNKNOWN RESPONSE') + status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - return (b'HTTP/%b %d %b\r\n' - b'Connection: %b\r\n' - b'%b' - b'%b\r\n' - b'%b') % ( - version.encode(), - self.status, - status, - b'keep-alive' if keep_alive else b'close', - timeout_header, - headers, - body - ) + return ( + b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" + ) % ( + version.encode(), + self.status, + status, + b"keep-alive" if keep_alive else b"close", + timeout_header, + headers, + body, + ) @property def cookies(self): @@ -180,9 +194,14 @@ class HTTPResponse(BaseHTTPResponse): return self._cookies -def json(body, status=200, headers=None, - content_type="application/json", dumps=json_dumps, - **kwargs): +def json( + body, + status=200, + headers=None, + content_type="application/json", + dumps=json_dumps, + **kwargs +): """ Returns response object with body in json format. @@ -191,12 +210,17 @@ def json(body, status=200, headers=None, :param headers: Custom Headers. :param kwargs: Remaining arguments that are passed to the json encoder. """ - return HTTPResponse(dumps(body, **kwargs), headers=headers, - status=status, content_type=content_type) + return HTTPResponse( + dumps(body, **kwargs), + headers=headers, + status=status, + content_type=content_type, + ) -def text(body, status=200, headers=None, - content_type="text/plain; charset=utf-8"): +def text( + body, status=200, headers=None, content_type="text/plain; charset=utf-8" +): """ Returns response object with body in text format. @@ -206,12 +230,13 @@ def text(body, status=200, headers=None, :param content_type: the content type (string) of the response """ return HTTPResponse( - body, status=status, headers=headers, - content_type=content_type) + body, status=status, headers=headers, content_type=content_type + ) -def raw(body, status=200, headers=None, - content_type="application/octet-stream"): +def raw( + body, status=200, headers=None, content_type="application/octet-stream" +): """ Returns response object without encoding the body. @@ -220,8 +245,12 @@ def raw(body, status=200, headers=None, :param headers: Custom Headers. :param content_type: the content type (string) of the response. """ - return HTTPResponse(body_bytes=body, status=status, headers=headers, - content_type=content_type) + return HTTPResponse( + body_bytes=body, + status=status, + headers=headers, + content_type=content_type, + ) def html(body, status=200, headers=None): @@ -232,12 +261,22 @@ def html(body, status=200, headers=None): :param status: Response code. :param headers: Custom Headers. """ - return HTTPResponse(body, status=status, headers=headers, - content_type="text/html; charset=utf-8") + return HTTPResponse( + body, + status=status, + headers=headers, + content_type="text/html; charset=utf-8", + ) -async def file(location, status=200, mime_type=None, headers=None, - filename=None, _range=None): +async def file( + location, + status=200, + mime_type=None, + headers=None, + filename=None, + _range=None, +): """Return a response object with file data. :param location: Location of file on system. @@ -249,28 +288,40 @@ async def file(location, status=200, mime_type=None, headers=None, headers = headers or {} if filename: headers.setdefault( - 'Content-Disposition', - 'attachment; filename="{}"'.format(filename)) + "Content-Disposition", 'attachment; filename="{}"'.format(filename) + ) filename = filename or path.split(location)[-1] - async with open_async(location, mode='rb') as _file: + async with open_async(location, mode="rb") as _file: if _range: await _file.seek(_range.start) out_stream = await _file.read(_range.size) - headers['Content-Range'] = 'bytes %s-%s/%s' % ( - _range.start, _range.end, _range.total) + headers["Content-Range"] = "bytes %s-%s/%s" % ( + _range.start, + _range.end, + _range.total, + ) else: out_stream = await _file.read() - mime_type = mime_type or guess_type(filename)[0] or 'text/plain' - return HTTPResponse(status=status, - headers=headers, - content_type=mime_type, - body_bytes=out_stream) + mime_type = mime_type or guess_type(filename)[0] or "text/plain" + return HTTPResponse( + status=status, + headers=headers, + content_type=mime_type, + body_bytes=out_stream, + ) -async def file_stream(location, status=200, chunk_size=4096, mime_type=None, - headers=None, filename=None, _range=None): +async def file_stream( + location, + status=200, + chunk_size=4096, + mime_type=None, + headers=None, + filename=None, + _range=None, +): """Return a streaming response object with file data. :param location: Location of file on system. @@ -283,11 +334,11 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None, headers = headers or {} if filename: headers.setdefault( - 'Content-Disposition', - 'attachment; filename="{}"'.format(filename)) + "Content-Disposition", 'attachment; filename="{}"'.format(filename) + ) filename = filename or path.split(location)[-1] - _file = await open_async(location, mode='rb') + _file = await open_async(location, mode="rb") async def _streaming_fn(response): nonlocal _file, chunk_size @@ -312,19 +363,27 @@ async def file_stream(location, status=200, chunk_size=4096, mime_type=None, await _file.close() return # Returning from this fn closes the stream - mime_type = mime_type or guess_type(filename)[0] or 'text/plain' + mime_type = mime_type or guess_type(filename)[0] or "text/plain" if _range: - headers['Content-Range'] = 'bytes %s-%s/%s' % ( - _range.start, _range.end, _range.total) - return StreamingHTTPResponse(streaming_fn=_streaming_fn, - status=status, - headers=headers, - content_type=mime_type) + headers["Content-Range"] = "bytes %s-%s/%s" % ( + _range.start, + _range.end, + _range.total, + ) + return StreamingHTTPResponse( + streaming_fn=_streaming_fn, + status=status, + headers=headers, + content_type=mime_type, + ) def stream( - streaming_fn, status=200, headers=None, - content_type="text/plain; charset=utf-8"): + streaming_fn, + status=200, + headers=None, + content_type="text/plain; charset=utf-8", +): """Accepts an coroutine `streaming_fn` which can be used to write chunks to a streaming response. Returns a `StreamingHTTPResponse`. @@ -344,15 +403,13 @@ def stream( :param headers: Custom Headers. """ return StreamingHTTPResponse( - streaming_fn, - headers=headers, - content_type=content_type, - status=status + streaming_fn, headers=headers, content_type=content_type, status=status ) -def redirect(to, headers=None, status=302, - content_type="text/html; charset=utf-8"): +def redirect( + to, headers=None, status=302, content_type="text/html; charset=utf-8" +): """Abort execution and cause a 302 redirect (by default). :param to: path or fully qualified URL to redirect to @@ -367,9 +424,8 @@ def redirect(to, headers=None, status=302, safe_to = quote_plus(to, safe=":/#?&=@[]!$&'()*+,;") # According to RFC 7231, a relative URI is now permitted. - headers['Location'] = safe_to + headers["Location"] = safe_to return HTTPResponse( - status=status, - headers=headers, - content_type=content_type) + status=status, headers=headers, content_type=content_type + ) diff --git a/sanic/router.py b/sanic/router.py index 8ddba1a3..038c53bf 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -9,25 +9,28 @@ from sanic.exceptions import NotFound, MethodNotSupported from sanic.views import CompositionView Route = namedtuple( - 'Route', - ['handler', 'methods', 'pattern', 'parameters', 'name', 'uri']) -Parameter = namedtuple('Parameter', ['name', 'cast']) + "Route", ["handler", "methods", "pattern", "parameters", "name", "uri"] +) +Parameter = namedtuple("Parameter", ["name", "cast"]) REGEX_TYPES = { - 'string': (str, r'[^/]+'), - 'int': (int, r'\d+'), - 'number': (float, r'[0-9\\.]+'), - 'alpha': (str, r'[A-Za-z]+'), - 'path': (str, r'[^/].*?'), - 'uuid': (uuid.UUID, r'[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-' - r'[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}') + "string": (str, r"[^/]+"), + "int": (int, r"\d+"), + "number": (float, r"[0-9\\.]+"), + "alpha": (str, r"[A-Za-z]+"), + "path": (str, r"[^/].*?"), + "uuid": ( + uuid.UUID, + r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" + r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}", + ), } ROUTER_CACHE_SIZE = 1024 def url_hash(url): - return url.count('/') + return url.count("/") class RouteExists(Exception): @@ -68,10 +71,11 @@ class Router: also be passed in as the type. The argument given to the function will always be a string, independent of the type. """ + routes_static = None routes_dynamic = None routes_always_check = None - parameter_pattern = re.compile(r'<(.+?)>') + parameter_pattern = re.compile(r"<(.+?)>") def __init__(self): self.routes_all = {} @@ -98,9 +102,9 @@ class Router: """ # We could receive NAME or NAME:PATTERN name = parameter_string - pattern = 'string' - if ':' in parameter_string: - name, pattern = parameter_string.split(':', 1) + pattern = "string" + if ":" in parameter_string: + name, pattern = parameter_string.split(":", 1) if not name: raise ValueError( "Invalid parameter syntax: {}".format(parameter_string) @@ -112,8 +116,16 @@ class Router: return name, _type, pattern - def add(self, uri, methods, handler, host=None, strict_slashes=False, - version=None, name=None): + def add( + self, + uri, + methods, + handler, + host=None, + strict_slashes=False, + version=None, + name=None, + ): """Add a handler to the route list :param uri: path to match @@ -127,8 +139,8 @@ class Router: :return: Nothing """ if version is not None: - version = re.escape(str(version).strip('/').lstrip('v')) - uri = "/".join(["/v{}".format(version), uri.lstrip('/')]) + version = re.escape(str(version).strip("/").lstrip("v")) + uri = "/".join(["/v{}".format(version), uri.lstrip("/")]) # add regular version self._add(uri, methods, handler, host, name) @@ -143,28 +155,26 @@ class Router: return # Add versions with and without trailing / - slashed_methods = self.routes_all.get(uri + '/', frozenset({})) + slashed_methods = self.routes_all.get(uri + "/", frozenset({})) unslashed_methods = self.routes_all.get(uri[:-1], frozenset({})) if isinstance(methods, Iterable): - _slash_is_missing = all(method in slashed_methods for - method in methods) - _without_slash_is_missing = all(method in unslashed_methods for - method in methods) + _slash_is_missing = all( + method in slashed_methods for method in methods + ) + _without_slash_is_missing = all( + method in unslashed_methods for method in methods + ) else: _slash_is_missing = methods in slashed_methods _without_slash_is_missing = methods in unslashed_methods - slash_is_missing = ( - not uri[-1] == '/' and not _slash_is_missing - ) + slash_is_missing = not uri[-1] == "/" and not _slash_is_missing without_slash_is_missing = ( - uri[-1] == '/' and not - _without_slash_is_missing and not - uri == '/' + uri[-1] == "/" and not _without_slash_is_missing and not uri == "/" ) # add version with trailing slash if slash_is_missing: - self._add(uri + '/', methods, handler, host, name) + self._add(uri + "/", methods, handler, host, name) # add version without trailing slash elif without_slash_is_missing: self._add(uri[:-1], methods, handler, host, name) @@ -187,8 +197,10 @@ class Router: else: if not isinstance(host, Iterable): - raise ValueError("Expected either string or Iterable of " - "host strings, not {!r}".format(host)) + raise ValueError( + "Expected either string or Iterable of " + "host strings, not {!r}".format(host) + ) for host_ in host: self.add(uri, methods, handler, host_, name) @@ -208,38 +220,39 @@ class Router: if name in parameter_names: raise ParameterNameConflicts( - "Multiple parameter named <{name}> " - "in route uri {uri}".format(name=name, uri=uri)) + "Multiple parameter named <{name}> " + "in route uri {uri}".format(name=name, uri=uri) + ) parameter_names.add(name) - parameter = Parameter( - name=name, cast=_type) + parameter = Parameter(name=name, cast=_type) parameters.append(parameter) # Mark the whole route as unhashable if it has the hash key in it - if re.search(r'(^|[^^]){1}/', pattern): - properties['unhashable'] = True + if re.search(r"(^|[^^]){1}/", pattern): + properties["unhashable"] = True # Mark the route as unhashable if it matches the hash key - elif re.search(r'/', pattern): - properties['unhashable'] = True + elif re.search(r"/", pattern): + properties["unhashable"] = True - return '({})'.format(pattern) + return "({})".format(pattern) pattern_string = re.sub(self.parameter_pattern, add_parameter, uri) - pattern = re.compile(r'^{}$'.format(pattern_string)) + pattern = re.compile(r"^{}$".format(pattern_string)) def merge_route(route, methods, handler): # merge to the existing route when possible. if not route.methods or not methods: # method-unspecified routes are not mergeable. - raise RouteExists( - "Route already registered: {}".format(uri)) + raise RouteExists("Route already registered: {}".format(uri)) elif route.methods.intersection(methods): # already existing method is not overloadable. duplicated = methods.intersection(route.methods) raise RouteExists( "Route already registered: {} [{}]".format( - uri, ','.join(list(duplicated)))) + uri, ",".join(list(duplicated)) + ) + ) if isinstance(route.handler, CompositionView): view = route.handler else: @@ -247,19 +260,22 @@ class Router: view.add(route.methods, route.handler) view.add(methods, handler) route = route._replace( - handler=view, methods=methods.union(route.methods)) + handler=view, methods=methods.union(route.methods) + ) return route if parameters: # TODO: This is too complex, we need to reduce the complexity - if properties['unhashable']: + if properties["unhashable"]: routes_to_check = self.routes_always_check ndx, route = self.check_dynamic_route_exists( - pattern, routes_to_check, parameters) + pattern, routes_to_check, parameters + ) else: routes_to_check = self.routes_dynamic[url_hash(uri)] ndx, route = self.check_dynamic_route_exists( - pattern, routes_to_check, parameters) + pattern, routes_to_check, parameters + ) if ndx != -1: # Pop the ndx of the route, no dups of the same route routes_to_check.pop(ndx) @@ -270,35 +286,41 @@ class Router: # if available # special prefix for static files is_static = False - if name and name.startswith('_static_'): + if name and name.startswith("_static_"): is_static = True - name = name.split('_static_', 1)[-1] + name = name.split("_static_", 1)[-1] - if hasattr(handler, '__blueprintname__'): - handler_name = '{}.{}'.format( - handler.__blueprintname__, name or handler.__name__) + if hasattr(handler, "__blueprintname__"): + handler_name = "{}.{}".format( + handler.__blueprintname__, name or handler.__name__ + ) else: - handler_name = name or getattr(handler, '__name__', None) + handler_name = name or getattr(handler, "__name__", None) if route: route = merge_route(route, methods, handler) else: route = Route( - handler=handler, methods=methods, pattern=pattern, - parameters=parameters, name=handler_name, uri=uri) + handler=handler, + methods=methods, + pattern=pattern, + parameters=parameters, + name=handler_name, + uri=uri, + ) self.routes_all[uri] = route if is_static: pair = self.routes_static_files.get(handler_name) - if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): + if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])): self.routes_static_files[handler_name] = (uri, route) else: pair = self.routes_names.get(handler_name) - if not (pair and (pair[0] + '/' == uri or uri + '/' == pair[0])): + if not (pair and (pair[0] + "/" == uri or uri + "/" == pair[0])): self.routes_names[handler_name] = (uri, route) - if properties['unhashable']: + if properties["unhashable"]: self.routes_always_check.append(route) elif parameters: self.routes_dynamic[url_hash(uri)].append(route) @@ -333,8 +355,10 @@ class Router: if route in self.routes_always_check: self.routes_always_check.remove(route) - elif url_hash(uri) in self.routes_dynamic \ - and route in self.routes_dynamic[url_hash(uri)]: + elif ( + url_hash(uri) in self.routes_dynamic + and route in self.routes_dynamic[url_hash(uri)] + ): self.routes_dynamic[url_hash(uri)].remove(route) else: self.routes_static.pop(uri) @@ -353,7 +377,7 @@ class Router: if not view_name: return (None, None) - if view_name == 'static' or view_name.endswith('.static'): + if view_name == "static" or view_name.endswith(".static"): return self.routes_static_files.get(name, (None, None)) return self.routes_names.get(view_name, (None, None)) @@ -367,14 +391,15 @@ class Router: """ # No virtual hosts specified; default behavior if not self.hosts: - return self._get(request.path, request.method, '') + return self._get(request.path, request.method, "") # virtual hosts specified; try to match route to the host header try: - return self._get(request.path, request.method, - request.headers.get("Host", '')) + return self._get( + request.path, request.method, request.headers.get("Host", "") + ) # try default hosts except NotFound: - return self._get(request.path, request.method, '') + return self._get(request.path, request.method, "") def get_supported_methods(self, url): """Get a list of supported methods for a url and optional host. @@ -384,7 +409,7 @@ class Router: """ route = self.routes_all.get(url) # if methods are None then this logic will prevent an error - return getattr(route, 'methods', None) or frozenset() + return getattr(route, "methods", None) or frozenset() @lru_cache(maxsize=ROUTER_CACHE_SIZE) def _get(self, url, method, host): @@ -399,9 +424,10 @@ class Router: # Check against known static routes route = self.routes_static.get(url) method_not_supported = MethodNotSupported( - 'Method {} not allowed for URL {}'.format(method, url), + "Method {} not allowed for URL {}".format(method, url), method=method, - allowed_methods=self.get_supported_methods(url)) + allowed_methods=self.get_supported_methods(url), + ) if route: if route.methods and method not in route.methods: raise method_not_supported @@ -427,13 +453,14 @@ class Router: # Route was found but the methods didn't match if route_found: raise method_not_supported - raise NotFound('Requested URL {} not found'.format(url)) + raise NotFound("Requested URL {} not found".format(url)) - kwargs = {p.name: p.cast(value) - for value, p - in zip(match.groups(1), route.parameters)} + kwargs = { + p.name: p.cast(value) + for value, p in zip(match.groups(1), route.parameters) + } route_handler = route.handler - if hasattr(route_handler, 'handlers'): + if hasattr(route_handler, "handlers"): route_handler = route_handler.handlers[method] return route_handler, [], kwargs, route.uri @@ -446,7 +473,8 @@ class Router: handler = self.get(request)[0] except (NotFound, MethodNotSupported): return False - if (hasattr(handler, 'view_class') and - hasattr(handler.view_class, request.method.lower())): + if hasattr(handler, "view_class") and hasattr( + handler.view_class, request.method.lower() + ): handler = getattr(handler.view_class, request.method.lower()) - return hasattr(handler, 'is_stream') + return hasattr(handler, "is_stream") diff --git a/sanic/server.py b/sanic/server.py index e5069875..2f743b8a 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -4,16 +4,8 @@ import traceback from functools import partial from inspect import isawaitable from multiprocessing import Process -from signal import ( - SIGTERM, SIGINT, SIG_IGN, - signal as signal_func, - Signals -) -from socket import ( - socket, - SOL_SOCKET, - SO_REUSEADDR, -) +from signal import SIGTERM, SIGINT, SIG_IGN, signal as signal_func, Signals +from socket import socket, SOL_SOCKET, SO_REUSEADDR from time import time from httptools import HttpRequestParser @@ -22,6 +14,7 @@ from multidict import CIMultiDict try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -30,8 +23,12 @@ from sanic.log import logger, access_logger from sanic.response import HTTPResponse from sanic.request import Request from sanic.exceptions import ( - RequestTimeout, PayloadTooLarge, InvalidUsage, ServerError, - ServiceUnavailable) + RequestTimeout, + PayloadTooLarge, + InvalidUsage, + ServerError, + ServiceUnavailable, +) current_time = None @@ -43,27 +40,58 @@ class Signal: class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection - 'loop', 'transport', 'connections', 'signal', + "loop", + "transport", + "connections", + "signal", # request params - 'parser', 'request', 'url', 'headers', + "parser", + "request", + "url", + "headers", # request config - 'request_handler', 'request_timeout', 'response_timeout', - 'keep_alive_timeout', 'request_max_size', 'request_class', - 'is_request_stream', 'router', + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_class", + "is_request_stream", + "router", # enable or disable access log purpose - 'access_log', + "access_log", # connection management - '_total_request_size', '_request_timeout_handler', - '_response_timeout_handler', '_keep_alive_timeout_handler', - '_last_request_time', '_last_response_time', '_is_stream_handler', - '_not_paused') + "_total_request_size", + "_request_timeout_handler", + "_response_timeout_handler", + "_keep_alive_timeout_handler", + "_last_request_time", + "_last_response_time", + "_is_stream_handler", + "_not_paused", + ) - def __init__(self, *, loop, request_handler, error_handler, - signal=Signal(), connections=set(), request_timeout=60, - response_timeout=60, keep_alive_timeout=5, - request_max_size=None, request_class=None, access_log=True, - keep_alive=True, is_request_stream=False, router=None, - state=None, debug=False, **kwargs): + def __init__( + self, + *, + loop, + request_handler, + error_handler, + signal=Signal(), + connections=set(), + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + request_max_size=None, + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + state=None, + debug=False, + **kwargs + ): self.loop = loop self.transport = None self.request = None @@ -93,19 +121,20 @@ class HttpProtocol(asyncio.Protocol): self._request_handler_task = None self._request_stream_task = None self._keep_alive = keep_alive - self._header_fragment = b'' + self._header_fragment = b"" self.state = state if state else {} - if 'requests_count' not in self.state: - self.state['requests_count'] = 0 + if "requests_count" not in self.state: + self.state["requests_count"] = 0 self._debug = debug self._not_paused.set() @property def keep_alive(self): return ( - self._keep_alive and - not self.signal.stopped and - self.parser.should_keep_alive()) + self._keep_alive + and not self.signal.stopped + and self.parser.should_keep_alive() + ) # -------------------------------------------- # # Connection @@ -114,7 +143,8 @@ class HttpProtocol(asyncio.Protocol): def connection_made(self, transport): self.connections.add(self) self._request_timeout_handler = self.loop.call_later( - self.request_timeout, self.request_timeout_callback) + self.request_timeout, self.request_timeout_callback + ) self.transport = transport self._last_request_time = current_time @@ -145,16 +175,15 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed - self._request_timeout_handler = ( - self.loop.call_later(time_left, - self.request_timeout_callback) + self._request_timeout_handler = self.loop.call_later( + time_left, self.request_timeout_callback ) else: if self._request_stream_task: self._request_stream_task.cancel() if self._request_handler_task: self._request_handler_task.cancel() - self.write_error(RequestTimeout('Request Timeout')) + self.write_error(RequestTimeout("Request Timeout")) def response_timeout_callback(self): # Check if elapsed time since response was initiated exceeds our @@ -162,16 +191,15 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_request_time if time_elapsed < self.response_timeout: time_left = self.response_timeout - time_elapsed - self._response_timeout_handler = ( - self.loop.call_later(time_left, - self.response_timeout_callback) + self._response_timeout_handler = self.loop.call_later( + time_left, self.response_timeout_callback ) else: if self._request_stream_task: self._request_stream_task.cancel() if self._request_handler_task: self._request_handler_task.cancel() - self.write_error(ServiceUnavailable('Response Timeout')) + self.write_error(ServiceUnavailable("Response Timeout")) def keep_alive_timeout_callback(self): # Check if elapsed time since last response exceeds our configured @@ -179,12 +207,11 @@ class HttpProtocol(asyncio.Protocol): time_elapsed = current_time - self._last_response_time if time_elapsed < self.keep_alive_timeout: time_left = self.keep_alive_timeout - time_elapsed - self._keep_alive_timeout_handler = ( - self.loop.call_later(time_left, - self.keep_alive_timeout_callback) + self._keep_alive_timeout_handler = self.loop.call_later( + time_left, self.keep_alive_timeout_callback ) else: - logger.debug('KeepAlive Timeout. Closing connection.') + logger.debug("KeepAlive Timeout. Closing connection.") self.transport.close() self.transport = None @@ -197,7 +224,7 @@ class HttpProtocol(asyncio.Protocol): # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: - self.write_error(PayloadTooLarge('Payload Too Large')) + self.write_error(PayloadTooLarge("Payload Too Large")) # Create parser if this is the first time we're receiving data if self.parser is None: @@ -206,15 +233,15 @@ class HttpProtocol(asyncio.Protocol): self.parser = HttpRequestParser(self) # requests count - self.state['requests_count'] = self.state['requests_count'] + 1 + self.state["requests_count"] = self.state["requests_count"] + 1 # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: - message = 'Bad Request' + message = "Bad Request" if self._debug: - message += '\n' + traceback.format_exc() + message += "\n" + traceback.format_exc() self.write_error(InvalidUsage(message)) def on_url(self, url): @@ -227,17 +254,20 @@ class HttpProtocol(asyncio.Protocol): self._header_fragment += name if value is not None: - if self._header_fragment == b'Content-Length' \ - and int(value) > self.request_max_size: - self.write_error(PayloadTooLarge('Payload Too Large')) + if ( + self._header_fragment == b"Content-Length" + and int(value) > self.request_max_size + ): + self.write_error(PayloadTooLarge("Payload Too Large")) try: value = value.decode() except UnicodeDecodeError: - value = value.decode('latin_1') + value = value.decode("latin_1") self.headers.append( - (self._header_fragment.decode().casefold(), value)) + (self._header_fragment.decode().casefold(), value) + ) - self._header_fragment = b'' + self._header_fragment = b"" def on_headers_complete(self): self.request = self.request_class( @@ -245,7 +275,7 @@ class HttpProtocol(asyncio.Protocol): headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), - transport=self.transport + transport=self.transport, ) # Remove any existing KeepAlive handler here, # It will be recreated if required on the new request. @@ -254,7 +284,8 @@ class HttpProtocol(asyncio.Protocol): self._keep_alive_timeout_handler = None if self.is_request_stream: self._is_stream_handler = self.router.is_stream_handler( - self.request) + self.request + ) if self._is_stream_handler: self.request.stream = asyncio.Queue() self.execute_request_handler() @@ -262,7 +293,8 @@ class HttpProtocol(asyncio.Protocol): def on_body(self, body): if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( - self.request.stream.put(body)) + self.request.stream.put(body) + ) return self.request.body.append(body) @@ -274,47 +306,49 @@ class HttpProtocol(asyncio.Protocol): self._request_timeout_handler = None if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( - self.request.stream.put(None)) + self.request.stream.put(None) + ) return - self.request.body = b''.join(self.request.body) + self.request.body = b"".join(self.request.body) self.execute_request_handler() def execute_request_handler(self): self._response_timeout_handler = self.loop.call_later( - self.response_timeout, self.response_timeout_callback) + self.response_timeout, self.response_timeout_callback + ) self._last_request_time = current_time self._request_handler_task = self.loop.create_task( self.request_handler( - self.request, - self.write_response, - self.stream_response)) + self.request, self.write_response, self.stream_response + ) + ) # -------------------------------------------- # # Responding # -------------------------------------------- # def log_response(self, response): if self.access_log: - extra = { - 'status': getattr(response, 'status', 0), - } + extra = {"status": getattr(response, "status", 0)} if isinstance(response, HTTPResponse): - extra['byte'] = len(response.body) + extra["byte"] = len(response.body) else: - extra['byte'] = -1 + extra["byte"] = -1 - extra['host'] = 'UNKNOWN' + extra["host"] = "UNKNOWN" if self.request is not None: if self.request.ip: - extra['host'] = '{0}:{1}'.format(self.request.ip, - self.request.port) + extra["host"] = "{0}:{1}".format( + self.request.ip, self.request.port + ) - extra['request'] = '{0} {1}'.format(self.request.method, - self.request.url) + extra["request"] = "{0} {1}".format( + self.request.method, self.request.url + ) else: - extra['request'] = 'nil' + extra["request"] = "nil" - access_logger.info('', extra=extra) + access_logger.info("", extra=extra) def write_response(self, response): """ @@ -327,31 +361,37 @@ class HttpProtocol(asyncio.Protocol): keep_alive = self.keep_alive self.transport.write( response.output( - self.request.version, keep_alive, - self.keep_alive_timeout)) + self.request.version, keep_alive, self.keep_alive_timeout + ) + ) self.log_response(response) except AttributeError: - logger.error('Invalid response object for url %s, ' - 'Expected Type: HTTPResponse, Actual Type: %s', - self.url, type(response)) - self.write_error(ServerError('Invalid response type')) + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) except RuntimeError: if self._debug: - logger.error('Connection lost before response written @ %s', - self.request.ip) + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) keep_alive = False except Exception as e: self.bail_out( - "Writing response failed, connection closed {}".format( - repr(e))) + "Writing response failed, connection closed {}".format(repr(e)) + ) finally: if not keep_alive: self.transport.close() self.transport = None else: self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, - self.keep_alive_timeout_callback) + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) self._last_response_time = current_time self.cleanup() @@ -375,30 +415,36 @@ class HttpProtocol(asyncio.Protocol): keep_alive = self.keep_alive response.protocol = self await response.stream( - self.request.version, keep_alive, self.keep_alive_timeout) + self.request.version, keep_alive, self.keep_alive_timeout + ) self.log_response(response) except AttributeError: - logger.error('Invalid response object for url %s, ' - 'Expected Type: HTTPResponse, Actual Type: %s', - self.url, type(response)) - self.write_error(ServerError('Invalid response type')) + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) except RuntimeError: if self._debug: - logger.error('Connection lost before response written @ %s', - self.request.ip) + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) keep_alive = False except Exception as e: self.bail_out( - "Writing response failed, connection closed {}".format( - repr(e))) + "Writing response failed, connection closed {}".format(repr(e)) + ) finally: if not keep_alive: self.transport.close() self.transport = None else: self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, - self.keep_alive_timeout_callback) + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) self._last_response_time = current_time self.cleanup() @@ -411,32 +457,37 @@ class HttpProtocol(asyncio.Protocol): response = None try: response = self.error_handler.response(self.request, exception) - version = self.request.version if self.request else '1.1' + version = self.request.version if self.request else "1.1" self.transport.write(response.output(version)) except RuntimeError: if self._debug: - logger.error('Connection lost before error written @ %s', - self.request.ip if self.request else 'Unknown') + logger.error( + "Connection lost before error written @ %s", + self.request.ip if self.request else "Unknown", + ) except Exception as e: self.bail_out( - "Writing error failed, connection closed {}".format( - repr(e)), from_error=True + "Writing error failed, connection closed {}".format(repr(e)), + from_error=True, ) finally: - if self.parser and (self.keep_alive - or getattr(response, 'status', 0) == 408): + if self.parser and ( + self.keep_alive or getattr(response, "status", 0) == 408 + ): self.log_response(response) try: self.transport.close() except AttributeError: - logger.debug('Connection lost before server could close it.') + logger.debug("Connection lost before server could close it.") def bail_out(self, message, from_error=False): if from_error or self.transport.is_closing(): - logger.error("Transport closed @ %s and exception " - "experienced during error handling", - self.transport.get_extra_info('peername')) - logger.debug('Exception:', exc_info=True) + logger.error( + "Transport closed @ %s and exception " + "experienced during error handling", + self.transport.get_extra_info("peername"), + ) + logger.debug("Exception:", exc_info=True) else: self.write_error(ServerError(message)) logger.error(message) @@ -497,17 +548,43 @@ def trigger_events(events, loop): loop.run_until_complete(result) -def serve(host, port, request_handler, error_handler, before_start=None, - after_start=None, before_stop=None, after_stop=None, debug=False, - request_timeout=60, response_timeout=60, keep_alive_timeout=5, - ssl=None, sock=None, request_max_size=None, reuse_port=False, - loop=None, protocol=HttpProtocol, backlog=100, - register_sys_signals=True, run_multiple=False, run_async=False, - connections=None, signal=Signal(), request_class=None, - access_log=True, keep_alive=True, is_request_stream=False, - router=None, websocket_max_size=None, websocket_max_queue=None, - websocket_read_limit=2 ** 16, websocket_write_limit=2 ** 16, - state=None, graceful_shutdown_timeout=15.0): +def serve( + host, + port, + request_handler, + error_handler, + before_start=None, + after_start=None, + before_stop=None, + after_stop=None, + debug=False, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + ssl=None, + sock=None, + request_max_size=None, + reuse_port=False, + loop=None, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_multiple=False, + run_async=False, + connections=None, + signal=Signal(), + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + state=None, + graceful_shutdown_timeout=15.0, +): """Start asynchronous HTTP Server on an individual process. :param host: Address to host on @@ -592,7 +669,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, ssl=ssl, reuse_port=reuse_port, sock=sock, - backlog=backlog + backlog=backlog, ) # Instead of pulling time at the end of every request, @@ -623,11 +700,13 @@ def serve(host, port, request_handler, error_handler, before_start=None, try: loop.add_signal_handler(_signal, loop.stop) except NotImplementedError: - logger.warning('Sanic tried to use loop.add_signal_handler ' - 'but it is not implemented on this platform.') + logger.warning( + "Sanic tried to use loop.add_signal_handler " + "but it is not implemented on this platform." + ) pid = os.getpid() try: - logger.info('Starting worker [%s]', pid) + logger.info("Starting worker [%s]", pid) loop.run_forever() finally: logger.info("Stopping worker [%s]", pid) @@ -658,9 +737,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append( - conn.websocket.close_connection() - ) + coros.append(conn.websocket.close_connection()) else: conn.close() @@ -681,18 +758,18 @@ def serve_multiple(server_settings, workers): :param stop_event: if provided, is used as a stop signal :return: """ - server_settings['reuse_port'] = True - server_settings['run_multiple'] = True + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True # Handling when custom socket is not provided. - if server_settings.get('sock') is None: + if server_settings.get("sock") is None: sock = socket() sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - sock.bind((server_settings['host'], server_settings['port'])) + sock.bind((server_settings["host"], server_settings["port"])) sock.set_inheritable(True) - server_settings['sock'] = sock - server_settings['host'] = None - server_settings['port'] = None + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None def sig_handler(signal, frame): logger.info("Received signal %s. Shutting down.", Signals(signal).name) @@ -716,4 +793,4 @@ def serve_multiple(server_settings, workers): # the above processes will block this until they're stopped for process in processes: process.terminate() - server_settings.get('sock').close() + server_settings.get("sock").close() diff --git a/sanic/static.py b/sanic/static.py index 07831390..9aaa9750 100644 --- a/sanic/static.py +++ b/sanic/static.py @@ -16,10 +16,19 @@ from sanic.handlers import ContentRangeHandler from sanic.response import file, file_stream, HTTPResponse -def register(app, uri, file_or_directory, pattern, - use_modified_since, use_content_range, - stream_large_files, name='static', host=None, - strict_slashes=None, content_type=None): +def register( + app, + uri, + file_or_directory, + pattern, + use_modified_since, + use_content_range, + stream_large_files, + name="static", + host=None, + strict_slashes=None, + content_type=None, +): # TODO: Though sanic is not a file server, I feel like we should at least # make a good effort here. Modified-since is nice, but we could # also look into etags, expires, and caching @@ -46,12 +55,12 @@ def register(app, uri, file_or_directory, pattern, # If we're not trying to match a file directly, # serve from the folder if not path.isfile(file_or_directory): - uri += '