diff --git a/sanic/sanic.py b/sanic/sanic.py index a3f49197..3dab3e47 100644 --- a/sanic/sanic.py +++ b/sanic/sanic.py @@ -3,7 +3,6 @@ from collections import deque from functools import partial from inspect import isawaitable, stack, getmodulename from multiprocessing import Process, Event -from select import select from signal import signal, SIGTERM, SIGINT from traceback import format_exc import logging @@ -16,6 +15,8 @@ from .router import Router from .server import serve, HttpProtocol from .static import register as static_register from .exceptions import ServerError +from socket import socket, SOL_SOCKET, SO_REUSEADDR +from os import set_inheritable class Sanic: @@ -39,6 +40,8 @@ class Sanic: self._blueprint_order = [] self.loop = None self.debug = None + self.sock = None + self.processes = None # Register alternative method names self.go_fast = self.run @@ -242,7 +245,8 @@ class Sanic: def run(self, host="127.0.0.1", port=8000, debug=False, before_start=None, after_start=None, before_stop=None, after_stop=None, sock=None, - workers=1, loop=None, protocol=HttpProtocol, backlog=100): + workers=1, loop=None, protocol=HttpProtocol, backlog=100, + stop_event=None): """ Runs the HTTP Server and listens until keyboard interrupt or term signal. On termination, drains connections before closing. @@ -318,7 +322,7 @@ class Sanic: else: log.info('Spinning up {} workers...'.format(workers)) - self.serve_multiple(server_settings, workers) + self.serve_multiple(server_settings, workers, stop_event) except Exception as e: log.exception( @@ -330,10 +334,13 @@ class Sanic: """ This kills the Sanic """ + if self.processes is not None: + for process in self.processes: + process.terminate() + self.sock.close() get_event_loop().stop() - @staticmethod - def serve_multiple(server_settings, workers, stop_event=None): + def serve_multiple(self, server_settings, workers, stop_event=None): """ Starts multiple server processes simultaneously. Stops on interrupt and terminate signals, and drains connections when complete. @@ -345,25 +352,28 @@ class Sanic: server_settings['reuse_port'] = True # Create a stop event to be triggered by a signal - if not stop_event: + if stop_event is None: stop_event = Event() signal(SIGINT, lambda s, f: stop_event.set()) signal(SIGTERM, lambda s, f: stop_event.set()) - processes = [] + self.sock = socket() + self.sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + self.sock.bind((server_settings['host'], server_settings['port'])) + set_inheritable(self.sock.fileno(), True) + server_settings['sock'] = self.sock + server_settings['host'] = None + server_settings['port'] = None + + self.processes = [] for _ in range(workers): process = Process(target=serve, kwargs=server_settings) + process.daemon = True process.start() - processes.append(process) + self.processes.append(process) - # Infinitely wait for the stop event - try: - select(stop_event) - except: - pass - - log.info('Spinning down workers...') - for process in processes: - process.terminate() - for process in processes: + for process in self.processes: process.join() + + # the above processes will block this until they're stopped + self.stop() diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index cc967ef1..e39c3d24 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -2,6 +2,8 @@ from multiprocessing import Array, Event, Process from time import sleep, time from ujson import loads as json_loads +import pytest + from sanic import Sanic from sanic.response import json from sanic.utils import local_request, HOST, PORT @@ -13,8 +15,9 @@ from sanic.utils import local_request, HOST, PORT # TODO: Figure out why this freezes on pytest but not when # executed via interpreter - -def skip_test_multiprocessing(): +@pytest.mark.skip( + reason="Freezes with pytest not on interpreter") +def test_multiprocessing(): app = Sanic('test_json') response = Array('c', 50) @@ -52,7 +55,8 @@ def skip_test_multiprocessing(): assert results.get('test') == True - +@pytest.mark.skip( + reason="Freezes with pytest not on interpreter") def test_drain_connections(): app = Sanic('test_json')