Fix pickle error when attempting to pickle an application which contains websocket routes. (#1853)
Moves the websocket_handler subfunction out to a class-level method, which can be more easily pickled by the built-in python Pickler. Also includes a similar fix for the add_task deferred task scheduler subfunction. Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
83511a0ba7
commit
761eef7d96
118
sanic/app.py
118
sanic/app.py
@ -117,24 +117,12 @@ class Sanic:
|
|||||||
:param task: future, couroutine or awaitable
|
:param task: future, couroutine or awaitable
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if callable(task):
|
loop = self.loop # Will raise SanicError if loop is not started
|
||||||
try:
|
self._loop_add_task(task, self, loop)
|
||||||
self.loop.create_task(task(self))
|
|
||||||
except TypeError:
|
|
||||||
self.loop.create_task(task())
|
|
||||||
else:
|
|
||||||
self.loop.create_task(task)
|
|
||||||
except SanicException:
|
except SanicException:
|
||||||
|
self.listener("before_server_start")(
|
||||||
@self.listener("before_server_start")
|
partial(self._loop_add_task, task)
|
||||||
def run(app, loop):
|
)
|
||||||
if callable(task):
|
|
||||||
try:
|
|
||||||
loop.create_task(task(self))
|
|
||||||
except TypeError:
|
|
||||||
loop.create_task(task())
|
|
||||||
else:
|
|
||||||
loop.create_task(task)
|
|
||||||
|
|
||||||
# Decorator
|
# Decorator
|
||||||
def listener(self, event):
|
def listener(self, event):
|
||||||
@ -499,42 +487,12 @@ class Sanic:
|
|||||||
routes, handler = handler
|
routes, handler = handler
|
||||||
else:
|
else:
|
||||||
routes = []
|
routes = []
|
||||||
|
websocket_handler = partial(
|
||||||
async def websocket_handler(request, *args, **kwargs):
|
self._websocket_handler, handler, subprotocols=subprotocols
|
||||||
request.app = self
|
)
|
||||||
if not getattr(handler, "__blueprintname__", False):
|
websocket_handler.__name__ = (
|
||||||
request.endpoint = handler.__name__
|
"websocket_handler_" + handler.__name__
|
||||||
else:
|
)
|
||||||
request.endpoint = (
|
|
||||||
getattr(handler, "__blueprintname__", "")
|
|
||||||
+ handler.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.asgi:
|
|
||||||
ws = request.transport.get_websocket_connection()
|
|
||||||
else:
|
|
||||||
protocol = request.transport.get_protocol()
|
|
||||||
protocol.app = self
|
|
||||||
|
|
||||||
ws = await protocol.websocket_handshake(
|
|
||||||
request, subprotocols
|
|
||||||
)
|
|
||||||
|
|
||||||
# schedule the application handler
|
|
||||||
# its future is kept in self.websocket_tasks in case it
|
|
||||||
# needs to be cancelled due to the server being stopped
|
|
||||||
fut = ensure_future(handler(request, ws, *args, **kwargs))
|
|
||||||
self.websocket_tasks.add(fut)
|
|
||||||
try:
|
|
||||||
await fut
|
|
||||||
except (CancelledError, ConnectionClosed):
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
self.websocket_tasks.remove(fut)
|
|
||||||
await ws.close()
|
|
||||||
|
|
||||||
routes.extend(
|
routes.extend(
|
||||||
self.router.add(
|
self.router.add(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
@ -598,10 +556,7 @@ class Sanic:
|
|||||||
if not self.websocket_enabled:
|
if not self.websocket_enabled:
|
||||||
# if the server is stopped, we want to cancel any ongoing
|
# if the server is stopped, we want to cancel any ongoing
|
||||||
# websocket tasks, to allow the server to exit promptly
|
# websocket tasks, to allow the server to exit promptly
|
||||||
@self.listener("before_server_stop")
|
self.listener("before_server_stop")(self._cancel_websocket_tasks)
|
||||||
def cancel_websocket_tasks(app, loop):
|
|
||||||
for task in self.websocket_tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
self.websocket_enabled = enable
|
self.websocket_enabled = enable
|
||||||
|
|
||||||
@ -1425,6 +1380,55 @@ class Sanic:
|
|||||||
parts = [self.name, *parts]
|
parts = [self.name, *parts]
|
||||||
return ".".join(parts)
|
return ".".join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _loop_add_task(cls, task, app, loop):
|
||||||
|
if callable(task):
|
||||||
|
try:
|
||||||
|
loop.create_task(task(app))
|
||||||
|
except TypeError:
|
||||||
|
loop.create_task(task())
|
||||||
|
else:
|
||||||
|
loop.create_task(task)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _cancel_websocket_tasks(cls, app, loop):
|
||||||
|
for task in app.websocket_tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
async def _websocket_handler(
|
||||||
|
self, handler, request, *args, subprotocols=None, **kwargs
|
||||||
|
):
|
||||||
|
request.app = self
|
||||||
|
if not getattr(handler, "__blueprintname__", False):
|
||||||
|
request.endpoint = handler.__name__
|
||||||
|
else:
|
||||||
|
request.endpoint = (
|
||||||
|
getattr(handler, "__blueprintname__", "") + handler.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.asgi:
|
||||||
|
ws = request.transport.get_websocket_connection()
|
||||||
|
else:
|
||||||
|
protocol = request.transport.get_protocol()
|
||||||
|
protocol.app = self
|
||||||
|
|
||||||
|
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||||
|
|
||||||
|
# schedule the application handler
|
||||||
|
# its future is kept in self.websocket_tasks in case it
|
||||||
|
# needs to be cancelled due to the server being stopped
|
||||||
|
fut = ensure_future(handler(request, ws, *args, **kwargs))
|
||||||
|
self.websocket_tasks.add(fut)
|
||||||
|
try:
|
||||||
|
await fut
|
||||||
|
except (CancelledError, ConnectionClosed):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self.websocket_tasks.remove(fut)
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
# -------------------------------------------------------------------- #
|
# -------------------------------------------------------------------- #
|
||||||
# ASGI
|
# ASGI
|
||||||
# -------------------------------------------------------------------- #
|
# -------------------------------------------------------------------- #
|
||||||
|
@ -927,7 +927,7 @@ def serve_multiple(server_settings, workers):
|
|||||||
|
|
||||||
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
signal_func(SIGINT, lambda s, f: sig_handler(s, f))
|
||||||
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
|
||||||
mp = multiprocessing.get_context("fork")
|
mp = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
for _ in range(workers):
|
for _ in range(workers):
|
||||||
process = mp.Process(target=serve, kwargs=server_settings)
|
process = mp.Process(target=serve, kwargs=server_settings)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user