diff --git a/docs/blueprints.md b/docs/blueprints.md index 7a4567ee..e80a125e 100644 --- a/docs/blueprints.md +++ b/docs/blueprints.md @@ -80,3 +80,26 @@ Exceptions can also be applied exclusively to blueprints globally. def ignore_404s(request, exception): return text("Yep, I totally found the page: {}".format(request.url)) ``` + +## Start and Stop +Blueprints and run functions during the start and stop process of the server. +If running in multiprocessor mode (more than 1 worker), these are triggered +Available events are: + + * before_server_start - Executed before the server begins to accept connections + * after_server_start - Executed after the server begins to accept connections + * before_server_stop - Executed before the server stops accepting connections + * after_server_stop - Executed after the server is stopped and all requests are complete + +```python +bp = Blueprint('my_blueprint') + +@bp.listen('before_server_start') +async def setup_connection(): + global database + database = mysql.connect(host='127.0.0.1'...) + +@bp.listen('after_server_stop') +async def close_connection(): + await database.close() +``` diff --git a/sanic/blueprints.py b/sanic/blueprints.py index f1aa2afc..37cfa1c3 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -1,3 +1,6 @@ +from collections import defaultdict + + class BlueprintSetup: """ """ @@ -22,7 +25,7 @@ class BlueprintSetup: if self.url_prefix: uri = self.url_prefix + uri - self.app.router.add(uri, methods, handler) + self.app.route(uri=uri, methods=methods)(handler) def add_exception(self, handler, *args, **kwargs): """ @@ -42,9 +45,15 @@ class BlueprintSetup: class Blueprint: def __init__(self, name, url_prefix=None): + """ + Creates a new blueprint + :param name: Unique name of the blueprint + :param url_prefix: URL to be prefixed before all route URLs + """ self.name = name self.url_prefix = url_prefix self.deferred_functions = [] + self.listeners = defaultdict(list) def record(self, func): """ @@ -73,6 +82,14 @@ class Blueprint: return handler return decorator + def listener(self, event): + """ + """ + def decorator(listener): + self.listeners[event].append(listener) + return listener + return decorator + def middleware(self, *args, **kwargs): """ """ diff --git a/sanic/sanic.py b/sanic/sanic.py index 310a2175..f8189d77 100644 --- a/sanic/sanic.py +++ b/sanic/sanic.py @@ -1,4 +1,5 @@ from asyncio import get_event_loop +from functools import partial from inspect import isawaitable from multiprocessing import Process, Event from signal import signal, SIGTERM, SIGINT @@ -24,6 +25,8 @@ class Sanic: self.response_middleware = [] self.blueprints = {} self._blueprint_order = [] + self.loop = None + self.debug = None # -------------------------------------------------------------------- # # Registration @@ -71,7 +74,7 @@ class Sanic: if attach_to == 'request': self.request_middleware.append(middleware) if attach_to == 'response': - self.response_middleware.append(middleware) + self.response_middleware.insert(0, middleware) return middleware # Detect which way this was called, @middleware or @middleware('AT') @@ -102,6 +105,9 @@ class Sanic: # Request Handling # -------------------------------------------------------------------- # + def converted_response_type(self, response): + pass + async def handle_request(self, request, response_callback): """ Takes a request from the HTTP Server and returns a response object to @@ -113,7 +119,10 @@ class Sanic: :return: Nothing """ try: - # Middleware process_request + # -------------------------------------------- # + # Request Middleware + # -------------------------------------------- # + response = False # The if improves speed. I don't know why if self.request_middleware: @@ -126,6 +135,10 @@ class Sanic: # No middleware results if not response: + # -------------------------------------------- # + # Execute Handler + # -------------------------------------------- # + # Fetch handler from router handler, args, kwargs = self.router.get(request) if handler is None: @@ -138,7 +151,10 @@ class Sanic: if isawaitable(response): response = await response - # Middleware process_response + # -------------------------------------------- # + # Response Middleware + # -------------------------------------------- # + if self.response_middleware: for middleware in self.response_middleware: _response = middleware(request, response) @@ -149,6 +165,10 @@ class Sanic: break except Exception as e: + # -------------------------------------------- # + # Response Generation Failed + # -------------------------------------------- # + try: response = self.error_handler.response(request, e) if isawaitable(response): @@ -168,18 +188,23 @@ class Sanic: # Execution # -------------------------------------------------------------------- # - def run(self, host="127.0.0.1", port=8000, debug=False, after_start=None, - before_stop=None, sock=None, workers=1, loop=None): + def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None, + after_start=None, before_stop=None, after_stop=None, sock=None, + workers=1, loop=None): """ Runs the HTTP Server and listens until keyboard interrupt or term signal. On termination, drains connections before closing. :param host: Address to host on :param port: Port to host on :param debug: Enables debug output (slows server) + :param before_start: Function to be executed before the server starts + accepting connections :param after_start: Function to be executed after the server starts - listening + accepting connections :param before_stop: Function to be executed when a stop signal is received before it is respected + :param after_stop: Function to be executed when all requests are + complete :param sock: Socket for the server to accept connections from :param workers: Number of processes received before it is respected @@ -188,6 +213,7 @@ class Sanic: """ self.error_handler.debug = True self.debug = debug + self.loop = loop server_settings = { 'host': host, @@ -197,11 +223,32 @@ class Sanic: 'request_handler': self.handle_request, 'request_timeout': self.config.REQUEST_TIMEOUT, 'request_max_size': self.config.REQUEST_MAX_SIZE, - 'after_start': after_start, - 'before_stop': before_stop, 'loop': loop } + # -------------------------------------------- # + # Register start/stop events + # -------------------------------------------- # + + for event_name, settings_name, args, reverse in ( + ("before_server_start", "before_start", before_start, False), + ("after_server_start", "after_start", after_start, False), + ("before_server_stop", "before_stop", before_stop, True), + ("after_server_stop", "after_stop", after_stop, True), + ): + listeners = [] + for blueprint in self.blueprints.values(): + listeners += blueprint.listeners[event_name] + if args: + if type(args) is not list: + args = [args] + listeners += args + if reverse: + listeners.reverse() + # Prepend sanic to the arguments when listeners are triggered + listeners = [partial(listener, self) for listener in listeners] + server_settings[settings_name] = listeners + if debug: log.setLevel(logging.DEBUG) log.debug(self.config.LOGO) diff --git a/sanic/server.py b/sanic/server.py index 63563269..eddff48c 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -157,7 +157,22 @@ class HttpProtocol(asyncio.Protocol): return False -def serve(host, port, request_handler, after_start=None, before_stop=None, +def trigger_events(events, loop): + """ + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + if events: + if type(events) is not list: + events = [events] + for event in events: + result = event(loop) + if isawaitable(result): + loop.run_until_complete(result) + + +def serve(host, port, request_handler, before_start=None, after_start=None, + before_stop=None, after_stop=None, debug=False, request_timeout=60, sock=None, request_max_size=None, reuse_port=False, loop=None): """ @@ -183,6 +198,8 @@ def serve(host, port, request_handler, after_start=None, before_stop=None, if debug: loop.set_debug(debug) + trigger_events(before_start, loop) + connections = {} signal = Signal() server_coroutine = loop.create_server(lambda: HttpProtocol( @@ -193,17 +210,14 @@ def serve(host, port, request_handler, after_start=None, before_stop=None, request_timeout=request_timeout, request_max_size=request_max_size, ), host, port, reuse_port=reuse_port, sock=sock) + try: http_server = loop.run_until_complete(server_coroutine) except Exception as e: log.exception("Unable to start server") return - # Run the on_start function if provided - if after_start: - result = after_start(loop) - if isawaitable(result): - loop.run_until_complete(result) + trigger_events(after_start, loop) # Register signals for graceful termination for _signal in (SIGINT, SIGTERM): @@ -215,10 +229,7 @@ def serve(host, port, request_handler, after_start=None, before_stop=None, log.info("Stop requested, draining connections...") # Run the on_stop function if provided - if before_stop: - result = before_stop(loop) - if isawaitable(result): - loop.run_until_complete(result) + trigger_events(before_stop, loop) # Wait for event loop to finish and all connections to drain http_server.close() @@ -232,4 +243,6 @@ def serve(host, port, request_handler, after_start=None, before_stop=None, while connections: loop.run_until_complete(asyncio.sleep(0.1)) + trigger_events(after_stop, loop) + loop.close() diff --git a/sanic/utils.py b/sanic/utils.py index c39f03ab..e731112e 100644 --- a/sanic/utils.py +++ b/sanic/utils.py @@ -24,7 +24,7 @@ def sanic_endpoint_test(app, method='get', uri='/', gather_request=True, def _collect_request(request): results.append(request) - async def _collect_response(loop): + async def _collect_response(sanic, loop): try: response = await local_request(method, uri, *request_args, **request_kwargs) diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index 8068160f..1b88795d 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -108,4 +108,40 @@ def test_bp_exception_handler(): assert response.text == 'OK' request, response = sanic_endpoint_test(app, uri='/3') - assert response.status == 200 \ No newline at end of file + assert response.status == 200 + +def test_bp_listeners(): + app = Sanic('test_middleware') + blueprint = Blueprint('test_middleware') + + order = [] + + @blueprint.listener('before_server_start') + def handler_1(sanic, loop): + order.append(1) + + @blueprint.listener('after_server_start') + def handler_2(sanic, loop): + order.append(2) + + @blueprint.listener('after_server_start') + def handler_3(sanic, loop): + order.append(3) + + @blueprint.listener('before_server_stop') + def handler_4(sanic, loop): + order.append(5) + + @blueprint.listener('before_server_stop') + def handler_5(sanic, loop): + order.append(4) + + @blueprint.listener('after_server_stop') + def handler_6(sanic, loop): + order.append(6) + + app.register_blueprint(blueprint) + + request, response = sanic_endpoint_test(app, uri='/') + + assert order == [1,2,3,4,5,6] \ No newline at end of file