Fix pickle error when attempting to pickle an application which contains websocket routes. (#1853)

Moves the websocket_handler subfunction out to a class-level method, which can be more easily pickled by the built-in python Pickler.
Also includes a similar fix for the add_task deferred task scheduler subfunction.

Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Ashley Sommer 2020-06-28 18:05:06 +10:00 committed by GitHub
parent 83511a0ba7
commit 761eef7d96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 58 deletions

View File

@ -117,24 +117,12 @@ class Sanic:
:param task: future, couroutine or awaitable
"""
try:
if callable(task):
try:
self.loop.create_task(task(self))
except TypeError:
self.loop.create_task(task())
else:
self.loop.create_task(task)
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:
@self.listener("before_server_start")
def run(app, loop):
if callable(task):
try:
loop.create_task(task(self))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)
# Decorator
def listener(self, event):
@ -499,42 +487,12 @@ class Sanic:
routes, handler = handler
else:
routes = []
async def websocket_handler(request, *args, **kwargs):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "")
+ handler.__name__
)
pass
if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self
ws = await protocol.websocket_handshake(
request, subprotocols
)
# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()
websocket_handler = partial(
self._websocket_handler, handler, subprotocols=subprotocols
)
websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__
)
routes.extend(
self.router.add(
uri=uri,
@ -598,10 +556,7 @@ class Sanic:
if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly
@self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks:
task.cancel()
self.listener("before_server_stop")(self._cancel_websocket_tasks)
self.websocket_enabled = enable
@ -1425,6 +1380,55 @@ class Sanic:
parts = [self.name, *parts]
return ".".join(parts)
@classmethod
def _loop_add_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()
async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "") + handler.__name__
)
pass
if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self
ws = await protocol.websocket_handshake(request, subprotocols)
# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()
# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #

View File

@ -927,7 +927,7 @@ def serve_multiple(server_settings, workers):
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
mp = multiprocessing.get_context("spawn")
for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)