Hello world!
') +``` + +Then with curl: + +```bash +curl localhost/v1/html +``` diff --git a/sanic/app.py b/sanic/app.py index ff680d9c..f1e8be7e 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -113,7 +113,7 @@ class Sanic: # Decorator def route(self, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=False, stream=False): + strict_slashes=False, stream=False, version=None): """Decorate a function to be registered as a route :param uri: path of the URL @@ -136,42 +136,49 @@ class Sanic: if stream: handler.is_stream = stream self.router.add(uri=uri, methods=methods, handler=handler, - host=host, strict_slashes=strict_slashes) + host=host, strict_slashes=strict_slashes, + version=version) return handler return response # Shorthand method decorators - def get(self, uri, host=None, strict_slashes=False): + def get(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=frozenset({"GET"}), host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def post(self, uri, host=None, strict_slashes=False, stream=False): + def post(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=frozenset({"POST"}), host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def put(self, uri, host=None, strict_slashes=False, stream=False): + def put(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=frozenset({"PUT"}), host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def head(self, uri, host=None, strict_slashes=False): + def head(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=frozenset({"HEAD"}), host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def options(self, uri, host=None, strict_slashes=False): + def options(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=frozenset({"OPTIONS"}), host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def patch(self, uri, host=None, strict_slashes=False, stream=False): + def patch(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=frozenset({"PATCH"}), host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def delete(self, uri, host=None, strict_slashes=False): + def delete(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=frozenset({"DELETE"}), host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=False): + strict_slashes=False, version=None): """A helper method to register class instance or functions as a handler to the application url routes. @@ -204,7 +211,8 @@ class Sanic: break self.route(uri=uri, methods=methods, host=host, - strict_slashes=strict_slashes, stream=stream)(handler) + strict_slashes=strict_slashes, stream=stream, + version=version)(handler) return handler # Decorator @@ -701,7 +709,8 @@ class Sanic: 'backlog': backlog, 'has_log': has_log, 'websocket_max_size': self.config.WEBSOCKET_MAX_SIZE, - 'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE + 'websocket_max_queue': self.config.WEBSOCKET_MAX_QUEUE, + 'graceful_shutdown_timeout': self.config.GRACEFUL_SHUTDOWN_TIMEOUT } # -------------------------------------------- # diff --git a/sanic/blueprints.py b/sanic/blueprints.py index b3866cbd..0e97903b 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,8 +4,8 @@ from sanic.constants import HTTP_METHODS from sanic.views import CompositionView FutureRoute = namedtuple('Route', - ['handler', 'uri', 'methods', - 'host', 'strict_slashes', 'stream']) + ['handler', 'uri', 'methods', 'host', + 'strict_slashes', 'stream', 'version']) FutureListener = namedtuple('Listener', ['handler', 'uri', 'methods', 'host']) FutureMiddleware = namedtuple('Route', ['middleware', 'args', 'kwargs']) FutureException = namedtuple('Route', ['handler', 'args', 'kwargs']) @@ -14,7 +14,7 @@ FutureStatic = namedtuple('Route', class Blueprint: - def __init__(self, name, url_prefix=None, host=None): + def __init__(self, name, url_prefix=None, host=None, version=None): """Create a new blueprint :param name: unique name of the blueprint @@ -30,6 +30,7 @@ class Blueprint: self.listeners = defaultdict(list) self.middlewares = [] self.statics = [] + self.version = version def register(self, app, options): """Register the blueprint to the sanic app.""" @@ -43,12 +44,16 @@ class Blueprint: future.handler.__blueprintname__ = self.name # Prepend the blueprint URI prefix if available uri = url_prefix + future.uri if url_prefix else future.uri + + version = future.version or self.version + app.route( uri=uri[1:] if uri.startswith('//') else uri, methods=future.methods, host=future.host or self.host, strict_slashes=future.strict_slashes, - stream=future.stream + stream=future.stream, + version=version )(future.handler) for future in self.websocket_routes: @@ -89,7 +94,7 @@ class Blueprint: app.listener(event)(listener) def route(self, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=False, stream=False): + strict_slashes=False, stream=False, version=None): """Create a blueprint route from a decorated function. :param uri: endpoint at which the route will be accessible. @@ -97,13 +102,13 @@ class Blueprint: """ def decorator(handler): route = FutureRoute( - handler, uri, methods, host, strict_slashes, stream) + handler, uri, methods, host, strict_slashes, stream, version) self.routes.append(route) return handler return decorator def add_route(self, handler, uri, methods=frozenset({'GET'}), host=None, - strict_slashes=False): + strict_slashes=False, version=None): """Create a blueprint route from a function. :param handler: function for handling uri requests. Accepts function, @@ -125,21 +130,22 @@ class Blueprint: methods = handler.handlers.keys() self.route(uri=uri, methods=methods, host=host, - strict_slashes=strict_slashes)(handler) + strict_slashes=strict_slashes, version=version)(handler) return handler - def websocket(self, uri, host=None, strict_slashes=False): + def websocket(self, uri, host=None, strict_slashes=False, version=None): """Create a blueprint websocket route from a decorated function. :param uri: endpoint at which the route will be accessible. """ def decorator(handler): - route = FutureRoute(handler, uri, [], host, strict_slashes, False) + route = FutureRoute(handler, uri, [], host, strict_slashes, + False, version) self.websocket_routes.append(route) return handler return decorator - def add_websocket_route(self, handler, uri, host=None): + def add_websocket_route(self, handler, uri, host=None, version=None): """Create a blueprint websocket route from a function. :param handler: function for handling uri requests. Accepts function, @@ -147,7 +153,7 @@ class Blueprint: :param uri: endpoint at which the route will be accessible. :return: function or class instance """ - self.websocket(uri=uri, host=host)(handler) + self.websocket(uri=uri, host=host, version=version)(handler) return handler def listener(self, event): @@ -193,30 +199,36 @@ class Blueprint: self.statics.append(static) # Shorthand method decorators - def get(self, uri, host=None, strict_slashes=False): + def get(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=["GET"], host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def post(self, uri, host=None, strict_slashes=False, stream=False): + def post(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=["POST"], host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def put(self, uri, host=None, strict_slashes=False, stream=False): + def put(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=["PUT"], host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def head(self, uri, host=None, strict_slashes=False): + def head(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=["HEAD"], host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def options(self, uri, host=None, strict_slashes=False): + def options(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=["OPTIONS"], host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) - def patch(self, uri, host=None, strict_slashes=False, stream=False): + def patch(self, uri, host=None, strict_slashes=False, stream=False, + version=None): return self.route(uri, methods=["PATCH"], host=host, - strict_slashes=strict_slashes, stream=stream) + strict_slashes=strict_slashes, stream=stream, + version=version) - def delete(self, uri, host=None, strict_slashes=False): + def delete(self, uri, host=None, strict_slashes=False, version=None): return self.route(uri, methods=["DELETE"], host=host, - strict_slashes=strict_slashes) + strict_slashes=strict_slashes, version=version) diff --git a/sanic/config.py b/sanic/config.py index e3563bc1..6ffcf7a1 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -12,11 +12,12 @@ _address_dict = { 'Windows': ('localhost', 514), 'Darwin': '/var/run/syslog', 'Linux': '/dev/log', - 'FreeBSD': '/dev/log' + 'FreeBSD': '/var/run/log' } LOGGING = { 'version': 1, + 'disable_existing_loggers': False, 'filters': { 'accessFilter': { '()': DefaultFilter, @@ -127,6 +128,7 @@ class Config(dict): self.KEEP_ALIVE = keep_alive self.WEBSOCKET_MAX_SIZE = 2 ** 20 # 1 megabytes self.WEBSOCKET_MAX_QUEUE = 32 + self.GRACEFUL_SHUTDOWN_TIMEOUT = 15.0 # 15 sec if load_env: self.load_environment_vars() @@ -195,10 +197,16 @@ class Config(dict): def load_environment_vars(self): """ - Looks for any SANIC_ prefixed environment variables and applies + Looks for any ``SANIC_`` prefixed environment variables and applies them to the configuration if present. """ for k, v in os.environ.items(): if k.startswith(SANIC_PREFIX): _, config_key = k.split(SANIC_PREFIX, 1) - self[config_key] = v + try: + self[config_key] = int(v) + except ValueError: + try: + self[config_key] = float(v) + except ValueError: + self[config_key] = v diff --git a/sanic/exceptions.py b/sanic/exceptions.py index d05342fa..21ab2a94 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -194,6 +194,11 @@ class ContentRangeError(SanicException): } +@add_status_code(403) +class Forbidden(SanicException): + pass + + class InvalidRangeType(ContentRangeError): pass @@ -205,8 +210,8 @@ class Unauthorized(SanicException): :param scheme: Name of the authentication scheme to be used. :param challenge: A dict containing values to add to the WWW-Authenticate - header that is generated. This is especially useful when dealing with the - Digest scheme. (optional) + header that is generated. This is especially useful when dealing with + the Digest scheme. (optional) Examples:: @@ -227,7 +232,6 @@ class Unauthorized(SanicException): # With a Bearer auth-scheme, realm is optional: challenge = {"realm": "Restricted Area"} raise Unauthorized("Auth required.", "Bearer", challenge) - """ pass @@ -249,9 +253,10 @@ def abort(status_code, message=None): """ Raise an exception based on SanicException. Returns the HTTP response message appropriate for the given status code, unless provided. + :param status_code: The HTTP status code to return. :param message: The HTTP response body. Defaults to the messages - in response.py for the given status code. + in response.py for the given status code. """ if message is None: message = COMMON_STATUS_CODES.get(status_code, diff --git a/sanic/request.py b/sanic/request.py index 3cc9c10b..27ff011e 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -45,7 +45,7 @@ class Request(dict): __slots__ = ( 'app', 'headers', 'version', 'method', '_cookies', 'transport', 'body', 'parsed_json', 'parsed_args', 'parsed_form', 'parsed_files', - '_ip', '_parsed_url', 'uri_template', 'stream' + '_ip', '_parsed_url', 'uri_template', 'stream', '_remote_addr' ) def __init__(self, url_bytes, headers, version, method, transport): @@ -86,11 +86,15 @@ class Request(dict): :return: token related to request """ + prefixes = ('Bearer', 'Token') auth_header = self.headers.get('Authorization') - if auth_header is not None and 'Token ' in auth_header: - return auth_header.partition('Token ')[-1] - else: - return auth_header + + if auth_header is not None: + for prefix in prefixes: + if prefix in auth_header: + return auth_header.partition(prefix)[-1].strip() + + return auth_header @property def form(self): @@ -138,7 +142,7 @@ class Request(dict): @property def cookies(self): if self._cookies is None: - cookie = self.headers.get('Cookie') or self.headers.get('cookie') + cookie = self.headers.get('Cookie') if cookie is not None: cookies = SimpleCookie() cookies.load(cookie) @@ -155,6 +159,25 @@ class Request(dict): (None, None)) return self._ip + @property + def remote_addr(self): + """Attempt to return the original client ip based on X-Forwarded-For. + + :return: original client ip. + """ + if not hasattr(self, '_remote_addr'): + forwarded_for = self.headers.get('X-Forwarded-For', '').split(',') + remote_addrs = [ + addr for addr in [ + addr.strip() for addr in forwarded_for + ] if addr + ] + if len(remote_addrs) > 0: + self._remote_addr = remote_addrs[0] + else: + self._remote_addr = '' + return self._remote_addr + @property def scheme(self): if self.app.websocket_enabled \ @@ -234,15 +257,15 @@ def parse_multipart_form(body, boundary): break colon_index = form_line.index(':') - form_header_field = form_line[0:colon_index] + form_header_field = form_line[0:colon_index].lower() form_header_value, form_parameters = parse_header( form_line[colon_index + 2:]) - if form_header_field == 'Content-Disposition': + if form_header_field == 'content-disposition': if 'filename' in form_parameters: file_name = form_parameters['filename'] field_name = form_parameters.get('name') - elif form_header_field == 'Content-Type': + elif form_header_field == 'content-type': file_type = form_header_value post_data = form_part[line_index:-4] diff --git a/sanic/response.py b/sanic/response.py index ea233d9a..f4fb1ea6 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -237,6 +237,7 @@ def json(body, status=200, headers=None, content_type="application/json", **kwargs): """ Returns response object with body in json format. + :param body: Response data to be serialized. :param status: Response code. :param headers: Custom Headers. @@ -250,6 +251,7 @@ def text(body, status=200, headers=None, content_type="text/plain; charset=utf-8"): """ Returns response object with body in text format. + :param body: Response data to be encoded. :param status: Response code. :param headers: Custom Headers. @@ -264,6 +266,7 @@ def raw(body, status=200, headers=None, content_type="application/octet-stream"): """ Returns response object without encoding the body. + :param body: Response data. :param status: Response code. :param headers: Custom Headers. @@ -276,6 +279,7 @@ def raw(body, status=200, headers=None, def html(body, status=200, headers=None): """ Returns response object with body in html format. + :param body: Response data to be encoded. :param status: Response code. :param headers: Custom Headers. diff --git a/sanic/router.py b/sanic/router.py index 691f1388..efc48f37 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -98,8 +98,25 @@ class Router: return name, _type, pattern - def add(self, uri, methods, handler, host=None, strict_slashes=False): + def add(self, uri, methods, handler, host=None, strict_slashes=False, + version=None): + """Add a handler to the route list + :param uri: path to match + :param methods: sequence of accepted method names. If none are + provided, any method is allowed + :param handler: request handler function. + When executed, it should provide a response object. + :param strict_slashes: strict to trailing slash + :param version: current version of the route or blueprint. See + docs for further details. + :return: Nothing + """ + if version is not None: + if uri.startswith('/'): + uri = "/".join(["/v{}".format(str(version)), uri[1:]]) + else: + uri = "/".join(["/v{}".format(str(version)), uri]) # add regular version self._add(uri, methods, handler, host) diff --git a/sanic/server.py b/sanic/server.py index f3106226..2ee48688 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -75,7 +75,7 @@ class HttpProtocol(asyncio.Protocol): signal=Signal(), connections=set(), request_timeout=60, request_max_size=None, request_class=None, has_log=True, keep_alive=True, is_request_stream=False, router=None, - **kwargs): + state=None, debug=False, **kwargs): self.loop = loop self.transport = None self.request = None @@ -99,12 +99,18 @@ class HttpProtocol(asyncio.Protocol): self._request_handler_task = None self._request_stream_task = None self._keep_alive = keep_alive + self._header_fragment = b'' + self.state = state if state else {} + if 'requests_count' not in self.state: + self.state['requests_count'] = 0 + self._debug = debug @property def keep_alive(self): - return (self._keep_alive - and not self.signal.stopped - and self.parser.should_keep_alive()) + return ( + self._keep_alive and + not self.signal.stopped and + self.parser.should_keep_alive()) # -------------------------------------------- # # Connection @@ -154,22 +160,39 @@ class HttpProtocol(asyncio.Protocol): self.headers = [] self.parser = HttpRequestParser(self) + # requests count + self.state['requests_count'] = self.state['requests_count'] + 1 + # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: - exception = InvalidUsage('Bad Request') + message = 'Bad Request' + if self._debug: + message += '\n' + traceback.format_exc() + exception = InvalidUsage(message) self.write_error(exception) def on_url(self, url): - self.url = url + if not self.url: + self.url = url + else: + self.url += url def on_header(self, name, value): - if name == b'Content-Length' and int(value) > self.request_max_size: - exception = PayloadTooLarge('Payload Too Large') - self.write_error(exception) + self._header_fragment += name - self.headers.append((name.decode().casefold(), value.decode())) + if value is not None: + if self._header_fragment == b'Content-Length' \ + and int(value) > self.request_max_size: + exception = PayloadTooLarge('Payload Too Large') + self.write_error(exception) + + self.headers.append( + (self._header_fragment.decode().casefold(), + value.decode())) + + self._header_fragment = b'' def on_headers_complete(self): self.request = self.request_class( @@ -357,6 +380,14 @@ class HttpProtocol(asyncio.Protocol): return True return False + def close(self): + """ + Force close the connection. + """ + if self.transport is not None: + self.transport.close() + self.transport = None + def update_current_time(loop): """Cache the current time, since it is needed at the end of every @@ -389,7 +420,8 @@ def serve(host, port, request_handler, error_handler, before_start=None, register_sys_signals=True, run_async=False, connections=None, signal=Signal(), request_class=None, has_log=True, keep_alive=True, is_request_stream=False, router=None, websocket_max_size=None, - websocket_max_queue=None): + websocket_max_queue=None, state=None, + graceful_shutdown_timeout=15.0): """Start asynchronous HTTP Server on an individual process. :param host: Address to host on @@ -427,8 +459,6 @@ def serve(host, port, request_handler, error_handler, before_start=None, if debug: loop.set_debug(debug) - trigger_events(before_start, loop) - connections = connections if connections is not None else set() server = partial( protocol, @@ -445,7 +475,9 @@ def serve(host, port, request_handler, error_handler, before_start=None, is_request_stream=is_request_stream, router=router, websocket_max_size=websocket_max_size, - websocket_max_queue=websocket_max_queue + websocket_max_queue=websocket_max_queue, + state=state, + debug=debug, ) server_coroutine = loop.create_server( @@ -457,6 +489,7 @@ def serve(host, port, request_handler, error_handler, before_start=None, sock=sock, backlog=backlog ) + # Instead of pulling time at the end of every request, # pull it once per minute loop.call_soon(partial(update_current_time, loop)) @@ -464,6 +497,8 @@ def serve(host, port, request_handler, error_handler, before_start=None, if run_async: return server_coroutine + trigger_events(before_start, loop) + try: http_server = loop.run_until_complete(server_coroutine) except: @@ -499,8 +534,26 @@ def serve(host, port, request_handler, error_handler, before_start=None, for connection in connections: connection.close_if_idle() - while connections: + # Gracefully shutdown timeout. + # We should provide graceful_shutdown_timeout, + # instead of letting connection hangs forever. + # Let's roughly calcucate time. + start_shutdown = 0 + while connections and (start_shutdown < graceful_shutdown_timeout): loop.run_until_complete(asyncio.sleep(0.1)) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append(conn.websocket.close_connection(force=True)) + else: + conn.close() + + _shutdown = asyncio.gather(*coros, loop=loop) + loop.run_until_complete(_shutdown) trigger_events(after_stop, loop) diff --git a/sanic/worker.py b/sanic/worker.py index 1d3e384b..9f950c34 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -3,6 +3,7 @@ import sys import signal import asyncio import logging +import traceback try: import ssl @@ -29,7 +30,7 @@ class GunicornWorker(base.Worker): self.ssl_context = self._create_ssl_context(cfg) else: self.ssl_context = None - self.servers = [] + self.servers = {} self.connections = set() self.exit_code = 0 self.signal = Signal() @@ -69,10 +70,16 @@ class GunicornWorker(base.Worker): trigger_events(self._server_settings.get('before_stop', []), self.loop) self.loop.run_until_complete(self.close()) + except: + traceback.print_exc() finally: - trigger_events(self._server_settings.get('after_stop', []), - self.loop) - self.loop.close() + try: + trigger_events(self._server_settings.get('after_stop', []), + self.loop) + except: + traceback.print_exc() + finally: + self.loop.close() sys.exit(self.exit_code) @@ -91,16 +98,37 @@ class GunicornWorker(base.Worker): for conn in self.connections: conn.close_if_idle() - while self.connections: + # gracefully shutdown timeout + start_shutdown = 0 + graceful_shutdown_timeout = self.cfg.graceful_timeout + while self.connections and \ + (start_shutdown < graceful_shutdown_timeout): await asyncio.sleep(0.1) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in self.connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append(conn.websocket.close_connection(force=True)) + else: + conn.close() + _shutdown = asyncio.gather(*coros, loop=self.loop) + await _shutdown async def _run(self): for sock in self.sockets: - self.servers.append(await serve( + state = dict(requests_count=0) + self._server_settings["host"] = None + self._server_settings["port"] = None + server = await serve( sock=sock, connections=self.connections, + state=state, **self._server_settings - )) + ) + self.servers[server] = state async def _check_alive(self): # If our parent changed then we shut down. @@ -109,7 +137,15 @@ class GunicornWorker(base.Worker): while self.alive: self.notify() - if pid == os.getpid() and self.ppid != os.getppid(): + req_count = sum( + self.servers[srv]["requests_count"] for srv in self.servers + ) + if self.max_requests and req_count > self.max_requests: + self.alive = False + self.log.info( + "Max requests exceeded, shutting down: %s", self + ) + elif pid == os.getpid() and self.ppid != os.getppid(): self.alive = False self.log.info("Parent changed, shutting down: %s", self) else: @@ -166,3 +202,4 @@ class GunicornWorker(base.Worker): self.alive = False self.exit_code = 1 self.cfg.worker_abort(self) + sys.exit(1) diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index 9ab387be..5cb356c2 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -1,16 +1,42 @@ import asyncio import inspect +import pytest from sanic import Sanic from sanic.blueprints import Blueprint from sanic.response import json, text from sanic.exceptions import NotFound, ServerError, InvalidUsage +from sanic.constants import HTTP_METHODS # ------------------------------------------------------------ # # GET # ------------------------------------------------------------ # +@pytest.mark.parametrize('method', HTTP_METHODS) +def test_versioned_routes_get(method): + app = Sanic('test_shorhand_routes_get') + bp = Blueprint('test_text') + + method = method.lower() + + func = getattr(bp, method) + if callable(func): + @func('/{}'.format(method), version=1) + def handler(request): + return text('OK') + else: + print(func) + raise + + app.blueprint(bp) + + client_method = getattr(app.test_client, method) + + request, response = client_method('/v1/{}'.format(method)) + assert response.status == 200 + + def test_bp(): app = Sanic('test_text') bp = Blueprint('test_text') diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index db1fc246..620e7891 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,7 +4,7 @@ from bs4 import BeautifulSoup from sanic import Sanic from sanic.response import text from sanic.exceptions import InvalidUsage, ServerError, NotFound, Unauthorized -from sanic.exceptions import abort +from sanic.exceptions import Forbidden, abort class SanicExceptionTestException(Exception): @@ -27,6 +27,10 @@ def exception_app(): def handler_404(request): raise NotFound("OK") + @app.route('/403') + def handler_403(request): + raise Forbidden("Forbidden") + @app.route('/401/basic') def handler_401_basic(request): raise Unauthorized("Unauthorized", "Basic", {"realm": "Sanic"}) @@ -113,6 +117,12 @@ def test_not_found_exception(exception_app): assert response.status == 404 +def test_forbidden_exception(exception_app): + """Test the built-in Forbidden exception""" + request, response = exception_app.test_client.get('/403') + assert response.status == 403 + + def test_unauthorized_exception(exception_app): """Test the built-in Unauthorized exception""" request, response = exception_app.test_client.get('/401/basic') diff --git a/tests/test_requests.py b/tests/test_requests.py index 2351a3b0..f0696c7f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -175,7 +175,7 @@ def test_token(): token = 'a1d895e0-553a-421a-8e22-5ff8ecb48cbf' headers = { 'content-type': 'application/json', - 'Authorization': 'Bearer Token {}'.format(token) + 'Authorization': 'Bearer {}'.format(token) } request, response = app.test_client.get('/', headers=headers) @@ -211,6 +211,32 @@ def test_content_type(): assert response.text == 'application/json' +def test_remote_addr(): + app = Sanic('test_content_type') + + @app.route('/') + async def handler(request): + return text(request.remote_addr) + + headers = { + 'X-Forwarded-For': '127.0.0.1, 127.0.1.2' + } + request, response = app.test_client.get('/', headers=headers) + assert request.remote_addr == '127.0.0.1' + assert response.text == '127.0.0.1' + + request, response = app.test_client.get('/') + assert request.remote_addr == '' + assert response.text == '' + + headers = { + 'X-Forwarded-For': '127.0.0.1, , ,,127.0.1.2' + } + request, response = app.test_client.get('/', headers=headers) + assert request.remote_addr == '127.0.0.1' + assert response.text == '127.0.0.1' + + def test_match_info(): app = Sanic('test_match_info') @@ -260,19 +286,26 @@ def test_post_form_urlencoded(): assert request.form.get('test') == 'OK' -def test_post_form_multipart_form_data(): +@pytest.mark.parametrize( + 'payload', [ + '------sanic\r\n' \ + 'Content-Disposition: form-data; name="test"\r\n' \ + '\r\n' \ + 'OK\r\n' \ + '------sanic--\r\n', + '------sanic\r\n' \ + 'content-disposition: form-data; name="test"\r\n' \ + '\r\n' \ + 'OK\r\n' \ + '------sanic--\r\n', + ]) +def test_post_form_multipart_form_data(payload): app = Sanic('test_post_form_multipart_form_data') @app.route('/', methods=['POST']) async def handler(request): return text('OK') - payload = '------sanic\r\n' \ - 'Content-Disposition: form-data; name="test"\r\n' \ - '\r\n' \ - 'OK\r\n' \ - '------sanic--\r\n' - headers = {'content-type': 'multipart/form-data; boundary=----sanic'} request, response = app.test_client.post(data=payload, headers=headers) diff --git a/tests/test_routes.py b/tests/test_routes.py index 4afb4a9c..04a682a0 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -4,12 +4,33 @@ import pytest from sanic import Sanic from sanic.response import text from sanic.router import RouteExists, RouteDoesNotExist +from sanic.constants import HTTP_METHODS # ------------------------------------------------------------ # # UTF-8 # ------------------------------------------------------------ # +@pytest.mark.parametrize('method', HTTP_METHODS) +def test_versioned_routes_get(method): + app = Sanic('test_shorhand_routes_get') + + method = method.lower() + + func = getattr(app, method) + if callable(func): + @func('/{}'.format(method), version=1) + def handler(request): + return text('OK') + else: + print(func) + raise + + client_method = getattr(app.test_client, method) + + request, response = client_method('/v1/{}'.format(method)) + assert response.status== 200 + def test_shorthand_routes_get(): app = Sanic('test_shorhand_routes_get') diff --git a/tests/test_worker.py b/tests/test_worker.py index 2c1a0123..e2b301ec 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -3,7 +3,11 @@ import json import shlex import subprocess import urllib.request - +from unittest import mock +from sanic.worker import GunicornWorker +from sanic.app import Sanic +import asyncio +import logging import pytest @@ -20,3 +24,112 @@ def test_gunicorn_worker(gunicorn_worker): with urllib.request.urlopen('http://localhost:1337/') as f: res = json.loads(f.read(100).decode()) assert res['test'] + + +class GunicornTestWorker(GunicornWorker): + + def __init__(self): + self.app = mock.Mock() + self.app.callable = Sanic("test_gunicorn_worker") + self.servers = {} + self.exit_code = 0 + self.cfg = mock.Mock() + self.notify = mock.Mock() + + +@pytest.fixture +def worker(): + return GunicornTestWorker() + + +def test_worker_init_process(worker): + with mock.patch('sanic.worker.asyncio') as mock_asyncio: + try: + worker.init_process() + except TypeError: + pass + + assert mock_asyncio.get_event_loop.return_value.close.called + assert mock_asyncio.new_event_loop.called + assert mock_asyncio.set_event_loop.called + + +def test_worker_init_signals(worker): + worker.loop = mock.Mock() + worker.init_signals() + assert worker.loop.add_signal_handler.called + + +def test_handle_abort(worker): + with mock.patch('sanic.worker.sys') as mock_sys: + worker.handle_abort(object(), object()) + assert not worker.alive + assert worker.exit_code == 1 + mock_sys.exit.assert_called_with(1) + + +def test_handle_quit(worker): + worker.handle_quit(object(), object()) + assert not worker.alive + assert worker.exit_code == 0 + + +def test_run_max_requests_exceeded(worker): + loop = asyncio.new_event_loop() + worker.ppid = 1 + worker.alive = True + sock = mock.Mock() + sock.cfg_addr = ('localhost', 8080) + worker.sockets = [sock] + worker.wsgi = mock.Mock() + worker.connections = set() + worker.log = mock.Mock() + worker.loop = loop + worker.servers = { + "server1": {"requests_count": 14}, + "server2": {"requests_count": 15}, + } + worker.max_requests = 10 + worker._run = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) + + # exceeding request count + _runner = asyncio.ensure_future(worker._check_alive(), loop=loop) + loop.run_until_complete(_runner) + + assert worker.alive == False + worker.notify.assert_called_with() + worker.log.info.assert_called_with("Max requests exceeded, shutting down: %s", + worker) + +def test_worker_close(worker): + loop = asyncio.new_event_loop() + asyncio.sleep = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) + worker.ppid = 1 + worker.pid = 2 + worker.cfg.graceful_timeout = 1.0 + worker.signal = mock.Mock() + worker.signal.stopped = False + worker.wsgi = mock.Mock() + conn = mock.Mock() + conn.websocket = mock.Mock() + conn.websocket.close_connection = mock.Mock( + wraps=asyncio.coroutine(lambda *a, **kw: None) + ) + worker.connections = set([conn]) + worker.log = mock.Mock() + worker.loop = loop + server = mock.Mock() + server.close = mock.Mock(wraps=lambda *a, **kw: None) + server.wait_closed = mock.Mock(wraps=asyncio.coroutine(lambda *a, **kw: None)) + worker.servers = { + server: {"requests_count": 14}, + } + worker.max_requests = 10 + + # close worker + _close = asyncio.ensure_future(worker.close(), loop=loop) + loop.run_until_complete(_close) + + assert worker.signal.stopped == True + conn.websocket.close_connection.assert_called_with(force=True) + assert len(worker.servers) == 0