Added server start/stop listeners and reverse ordering on response middleware to blueprints

This commit is contained in:
Channel Cat 2016-10-21 04:11:18 -07:00
parent 8f6e5a1263
commit a5614f6880
6 changed files with 157 additions and 21 deletions

View File

@ -80,3 +80,26 @@ Exceptions can also be applied exclusively to blueprints globally.
def ignore_404s(request, exception): def ignore_404s(request, exception):
return text("Yep, I totally found the page: {}".format(request.url)) return text("Yep, I totally found the page: {}".format(request.url))
``` ```
## Start and Stop
Blueprints and run functions during the start and stop process of the server.
If running in multiprocessor mode (more than 1 worker), these are triggered
Available events are:
* before_server_start - Executed before the server begins to accept connections
* after_server_start - Executed after the server begins to accept connections
* before_server_stop - Executed before the server stops accepting connections
* after_server_stop - Executed after the server is stopped and all requests are complete
```python
bp = Blueprint('my_blueprint')
@bp.listen('before_server_start')
async def setup_connection():
global database
database = mysql.connect(host='127.0.0.1'...)
@bp.listen('after_server_stop')
async def close_connection():
await database.close()
```

View File

@ -1,3 +1,6 @@
from collections import defaultdict
class BlueprintSetup: class BlueprintSetup:
""" """
""" """
@ -22,7 +25,7 @@ class BlueprintSetup:
if self.url_prefix: if self.url_prefix:
uri = self.url_prefix + uri uri = self.url_prefix + uri
self.app.router.add(uri, methods, handler) self.app.route(uri=uri, methods=methods)(handler)
def add_exception(self, handler, *args, **kwargs): def add_exception(self, handler, *args, **kwargs):
""" """
@ -42,9 +45,15 @@ class BlueprintSetup:
class Blueprint: class Blueprint:
def __init__(self, name, url_prefix=None): def __init__(self, name, url_prefix=None):
"""
Creates a new blueprint
:param name: Unique name of the blueprint
:param url_prefix: URL to be prefixed before all route URLs
"""
self.name = name self.name = name
self.url_prefix = url_prefix self.url_prefix = url_prefix
self.deferred_functions = [] self.deferred_functions = []
self.listeners = defaultdict(list)
def record(self, func): def record(self, func):
""" """
@ -73,6 +82,14 @@ class Blueprint:
return handler return handler
return decorator return decorator
def listener(self, event):
"""
"""
def decorator(listener):
self.listeners[event].append(listener)
return listener
return decorator
def middleware(self, *args, **kwargs): def middleware(self, *args, **kwargs):
""" """
""" """

View File

@ -1,4 +1,5 @@
from asyncio import get_event_loop from asyncio import get_event_loop
from functools import partial
from inspect import isawaitable from inspect import isawaitable
from multiprocessing import Process, Event from multiprocessing import Process, Event
from signal import signal, SIGTERM, SIGINT from signal import signal, SIGTERM, SIGINT
@ -24,6 +25,8 @@ class Sanic:
self.response_middleware = [] self.response_middleware = []
self.blueprints = {} self.blueprints = {}
self._blueprint_order = [] self._blueprint_order = []
self.loop = None
self.debug = None
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Registration # Registration
@ -71,7 +74,7 @@ class Sanic:
if attach_to == 'request': if attach_to == 'request':
self.request_middleware.append(middleware) self.request_middleware.append(middleware)
if attach_to == 'response': if attach_to == 'response':
self.response_middleware.append(middleware) self.response_middleware.insert(0, middleware)
return middleware return middleware
# Detect which way this was called, @middleware or @middleware('AT') # Detect which way this was called, @middleware or @middleware('AT')
@ -102,6 +105,9 @@ class Sanic:
# Request Handling # Request Handling
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
def converted_response_type(self, response):
pass
async def handle_request(self, request, response_callback): async def handle_request(self, request, response_callback):
""" """
Takes a request from the HTTP Server and returns a response object to Takes a request from the HTTP Server and returns a response object to
@ -113,7 +119,10 @@ class Sanic:
:return: Nothing :return: Nothing
""" """
try: try:
# Middleware process_request # -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = False response = False
# The if improves speed. I don't know why # The if improves speed. I don't know why
if self.request_middleware: if self.request_middleware:
@ -126,6 +135,10 @@ class Sanic:
# No middleware results # No middleware results
if not response: if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #
# Fetch handler from router # Fetch handler from router
handler, args, kwargs = self.router.get(request) handler, args, kwargs = self.router.get(request)
if handler is None: if handler is None:
@ -138,7 +151,10 @@ class Sanic:
if isawaitable(response): if isawaitable(response):
response = await response response = await response
# Middleware process_response # -------------------------------------------- #
# Response Middleware
# -------------------------------------------- #
if self.response_middleware: if self.response_middleware:
for middleware in self.response_middleware: for middleware in self.response_middleware:
_response = middleware(request, response) _response = middleware(request, response)
@ -149,6 +165,10 @@ class Sanic:
break break
except Exception as e: except Exception as e:
# -------------------------------------------- #
# Response Generation Failed
# -------------------------------------------- #
try: try:
response = self.error_handler.response(request, e) response = self.error_handler.response(request, e)
if isawaitable(response): if isawaitable(response):
@ -168,18 +188,23 @@ class Sanic:
# Execution # Execution
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
def run(self, host="127.0.0.1", port=8000, debug=False, after_start=None, def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None,
before_stop=None, sock=None, workers=1, loop=None): after_start=None, before_stop=None, after_stop=None, sock=None,
workers=1, loop=None):
""" """
Runs the HTTP Server and listens until keyboard interrupt or term Runs the HTTP Server and listens until keyboard interrupt or term
signal. On termination, drains connections before closing. signal. On termination, drains connections before closing.
:param host: Address to host on :param host: Address to host on
:param port: Port to host on :param port: Port to host on
:param debug: Enables debug output (slows server) :param debug: Enables debug output (slows server)
:param before_start: Function to be executed before the server starts
accepting connections
:param after_start: Function to be executed after the server starts :param after_start: Function to be executed after the server starts
listening accepting connections
:param before_stop: Function to be executed when a stop signal is :param before_stop: Function to be executed when a stop signal is
received before it is respected received before it is respected
:param after_stop: Function to be executed when all requests are
complete
:param sock: Socket for the server to accept connections from :param sock: Socket for the server to accept connections from
:param workers: Number of processes :param workers: Number of processes
received before it is respected received before it is respected
@ -188,6 +213,7 @@ class Sanic:
""" """
self.error_handler.debug = True self.error_handler.debug = True
self.debug = debug self.debug = debug
self.loop = loop
server_settings = { server_settings = {
'host': host, 'host': host,
@ -197,11 +223,32 @@ class Sanic:
'request_handler': self.handle_request, 'request_handler': self.handle_request,
'request_timeout': self.config.REQUEST_TIMEOUT, 'request_timeout': self.config.REQUEST_TIMEOUT,
'request_max_size': self.config.REQUEST_MAX_SIZE, 'request_max_size': self.config.REQUEST_MAX_SIZE,
'after_start': after_start,
'before_stop': before_stop,
'loop': loop 'loop': loop
} }
# -------------------------------------------- #
# Register start/stop events
# -------------------------------------------- #
for event_name, settings_name, args, reverse in (
("before_server_start", "before_start", before_start, False),
("after_server_start", "after_start", after_start, False),
("before_server_stop", "before_stop", before_stop, True),
("after_server_stop", "after_stop", after_stop, True),
):
listeners = []
for blueprint in self.blueprints.values():
listeners += blueprint.listeners[event_name]
if args:
if type(args) is not list:
args = [args]
listeners += args
if reverse:
listeners.reverse()
# Prepend sanic to the arguments when listeners are triggered
listeners = [partial(listener, self) for listener in listeners]
server_settings[settings_name] = listeners
if debug: if debug:
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
log.debug(self.config.LOGO) log.debug(self.config.LOGO)

View File

@ -157,7 +157,22 @@ class HttpProtocol(asyncio.Protocol):
return False return False
def serve(host, port, request_handler, after_start=None, before_stop=None, def trigger_events(events, loop):
"""
:param events: one or more sync or async functions to execute
:param loop: event loop
"""
if events:
if type(events) is not list:
events = [events]
for event in events:
result = event(loop)
if isawaitable(result):
loop.run_until_complete(result)
def serve(host, port, request_handler, before_start=None, after_start=None,
before_stop=None, after_stop=None,
debug=False, request_timeout=60, sock=None, debug=False, request_timeout=60, sock=None,
request_max_size=None, reuse_port=False, loop=None): request_max_size=None, reuse_port=False, loop=None):
""" """
@ -183,6 +198,8 @@ def serve(host, port, request_handler, after_start=None, before_stop=None,
if debug: if debug:
loop.set_debug(debug) loop.set_debug(debug)
trigger_events(before_start, loop)
connections = {} connections = {}
signal = Signal() signal = Signal()
server_coroutine = loop.create_server(lambda: HttpProtocol( server_coroutine = loop.create_server(lambda: HttpProtocol(
@ -193,17 +210,14 @@ def serve(host, port, request_handler, after_start=None, before_stop=None,
request_timeout=request_timeout, request_timeout=request_timeout,
request_max_size=request_max_size, request_max_size=request_max_size,
), host, port, reuse_port=reuse_port, sock=sock) ), host, port, reuse_port=reuse_port, sock=sock)
try: try:
http_server = loop.run_until_complete(server_coroutine) http_server = loop.run_until_complete(server_coroutine)
except Exception as e: except Exception as e:
log.exception("Unable to start server") log.exception("Unable to start server")
return return
# Run the on_start function if provided trigger_events(after_start, loop)
if after_start:
result = after_start(loop)
if isawaitable(result):
loop.run_until_complete(result)
# Register signals for graceful termination # Register signals for graceful termination
for _signal in (SIGINT, SIGTERM): for _signal in (SIGINT, SIGTERM):
@ -215,10 +229,7 @@ def serve(host, port, request_handler, after_start=None, before_stop=None,
log.info("Stop requested, draining connections...") log.info("Stop requested, draining connections...")
# Run the on_stop function if provided # Run the on_stop function if provided
if before_stop: trigger_events(before_stop, loop)
result = before_stop(loop)
if isawaitable(result):
loop.run_until_complete(result)
# Wait for event loop to finish and all connections to drain # Wait for event loop to finish and all connections to drain
http_server.close() http_server.close()
@ -232,4 +243,6 @@ def serve(host, port, request_handler, after_start=None, before_stop=None,
while connections: while connections:
loop.run_until_complete(asyncio.sleep(0.1)) loop.run_until_complete(asyncio.sleep(0.1))
trigger_events(after_stop, loop)
loop.close() loop.close()

View File

@ -24,7 +24,7 @@ def sanic_endpoint_test(app, method='get', uri='/', gather_request=True,
def _collect_request(request): def _collect_request(request):
results.append(request) results.append(request)
async def _collect_response(loop): async def _collect_response(sanic, loop):
try: try:
response = await local_request(method, uri, *request_args, response = await local_request(method, uri, *request_args,
**request_kwargs) **request_kwargs)

View File

@ -109,3 +109,39 @@ def test_bp_exception_handler():
request, response = sanic_endpoint_test(app, uri='/3') request, response = sanic_endpoint_test(app, uri='/3')
assert response.status == 200 assert response.status == 200
def test_bp_listeners():
app = Sanic('test_middleware')
blueprint = Blueprint('test_middleware')
order = []
@blueprint.listener('before_server_start')
def handler_1(sanic, loop):
order.append(1)
@blueprint.listener('after_server_start')
def handler_2(sanic, loop):
order.append(2)
@blueprint.listener('after_server_start')
def handler_3(sanic, loop):
order.append(3)
@blueprint.listener('before_server_stop')
def handler_4(sanic, loop):
order.append(5)
@blueprint.listener('before_server_stop')
def handler_5(sanic, loop):
order.append(4)
@blueprint.listener('after_server_stop')
def handler_6(sanic, loop):
order.append(6)
app.register_blueprint(blueprint)
request, response = sanic_endpoint_test(app, uri='/')
assert order == [1,2,3,4,5,6]