Move serve_multiple, fix tests (#357)

* Move serve_multiple, remove stop_events, fix tests

Moves serve_multiple out of the app, removes stop_event (adds a
deprecation warning, but it also wasn't doing anything) fixes
multiprocessing tests so that they don't freeze pytest's runner.

Other notes:

Also moves around some imports so that they are better optimized as
well.

* Re-add in stop_event, maybe it wasn't so bad!

* Get rid of unused warnings import
This commit is contained in:
Eli Uriegas 2017-01-27 19:34:21 -06:00 committed by GitHub
parent fad9fbca6f
commit 59242df7d6
4 changed files with 94 additions and 144 deletions

View File

@ -1,22 +1,18 @@
import logging
from asyncio import get_event_loop from asyncio import get_event_loop
from collections import deque from collections import deque
from functools import partial from functools import partial
from inspect import isawaitable, stack, getmodulename from inspect import isawaitable, stack, getmodulename
from multiprocessing import Process, Event
from signal import signal, SIGTERM, SIGINT
from traceback import format_exc from traceback import format_exc
import logging
from .config import Config from .config import Config
from .exceptions import Handler from .exceptions import Handler
from .exceptions import ServerError
from .log import log from .log import log
from .response import HTTPResponse from .response import HTTPResponse
from .router import Router from .router import Router
from .server import serve, HttpProtocol from .server import serve, serve_multiple, HttpProtocol
from .static import register as static_register from .static import register as static_register
from .exceptions import ServerError
from socket import socket, SOL_SOCKET, SO_REUSEADDR
from os import set_inheritable
class Sanic: class Sanic:
@ -358,9 +354,7 @@ class Sanic:
if workers == 1: if workers == 1:
serve(**server_settings) serve(**server_settings)
else: else:
log.info('Spinning up {} workers...'.format(workers)) serve_multiple(server_settings, workers, stop_event)
self.serve_multiple(server_settings, workers, stop_event)
except Exception as e: except Exception as e:
log.exception( log.exception(
@ -369,13 +363,7 @@ class Sanic:
log.info("Server Stopped") log.info("Server Stopped")
def stop(self): def stop(self):
""" """This kills the Sanic"""
This kills the Sanic
"""
if self.processes is not None:
for process in self.processes:
process.terminate()
self.sock.close()
get_event_loop().stop() get_event_loop().stop()
async def create_server(self, host="127.0.0.1", port=8000, debug=False, async def create_server(self, host="127.0.0.1", port=8000, debug=False,
@ -414,8 +402,7 @@ class Sanic:
("before_server_start", "before_start", before_start, False), ("before_server_start", "before_start", before_start, False),
("after_server_start", "after_start", after_start, False), ("after_server_start", "after_start", after_start, False),
("before_server_stop", "before_stop", before_stop, True), ("before_server_stop", "before_stop", before_stop, True),
("after_server_stop", "after_stop", after_stop, True), ("after_server_stop", "after_stop", after_stop, True)):
):
listeners = [] listeners = []
for blueprint in self.blueprints.values(): for blueprint in self.blueprints.values():
listeners += blueprint.listeners[event_name] listeners += blueprint.listeners[event_name]
@ -438,46 +425,3 @@ class Sanic:
log.info('Goin\' Fast @ {}://{}:{}'.format(proto, host, port)) log.info('Goin\' Fast @ {}://{}:{}'.format(proto, host, port))
return await serve(**server_settings) return await serve(**server_settings)
def serve_multiple(self, server_settings, workers, stop_event=None):
"""
Starts multiple server processes simultaneously. Stops on interrupt
and terminate signals, and drains connections when complete.
:param server_settings: kw arguments to be passed to the serve function
:param workers: number of workers to launch
:param stop_event: if provided, is used as a stop signal
:return:
"""
if server_settings.get('loop', None) is not None:
log.warning("Passing a loop will be deprecated in version 0.4.0"
" https://github.com/channelcat/sanic/pull/335"
" has more information.", DeprecationWarning)
server_settings['reuse_port'] = True
# Create a stop event to be triggered by a signal
if stop_event is None:
stop_event = Event()
signal(SIGINT, lambda s, f: stop_event.set())
signal(SIGTERM, lambda s, f: stop_event.set())
self.sock = socket()
self.sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
self.sock.bind((server_settings['host'], server_settings['port']))
set_inheritable(self.sock.fileno(), True)
server_settings['sock'] = self.sock
server_settings['host'] = None
server_settings['port'] = None
self.processes = []
for _ in range(workers):
process = Process(target=serve, kwargs=server_settings)
process.daemon = True
process.start()
self.processes.append(process)
for process in self.processes:
process.join()
# the above processes will block this until they're stopped
self.stop()

View File

@ -1,11 +1,18 @@
import asyncio import asyncio
import os
import traceback import traceback
from functools import partial from functools import partial
from inspect import isawaitable from inspect import isawaitable
from signal import SIGINT, SIGTERM from multiprocessing import Process, Event
from os import set_inheritable
from signal import SIGTERM, SIGINT
from signal import signal as signal_func
from socket import socket, SOL_SOCKET, SO_REUSEADDR
from time import time from time import time
from httptools import HttpRequestParser from httptools import HttpRequestParser
from httptools.parser.errors import HttpParserError from httptools.parser.errors import HttpParserError
from .exceptions import ServerError from .exceptions import ServerError
try: try:
@ -17,7 +24,6 @@ from .log import log
from .request import Request from .request import Request
from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage from .exceptions import RequestTimeout, PayloadTooLarge, InvalidUsage
current_time = None current_time = None
@ -31,6 +37,7 @@ class CIDict(dict):
This does not maintain the inputted case when calling items() or keys() This does not maintain the inputted case when calling items() or keys()
in favor of speed, since headers are case insensitive in favor of speed, since headers are case insensitive
""" """
def get(self, key, default=None): def get(self, key, default=None):
return super().get(key.casefold(), default) return super().get(key.casefold(), default)
@ -56,7 +63,7 @@ class HttpProtocol(asyncio.Protocol):
'_total_request_size', '_timeout_handler', '_last_communication_time') '_total_request_size', '_timeout_handler', '_last_communication_time')
def __init__(self, *, loop, request_handler, error_handler, def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections={}, request_timeout=60, signal=Signal(), connections=set(), request_timeout=60,
request_max_size=None): request_max_size=None):
self.loop = loop self.loop = loop
self.transport = None self.transport = None
@ -328,7 +335,7 @@ def serve(host, port, request_handler, error_handler, before_start=None,
try: try:
http_server = loop.run_until_complete(server_coroutine) http_server = loop.run_until_complete(server_coroutine)
except Exception: except:
log.exception("Unable to start server") log.exception("Unable to start server")
return return
@ -339,10 +346,12 @@ def serve(host, port, request_handler, error_handler, before_start=None,
for _signal in (SIGINT, SIGTERM): for _signal in (SIGINT, SIGTERM):
loop.add_signal_handler(_signal, loop.stop) loop.add_signal_handler(_signal, loop.stop)
pid = os.getpid()
try: try:
log.info('Starting worker [{}]'.format(pid))
loop.run_forever() loop.run_forever()
finally: finally:
log.info("Stop requested, draining connections...") log.info("Stopping worker [{}]".format(pid))
# Run the on_stop function if provided # Run the on_stop function if provided
trigger_events(before_stop, loop) trigger_events(before_stop, loop)
@ -362,3 +371,51 @@ def serve(host, port, request_handler, error_handler, before_start=None,
trigger_events(after_stop, loop) trigger_events(after_stop, loop)
loop.close() loop.close()
def serve_multiple(server_settings, workers, stop_event=None):
"""
Starts multiple server processes simultaneously. Stops on interrupt
and terminate signals, and drains connections when complete.
:param server_settings: kw arguments to be passed to the serve function
:param workers: number of workers to launch
:param stop_event: if provided, is used as a stop signal
:return:
"""
if server_settings.get('loop', None) is not None:
log.warning("Passing a loop will be deprecated in version 0.4.0"
" https://github.com/channelcat/sanic/pull/335"
" has more information.", DeprecationWarning)
server_settings['reuse_port'] = True
sock = socket()
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
sock.bind((server_settings['host'], server_settings['port']))
set_inheritable(sock.fileno(), True)
server_settings['sock'] = sock
server_settings['host'] = None
server_settings['port'] = None
if stop_event is None:
stop_event = Event()
signal_func(SIGINT, lambda s, f: stop_event.set())
signal_func(SIGTERM, lambda s, f: stop_event.set())
processes = []
for _ in range(workers):
process = Process(target=serve, kwargs=server_settings)
process.daemon = True
process.start()
processes.append(process)
for process in processes:
process.join()
# the above processes will block this until they're stopped
for process in processes:
process.terminate()
sock.close()
asyncio.get_event_loop().stop()

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import uuid
from sanic.response import text from sanic.response import text
from sanic import Sanic from sanic import Sanic
from io import StringIO from io import StringIO
@ -9,10 +10,11 @@ logging_format = '''module: %(module)s; \
function: %(funcName)s(); \ function: %(funcName)s(); \
message: %(message)s''' message: %(message)s'''
def test_log(): def test_log():
log_stream = StringIO() log_stream = StringIO()
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
logging.basicConfig( logging.basicConfig(
format=logging_format, format=logging_format,
level=logging.DEBUG, level=logging.DEBUG,
@ -20,14 +22,16 @@ def test_log():
) )
log = logging.getLogger() log = logging.getLogger()
app = Sanic('test_logging') app = Sanic('test_logging')
rand_string = str(uuid.uuid4())
@app.route('/') @app.route('/')
def handler(request): def handler(request):
log.info('hello world') log.info(rand_string)
return text('hello') return text('hello')
request, response = sanic_endpoint_test(app) request, response = sanic_endpoint_test(app)
log_text = log_stream.getvalue().strip().split('\n')[-3] log_text = log_stream.getvalue()
assert log_text == "module: test_logging; function: handler(); message: hello world" assert rand_string in log_text
if __name__ =="__main__": if __name__ == "__main__":
test_log() test_log()

View File

@ -1,81 +1,26 @@
from multiprocessing import Array, Event, Process import multiprocessing
from time import sleep, time import random
from ujson import loads as json_loads import signal
import pytest
from sanic import Sanic from sanic import Sanic
from sanic.response import json from sanic.utils import HOST, PORT
from sanic.utils import local_request, HOST, PORT
# ------------------------------------------------------------ #
# GET
# ------------------------------------------------------------ #
# TODO: Figure out why this freezes on pytest but not when
# executed via interpreter
@pytest.mark.skip(
reason="Freezes with pytest not on interpreter")
def test_multiprocessing(): def test_multiprocessing():
app = Sanic('test_json') """Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
app = Sanic('test_multiprocessing')
process_list = set()
response = Array('c', 50) def stop_on_alarm(*args):
@app.route('/') for process in multiprocessing.active_children():
async def handler(request): process_list.add(process.pid)
return json({"test": True}) process.terminate()
stop_event = Event() signal.signal(signal.SIGALRM, stop_on_alarm)
async def after_start(*args, **kwargs): signal.alarm(1)
http_response = await local_request('get', '/') app.run(HOST, PORT, workers=num_workers)
response.value = http_response.text.encode()
stop_event.set()
def rescue_crew(): assert len(process_list) == num_workers
sleep(5)
stop_event.set()
rescue_process = Process(target=rescue_crew)
rescue_process.start()
app.serve_multiple({
'host': HOST,
'port': PORT,
'after_start': after_start,
'request_handler': app.handle_request,
'request_max_size': 100000,
}, workers=2, stop_event=stop_event)
rescue_process.terminate()
try:
results = json_loads(response.value)
except:
raise ValueError("Expected JSON response but got '{}'".format(response))
assert results.get('test') == True
@pytest.mark.skip(
reason="Freezes with pytest not on interpreter")
def test_drain_connections():
app = Sanic('test_json')
@app.route('/')
async def handler(request):
return json({"test": True})
stop_event = Event()
async def after_start(*args, **kwargs):
http_response = await local_request('get', '/')
stop_event.set()
start = time()
app.serve_multiple({
'host': HOST,
'port': PORT,
'after_start': after_start,
'request_handler': app.handle_request,
}, workers=2, stop_event=stop_event)
end = time()
assert end - start < 0.05