New websockets (#2158)
* First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes Co-authored-by: Adam Hopkins <adam@amhopkins.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
595d2c76ac
commit
6ffc4d9756
|
@ -1,13 +1,14 @@
|
||||||
from sanic import Sanic
|
from sanic import Sanic
|
||||||
from sanic.response import file
|
from sanic.response import redirect
|
||||||
|
|
||||||
app = Sanic(__name__)
|
app = Sanic(__name__)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
app.static('index.html', "websocket.html")
|
||||||
async def index(request):
|
|
||||||
return await file('websocket.html')
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index(request):
|
||||||
|
return redirect("index.html")
|
||||||
|
|
||||||
@app.websocket('/feed')
|
@app.websocket('/feed')
|
||||||
async def feed(request, ws):
|
async def feed(request, ws):
|
||||||
|
|
23
sanic/app.py
23
sanic/app.py
|
@ -74,9 +74,10 @@ from sanic.router import Router
|
||||||
from sanic.server import AsyncioServer, HttpProtocol
|
from sanic.server import AsyncioServer, HttpProtocol
|
||||||
from sanic.server import Signal as ServerSignal
|
from sanic.server import Signal as ServerSignal
|
||||||
from sanic.server import serve, serve_multiple, serve_single
|
from sanic.server import serve, serve_multiple, serve_single
|
||||||
|
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||||
|
from sanic.server.websockets.impl import ConnectionClosed
|
||||||
from sanic.signals import Signal, SignalRouter
|
from sanic.signals import Signal, SignalRouter
|
||||||
from sanic.touchup import TouchUp, TouchUpMeta
|
from sanic.touchup import TouchUp, TouchUpMeta
|
||||||
from sanic.websocket import ConnectionClosed, WebSocketProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||||
|
@ -871,23 +872,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||||
async def _websocket_handler(
|
async def _websocket_handler(
|
||||||
self, handler, request, *args, subprotocols=None, **kwargs
|
self, handler, request, *args, subprotocols=None, **kwargs
|
||||||
):
|
):
|
||||||
request.app = self
|
|
||||||
if not getattr(handler, "__blueprintname__", False):
|
|
||||||
request._name = handler.__name__
|
|
||||||
else:
|
|
||||||
request._name = (
|
|
||||||
getattr(handler, "__blueprintname__", "") + handler.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.asgi:
|
if self.asgi:
|
||||||
ws = request.transport.get_websocket_connection()
|
ws = request.transport.get_websocket_connection()
|
||||||
await ws.accept(subprotocols)
|
await ws.accept(subprotocols)
|
||||||
else:
|
else:
|
||||||
protocol = request.transport.get_protocol()
|
protocol = request.transport.get_protocol()
|
||||||
protocol.app = self
|
|
||||||
|
|
||||||
ws = await protocol.websocket_handshake(request, subprotocols)
|
ws = await protocol.websocket_handshake(request, subprotocols)
|
||||||
|
|
||||||
# schedule the application handler
|
# schedule the application handler
|
||||||
|
@ -895,15 +884,19 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
|
||||||
# needs to be cancelled due to the server being stopped
|
# needs to be cancelled due to the server being stopped
|
||||||
fut = ensure_future(handler(request, ws, *args, **kwargs))
|
fut = ensure_future(handler(request, ws, *args, **kwargs))
|
||||||
self.websocket_tasks.add(fut)
|
self.websocket_tasks.add(fut)
|
||||||
|
cancelled = False
|
||||||
try:
|
try:
|
||||||
await fut
|
await fut
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.error_handler.log(request, e)
|
self.error_handler.log(request, e)
|
||||||
except (CancelledError, ConnectionClosed):
|
except (CancelledError, ConnectionClosed):
|
||||||
pass
|
cancelled = True
|
||||||
finally:
|
finally:
|
||||||
self.websocket_tasks.remove(fut)
|
self.websocket_tasks.remove(fut)
|
||||||
await ws.close()
|
if cancelled:
|
||||||
|
ws.end_connection(1000)
|
||||||
|
else:
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
# -------------------------------------------------------------------- #
|
# -------------------------------------------------------------------- #
|
||||||
# Testing
|
# Testing
|
||||||
|
|
|
@ -10,7 +10,7 @@ from sanic.exceptions import ServerError
|
||||||
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
from sanic.server import ConnInfo
|
from sanic.server import ConnInfo
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.server.websockets.connection import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
class Lifespan:
|
class Lifespan:
|
||||||
|
|
|
@ -35,12 +35,9 @@ DEFAULT_CONFIG = {
|
||||||
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
|
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes
|
||||||
"REQUEST_TIMEOUT": 60, # 60 seconds
|
"REQUEST_TIMEOUT": 60, # 60 seconds
|
||||||
"RESPONSE_TIMEOUT": 60, # 60 seconds
|
"RESPONSE_TIMEOUT": 60, # 60 seconds
|
||||||
"WEBSOCKET_MAX_QUEUE": 32,
|
|
||||||
"WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte
|
"WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte
|
||||||
"WEBSOCKET_PING_INTERVAL": 20,
|
"WEBSOCKET_PING_INTERVAL": 20,
|
||||||
"WEBSOCKET_PING_TIMEOUT": 20,
|
"WEBSOCKET_PING_TIMEOUT": 20,
|
||||||
"WEBSOCKET_READ_LIMIT": 2 ** 16,
|
|
||||||
"WEBSOCKET_WRITE_LIMIT": 2 ** 16,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,12 +59,9 @@ class Config(dict):
|
||||||
REQUEST_MAX_SIZE: int
|
REQUEST_MAX_SIZE: int
|
||||||
REQUEST_TIMEOUT: int
|
REQUEST_TIMEOUT: int
|
||||||
RESPONSE_TIMEOUT: int
|
RESPONSE_TIMEOUT: int
|
||||||
WEBSOCKET_MAX_QUEUE: int
|
|
||||||
WEBSOCKET_MAX_SIZE: int
|
WEBSOCKET_MAX_SIZE: int
|
||||||
WEBSOCKET_PING_INTERVAL: int
|
WEBSOCKET_PING_INTERVAL: int
|
||||||
WEBSOCKET_PING_TIMEOUT: int
|
WEBSOCKET_PING_TIMEOUT: int
|
||||||
WEBSOCKET_READ_LIMIT: int
|
|
||||||
WEBSOCKET_WRITE_LIMIT: int
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -121,8 +121,11 @@ class RouteMixin:
|
||||||
"Expected either string or Iterable of host strings, "
|
"Expected either string or Iterable of host strings, "
|
||||||
"not %s" % host
|
"not %s" % host
|
||||||
)
|
)
|
||||||
|
if isinstance(subprotocols, list):
|
||||||
if isinstance(subprotocols, (list, tuple, set)):
|
# Ordered subprotocols, maintain order
|
||||||
|
subprotocols = tuple(subprotocols)
|
||||||
|
elif isinstance(subprotocols, set):
|
||||||
|
# subprotocol is unordered, keep it unordered
|
||||||
subprotocols = frozenset(subprotocols)
|
subprotocols = frozenset(subprotocols)
|
||||||
|
|
||||||
route = FutureRoute(
|
route = FutureRoute(
|
||||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
||||||
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
|
||||||
|
|
||||||
from sanic.exceptions import InvalidUsage
|
from sanic.exceptions import InvalidUsage
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.server.websockets.connection import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
ASGIScope = MutableMapping[str, Any]
|
ASGIScope = MutableMapping[str, Any]
|
||||||
|
|
165
sanic/server/protocols/websocket_protocol.py
Normal file
165
sanic/server/protocols/websocket_protocol.py
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
|
|
||||||
|
from websockets.connection import CLOSED, CLOSING, OPEN
|
||||||
|
from websockets.server import ServerConnection
|
||||||
|
|
||||||
|
from sanic.exceptions import ServerError
|
||||||
|
from sanic.log import error_logger
|
||||||
|
from sanic.server import HttpProtocol
|
||||||
|
|
||||||
|
from ..websockets.impl import WebsocketImplProtocol
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from websockets import http11
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketProtocol(HttpProtocol):
|
||||||
|
|
||||||
|
websocket: Optional[WebsocketImplProtocol]
|
||||||
|
websocket_timeout: float
|
||||||
|
websocket_max_size = Optional[int]
|
||||||
|
websocket_ping_interval = Optional[float]
|
||||||
|
websocket_ping_timeout = Optional[float]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
websocket_timeout: float = 10.0,
|
||||||
|
websocket_max_size: Optional[int] = None,
|
||||||
|
websocket_max_queue: Optional[int] = None, # max_queue is deprecated
|
||||||
|
websocket_read_limit: Optional[int] = None, # read_limit is deprecated
|
||||||
|
websocket_write_limit: Optional[int] = None, # write_limit deprecated
|
||||||
|
websocket_ping_interval: Optional[float] = 20.0,
|
||||||
|
websocket_ping_timeout: Optional[float] = 20.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.websocket = None
|
||||||
|
self.websocket_timeout = websocket_timeout
|
||||||
|
self.websocket_max_size = websocket_max_size
|
||||||
|
if websocket_max_queue is not None and websocket_max_queue > 0:
|
||||||
|
# TODO: Reminder remove this warning in v22.3
|
||||||
|
error_logger.warning(
|
||||||
|
DeprecationWarning(
|
||||||
|
"Websocket no longer uses queueing, so websocket_max_queue"
|
||||||
|
" is no longer required."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if websocket_read_limit is not None and websocket_read_limit > 0:
|
||||||
|
# TODO: Reminder remove this warning in v22.3
|
||||||
|
error_logger.warning(
|
||||||
|
DeprecationWarning(
|
||||||
|
"Websocket no longer uses read buffers, so "
|
||||||
|
"websocket_read_limit is not required."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if websocket_write_limit is not None and websocket_write_limit > 0:
|
||||||
|
# TODO: Reminder remove this warning in v22.3
|
||||||
|
error_logger.warning(
|
||||||
|
DeprecationWarning(
|
||||||
|
"Websocket no longer uses write buffers, so "
|
||||||
|
"websocket_write_limit is not required."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.websocket_ping_interval = websocket_ping_interval
|
||||||
|
self.websocket_ping_timeout = websocket_ping_timeout
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
if self.websocket is not None:
|
||||||
|
self.websocket.connection_lost(exc)
|
||||||
|
super().connection_lost(exc)
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
if self.websocket is not None:
|
||||||
|
self.websocket.data_received(data)
|
||||||
|
else:
|
||||||
|
# Pass it to HttpProtocol handler first
|
||||||
|
# That will (hopefully) upgrade it to a websocket.
|
||||||
|
super().data_received(data)
|
||||||
|
|
||||||
|
def eof_received(self) -> Optional[bool]:
|
||||||
|
if self.websocket is not None:
|
||||||
|
return self.websocket.eof_received()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self, timeout: Optional[float] = None):
|
||||||
|
# Called by HttpProtocol at the end of connection_task
|
||||||
|
# If we've upgraded to websocket, we do our own closing
|
||||||
|
if self.websocket is not None:
|
||||||
|
# Note, we don't want to use websocket.close()
|
||||||
|
# That is used for user's application code to send a
|
||||||
|
# websocket close packet. This is different.
|
||||||
|
self.websocket.end_connection(1001)
|
||||||
|
else:
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
def close_if_idle(self):
|
||||||
|
# Called by Sanic Server when shutting down
|
||||||
|
# If we've upgraded to websocket, shut it down
|
||||||
|
if self.websocket is not None:
|
||||||
|
if self.websocket.connection.state in (CLOSING, CLOSED):
|
||||||
|
return True
|
||||||
|
elif self.websocket.loop is not None:
|
||||||
|
self.websocket.loop.create_task(self.websocket.close(1001))
|
||||||
|
else:
|
||||||
|
self.websocket.end_connection(1001)
|
||||||
|
else:
|
||||||
|
return super().close_if_idle()
|
||||||
|
|
||||||
|
async def websocket_handshake(
|
||||||
|
self, request, subprotocols=Optional[Sequence[str]]
|
||||||
|
):
|
||||||
|
# let the websockets package do the handshake with the client
|
||||||
|
try:
|
||||||
|
if subprotocols is not None:
|
||||||
|
# subprotocols can be a set or frozenset,
|
||||||
|
# but ServerConnection needs a list
|
||||||
|
subprotocols = list(subprotocols)
|
||||||
|
ws_conn = ServerConnection(
|
||||||
|
max_size=self.websocket_max_size,
|
||||||
|
subprotocols=subprotocols,
|
||||||
|
state=OPEN,
|
||||||
|
logger=error_logger,
|
||||||
|
)
|
||||||
|
resp: "http11.Response" = ws_conn.accept(request)
|
||||||
|
except Exception:
|
||||||
|
msg = (
|
||||||
|
"Failed to open a WebSocket connection.\n"
|
||||||
|
"See server log for more information.\n"
|
||||||
|
)
|
||||||
|
raise ServerError(msg, status_code=500)
|
||||||
|
if 100 <= resp.status_code <= 299:
|
||||||
|
rbody = "".join(
|
||||||
|
[
|
||||||
|
"HTTP/1.1 ",
|
||||||
|
str(resp.status_code),
|
||||||
|
" ",
|
||||||
|
resp.reason_phrase,
|
||||||
|
"\r\n",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
|
||||||
|
if resp.body is not None:
|
||||||
|
rbody += f"\r\n{resp.body}\r\n\r\n"
|
||||||
|
else:
|
||||||
|
rbody += "\r\n"
|
||||||
|
await super().send(rbody.encode())
|
||||||
|
else:
|
||||||
|
raise ServerError(resp.body, resp.status_code)
|
||||||
|
|
||||||
|
self.websocket = WebsocketImplProtocol(
|
||||||
|
ws_conn,
|
||||||
|
ping_interval=self.websocket_ping_interval,
|
||||||
|
ping_timeout=self.websocket_ping_timeout,
|
||||||
|
close_timeout=self.websocket_timeout,
|
||||||
|
)
|
||||||
|
loop = (
|
||||||
|
request.transport.loop
|
||||||
|
if hasattr(request, "transport")
|
||||||
|
and hasattr(request.transport, "loop")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
await self.websocket.connection_made(self, loop=loop)
|
||||||
|
return self.websocket
|
|
@ -175,15 +175,11 @@ def serve(
|
||||||
|
|
||||||
# Force close non-idle connection after waiting for
|
# Force close non-idle connection after waiting for
|
||||||
# graceful_shutdown_timeout
|
# graceful_shutdown_timeout
|
||||||
coros = []
|
|
||||||
for conn in connections:
|
for conn in connections:
|
||||||
if hasattr(conn, "websocket") and conn.websocket:
|
if hasattr(conn, "websocket") and conn.websocket:
|
||||||
coros.append(conn.websocket.close_connection())
|
conn.websocket.fail_connection(code=1001)
|
||||||
else:
|
else:
|
||||||
conn.abort()
|
conn.abort()
|
||||||
|
|
||||||
_shutdown = asyncio.gather(*coros)
|
|
||||||
loop.run_until_complete(_shutdown)
|
|
||||||
loop.run_until_complete(app._server_event("shutdown", "after"))
|
loop.run_until_complete(app._server_event("shutdown", "after"))
|
||||||
|
|
||||||
remove_unix_socket(unix)
|
remove_unix_socket(unix)
|
||||||
|
@ -278,9 +274,6 @@ def _build_protocol_kwargs(
|
||||||
if hasattr(protocol, "websocket_handshake"):
|
if hasattr(protocol, "websocket_handshake"):
|
||||||
return {
|
return {
|
||||||
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
"websocket_max_size": config.WEBSOCKET_MAX_SIZE,
|
||||||
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
|
|
||||||
"websocket_read_limit": config.WEBSOCKET_READ_LIMIT,
|
|
||||||
"websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT,
|
|
||||||
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
|
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
|
||||||
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
|
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
|
||||||
}
|
}
|
||||||
|
|
0
sanic/server/websockets/__init__.py
Normal file
0
sanic/server/websockets/__init__.py
Normal file
82
sanic/server/websockets/connection.py
Normal file
82
sanic/server/websockets/connection.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
MutableMapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ASIMessage = MutableMapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketConnection:
|
||||||
|
"""
|
||||||
|
This is for ASGI Connections.
|
||||||
|
It provides an interface similar to WebsocketProtocol, but
|
||||||
|
sends/receives over an ASGI connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# - Implement ping/pong
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
send: Callable[[ASIMessage], Awaitable[None]],
|
||||||
|
receive: Callable[[], Awaitable[ASIMessage]],
|
||||||
|
subprotocols: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
self._send = send
|
||||||
|
self._receive = receive
|
||||||
|
self._subprotocols = subprotocols or []
|
||||||
|
|
||||||
|
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
||||||
|
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
||||||
|
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
message.update({"bytes": data})
|
||||||
|
else:
|
||||||
|
message.update({"text": str(data)})
|
||||||
|
|
||||||
|
await self._send(message)
|
||||||
|
|
||||||
|
async def recv(self, *args, **kwargs) -> Optional[str]:
|
||||||
|
message = await self._receive()
|
||||||
|
|
||||||
|
if message["type"] == "websocket.receive":
|
||||||
|
return message["text"]
|
||||||
|
elif message["type"] == "websocket.disconnect":
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
receive = recv
|
||||||
|
|
||||||
|
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
||||||
|
subprotocol = None
|
||||||
|
if subprotocols:
|
||||||
|
for subp in subprotocols:
|
||||||
|
if subp in self.subprotocols:
|
||||||
|
subprotocol = subp
|
||||||
|
break
|
||||||
|
|
||||||
|
await self._send(
|
||||||
|
{
|
||||||
|
"type": "websocket.accept",
|
||||||
|
"subprotocol": subprotocol,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000, reason: str = "") -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subprotocols(self):
|
||||||
|
return self._subprotocols
|
||||||
|
|
||||||
|
@subprotocols.setter
|
||||||
|
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
||||||
|
self._subprotocols = subprotocols or []
|
297
sanic/server/websockets/frame.py
Normal file
297
sanic/server/websockets/frame.py
Normal file
|
@ -0,0 +1,297 @@
|
||||||
|
import asyncio
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||||
|
|
||||||
|
from websockets.frames import Frame, Opcode
|
||||||
|
from websockets.typing import Data
|
||||||
|
|
||||||
|
from sanic.exceptions import ServerError
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .impl import WebsocketImplProtocol
|
||||||
|
|
||||||
|
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketFrameAssembler:
|
||||||
|
"""
|
||||||
|
Assemble a message from frames.
|
||||||
|
Code borrowed from aaugustin/websockets project:
|
||||||
|
https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = (
|
||||||
|
"protocol",
|
||||||
|
"read_mutex",
|
||||||
|
"write_mutex",
|
||||||
|
"message_complete",
|
||||||
|
"message_fetched",
|
||||||
|
"get_in_progress",
|
||||||
|
"decoder",
|
||||||
|
"completed_queue",
|
||||||
|
"chunks",
|
||||||
|
"chunks_queue",
|
||||||
|
"paused",
|
||||||
|
"get_id",
|
||||||
|
"put_id",
|
||||||
|
)
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
protocol: "WebsocketImplProtocol"
|
||||||
|
read_mutex: asyncio.Lock
|
||||||
|
write_mutex: asyncio.Lock
|
||||||
|
message_complete: asyncio.Event
|
||||||
|
message_fetched: asyncio.Event
|
||||||
|
completed_queue: asyncio.Queue
|
||||||
|
get_in_progress: bool
|
||||||
|
decoder: Optional[codecs.IncrementalDecoder]
|
||||||
|
# For streaming chunks rather than messages:
|
||||||
|
chunks: List[Data]
|
||||||
|
chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
|
||||||
|
paused: bool
|
||||||
|
|
||||||
|
def __init__(self, protocol) -> None:
|
||||||
|
|
||||||
|
self.protocol = protocol
|
||||||
|
|
||||||
|
self.read_mutex = asyncio.Lock()
|
||||||
|
self.write_mutex = asyncio.Lock()
|
||||||
|
|
||||||
|
self.completed_queue = asyncio.Queue(
|
||||||
|
maxsize=1
|
||||||
|
) # type: asyncio.Queue[Data]
|
||||||
|
|
||||||
|
# put() sets this event to tell get() that a message can be fetched.
|
||||||
|
self.message_complete = asyncio.Event()
|
||||||
|
# get() sets this event to let put()
|
||||||
|
self.message_fetched = asyncio.Event()
|
||||||
|
|
||||||
|
# This flag prevents concurrent calls to get() by user code.
|
||||||
|
self.get_in_progress = False
|
||||||
|
|
||||||
|
# Decoder for text frames, None for binary frames.
|
||||||
|
self.decoder = None
|
||||||
|
|
||||||
|
# Buffer data from frames belonging to the same message.
|
||||||
|
self.chunks = []
|
||||||
|
|
||||||
|
# When switching from "buffering" to "streaming", we use a thread-safe
|
||||||
|
# queue for transferring frames from the writing thread (library code)
|
||||||
|
# to the reading thread (user code). We're buffering when chunks_queue
|
||||||
|
# is None and streaming when it's a Queue. None is a sentinel
|
||||||
|
# value marking the end of the stream, superseding message_complete.
|
||||||
|
|
||||||
|
# Stream data from frames belonging to the same message.
|
||||||
|
self.chunks_queue = None
|
||||||
|
|
||||||
|
# Flag to indicate we've paused the protocol
|
||||||
|
self.paused = False
|
||||||
|
|
||||||
|
async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||||
|
"""
|
||||||
|
Read the next message.
|
||||||
|
:meth:`get` returns a single :class:`str` or :class:`bytes`.
|
||||||
|
If the :message was fragmented, :meth:`get` waits until the last frame
|
||||||
|
is received, then it reassembles the message.
|
||||||
|
If ``timeout`` is set and elapses before a complete message is
|
||||||
|
received, :meth:`get` returns ``None``.
|
||||||
|
"""
|
||||||
|
async with self.read_mutex:
|
||||||
|
if timeout is not None and timeout <= 0:
|
||||||
|
if not self.message_complete.is_set():
|
||||||
|
return None
|
||||||
|
if self.get_in_progress:
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# exception is only here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"Called get() on Websocket frame assembler "
|
||||||
|
"while asynchronous get is already in progress."
|
||||||
|
)
|
||||||
|
self.get_in_progress = True
|
||||||
|
|
||||||
|
# If the message_complete event isn't set yet, release the lock to
|
||||||
|
# allow put() to run and eventually set it.
|
||||||
|
# Locking with get_in_progress ensures only one task can get here.
|
||||||
|
if timeout is None:
|
||||||
|
completed = await self.message_complete.wait()
|
||||||
|
elif timeout <= 0:
|
||||||
|
completed = self.message_complete.is_set()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.message_complete.wait(), timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
...
|
||||||
|
finally:
|
||||||
|
completed = self.message_complete.is_set()
|
||||||
|
|
||||||
|
# Unpause the transport, if its paused
|
||||||
|
if self.paused:
|
||||||
|
self.protocol.resume_frames()
|
||||||
|
self.paused = False
|
||||||
|
if not self.get_in_progress:
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# exception is here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"State of Websocket frame assembler was modified while an "
|
||||||
|
"asynchronous get was in progress."
|
||||||
|
)
|
||||||
|
self.get_in_progress = False
|
||||||
|
|
||||||
|
# Waiting for a complete message timed out.
|
||||||
|
if not completed:
|
||||||
|
return None
|
||||||
|
if not self.message_complete.is_set():
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.message_complete.clear()
|
||||||
|
|
||||||
|
joiner: Data = b"" if self.decoder is None else ""
|
||||||
|
# mypy cannot figure out that chunks have the proper type.
|
||||||
|
message: Data = joiner.join(self.chunks) # type: ignore
|
||||||
|
if self.message_fetched.is_set():
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# and get_in_progress check, this exception is here
|
||||||
|
# as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"Websocket get() found a message when "
|
||||||
|
"state was already fetched."
|
||||||
|
)
|
||||||
|
self.message_fetched.set()
|
||||||
|
self.chunks = []
|
||||||
|
self.chunks_queue = (
|
||||||
|
None # this should already be None, but set it here for safety
|
||||||
|
)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def get_iter(self) -> AsyncIterator[Data]:
|
||||||
|
"""
|
||||||
|
Stream the next message.
|
||||||
|
Iterating the return value of :meth:`get_iter` yields a :class:`str`
|
||||||
|
or :class:`bytes` for each frame in the message.
|
||||||
|
"""
|
||||||
|
async with self.read_mutex:
|
||||||
|
if self.get_in_progress:
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# exception is only here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"Called get_iter on Websocket frame assembler "
|
||||||
|
"while asynchronous get is already in progress."
|
||||||
|
)
|
||||||
|
self.get_in_progress = True
|
||||||
|
|
||||||
|
chunks = self.chunks
|
||||||
|
self.chunks = []
|
||||||
|
self.chunks_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Sending None in chunk_queue supersedes setting message_complete
|
||||||
|
# when switching to "streaming". If message is already complete
|
||||||
|
# when the switch happens, put() didn't send None, so we have to.
|
||||||
|
if self.message_complete.is_set():
|
||||||
|
await self.chunks_queue.put(None)
|
||||||
|
|
||||||
|
# Locking with get_in_progress ensures only one thread can get here
|
||||||
|
for c in chunks:
|
||||||
|
yield c
|
||||||
|
while True:
|
||||||
|
chunk = await self.chunks_queue.get()
|
||||||
|
if chunk is None:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Unpause the transport, if its paused
|
||||||
|
if self.paused:
|
||||||
|
self.protocol.resume_frames()
|
||||||
|
self.paused = False
|
||||||
|
if not self.get_in_progress:
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# exception is here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"State of Websocket frame assembler was modified while an "
|
||||||
|
"asynchronous get was in progress."
|
||||||
|
)
|
||||||
|
self.get_in_progress = False
|
||||||
|
if not self.message_complete.is_set():
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# exception is here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"Websocket frame assembler chunks queue ended before "
|
||||||
|
"message was complete."
|
||||||
|
)
|
||||||
|
self.message_complete.clear()
|
||||||
|
if self.message_fetched.is_set():
|
||||||
|
# This should be guarded against with the read_mutex,
|
||||||
|
# and get_in_progress check, this exception is
|
||||||
|
# here as a failsafe
|
||||||
|
raise ServerError(
|
||||||
|
"Websocket get_iter() found a message when state was "
|
||||||
|
"already fetched."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.message_fetched.set()
|
||||||
|
self.chunks = (
|
||||||
|
[]
|
||||||
|
) # this should already be empty, but set it here for safety
|
||||||
|
self.chunks_queue = None
|
||||||
|
|
||||||
|
async def put(self, frame: Frame) -> None:
|
||||||
|
"""
|
||||||
|
Add ``frame`` to the next message.
|
||||||
|
When ``frame`` is the final frame in a message, :meth:`put` waits
|
||||||
|
until the message is fetched, either by calling :meth:`get` or by
|
||||||
|
iterating the return value of :meth:`get_iter`.
|
||||||
|
:meth:`put` assumes that the stream of frames respects the protocol.
|
||||||
|
If it doesn't, the behavior is undefined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self.write_mutex:
|
||||||
|
if frame.opcode is Opcode.TEXT:
|
||||||
|
self.decoder = UTF8Decoder(errors="strict")
|
||||||
|
elif frame.opcode is Opcode.BINARY:
|
||||||
|
self.decoder = None
|
||||||
|
elif frame.opcode is Opcode.CONT:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Ignore control frames.
|
||||||
|
return
|
||||||
|
data: Data
|
||||||
|
if self.decoder is not None:
|
||||||
|
data = self.decoder.decode(frame.data, frame.fin)
|
||||||
|
else:
|
||||||
|
data = frame.data
|
||||||
|
if self.chunks_queue is None:
|
||||||
|
self.chunks.append(data)
|
||||||
|
else:
|
||||||
|
await self.chunks_queue.put(data)
|
||||||
|
|
||||||
|
if not frame.fin:
|
||||||
|
return
|
||||||
|
if not self.get_in_progress:
|
||||||
|
# nobody is waiting for this frame, so try to pause subsequent
|
||||||
|
# frames at the protocol level
|
||||||
|
self.paused = self.protocol.pause_frames()
|
||||||
|
# Message is complete. Wait until it's fetched to return.
|
||||||
|
|
||||||
|
if self.chunks_queue is not None:
|
||||||
|
await self.chunks_queue.put(None)
|
||||||
|
if self.message_complete.is_set():
|
||||||
|
# This should be guarded against with the write_mutex
|
||||||
|
raise ServerError(
|
||||||
|
"Websocket put() got a new message when a message was "
|
||||||
|
"already in its chamber."
|
||||||
|
)
|
||||||
|
self.message_complete.set() # Signal to get() it can serve the
|
||||||
|
if self.message_fetched.is_set():
|
||||||
|
# This should be guarded against with the write_mutex
|
||||||
|
raise ServerError(
|
||||||
|
"Websocket put() got a new message when the previous "
|
||||||
|
"message was not yet fetched."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow get() to run and eventually set the event.
|
||||||
|
await self.message_fetched.wait()
|
||||||
|
self.message_fetched.clear()
|
||||||
|
self.decoder = None
|
789
sanic/server/websockets/impl.py
Normal file
789
sanic/server/websockets/impl.py
Normal file
|
@ -0,0 +1,789 @@
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
||||||
|
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
||||||
|
from websockets.frames import OP_PONG, Frame
|
||||||
|
from websockets.server import ServerConnection
|
||||||
|
from websockets.typing import Data
|
||||||
|
|
||||||
|
from sanic.log import error_logger, logger
|
||||||
|
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||||
|
|
||||||
|
from ...exceptions import ServerError
|
||||||
|
from .frame import WebsocketFrameAssembler
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketImplProtocol:
|
||||||
|
connection: ServerConnection
|
||||||
|
io_proto: Optional[SanicProtocol]
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop]
|
||||||
|
max_queue: int
|
||||||
|
close_timeout: float
|
||||||
|
ping_interval: Optional[float]
|
||||||
|
ping_timeout: Optional[float]
|
||||||
|
assembler: WebsocketFrameAssembler
|
||||||
|
# Dict[bytes, asyncio.Future[None]]
|
||||||
|
pings: Dict[bytes, asyncio.Future]
|
||||||
|
conn_mutex: asyncio.Lock
|
||||||
|
recv_lock: asyncio.Lock
|
||||||
|
process_event_mutex: asyncio.Lock
|
||||||
|
can_pause: bool
|
||||||
|
# Optional[asyncio.Future[None]]
|
||||||
|
data_finished_fut: Optional[asyncio.Future]
|
||||||
|
# Optional[asyncio.Future[None]]
|
||||||
|
pause_frame_fut: Optional[asyncio.Future]
|
||||||
|
# Optional[asyncio.Future[None]]
|
||||||
|
connection_lost_waiter: Optional[asyncio.Future]
|
||||||
|
keepalive_ping_task: Optional[asyncio.Task]
|
||||||
|
auto_closer_task: Optional[asyncio.Task]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection,
|
||||||
|
max_queue=None,
|
||||||
|
ping_interval: Optional[float] = 20,
|
||||||
|
ping_timeout: Optional[float] = 20,
|
||||||
|
close_timeout: float = 10,
|
||||||
|
loop=None,
|
||||||
|
):
|
||||||
|
self.connection = connection
|
||||||
|
self.io_proto = None
|
||||||
|
self.loop = None
|
||||||
|
self.max_queue = max_queue
|
||||||
|
self.close_timeout = close_timeout
|
||||||
|
self.ping_interval = ping_interval
|
||||||
|
self.ping_timeout = ping_timeout
|
||||||
|
self.assembler = WebsocketFrameAssembler(self)
|
||||||
|
self.pings = {}
|
||||||
|
self.conn_mutex = asyncio.Lock()
|
||||||
|
self.recv_lock = asyncio.Lock()
|
||||||
|
self.process_event_mutex = asyncio.Lock()
|
||||||
|
self.data_finished_fut = None
|
||||||
|
self.can_pause = True
|
||||||
|
self.pause_frame_fut = None
|
||||||
|
self.keepalive_ping_task = None
|
||||||
|
self.auto_closer_task = None
|
||||||
|
self.connection_lost_waiter = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subprotocol(self):
|
||||||
|
return self.connection.subprotocol
|
||||||
|
|
||||||
|
def pause_frames(self):
|
||||||
|
if not self.can_pause:
|
||||||
|
return False
|
||||||
|
if self.pause_frame_fut:
|
||||||
|
logger.debug("Websocket connection already paused.")
|
||||||
|
return False
|
||||||
|
if (not self.loop) or (not self.io_proto):
|
||||||
|
return False
|
||||||
|
if self.io_proto.transport:
|
||||||
|
self.io_proto.transport.pause_reading()
|
||||||
|
self.pause_frame_fut = self.loop.create_future()
|
||||||
|
logger.debug("Websocket connection paused.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resume_frames(self):
|
||||||
|
if not self.pause_frame_fut:
|
||||||
|
logger.debug("Websocket connection not paused.")
|
||||||
|
return False
|
||||||
|
if (not self.loop) or (not self.io_proto):
|
||||||
|
logger.debug(
|
||||||
|
"Websocket attempting to resume reading frames, "
|
||||||
|
"but connection is gone."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if self.io_proto.transport:
|
||||||
|
self.io_proto.transport.resume_reading()
|
||||||
|
self.pause_frame_fut.set_result(None)
|
||||||
|
self.pause_frame_fut = None
|
||||||
|
logger.debug("Websocket connection unpaused.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def connection_made(
|
||||||
|
self,
|
||||||
|
io_proto: SanicProtocol,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
|
):
|
||||||
|
if not loop:
|
||||||
|
try:
|
||||||
|
loop = getattr(io_proto, "loop")
|
||||||
|
except AttributeError:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if not loop:
|
||||||
|
# This catch is for mypy type checker
|
||||||
|
# to assert loop is not None here.
|
||||||
|
raise ServerError("Connection received with no asyncio loop.")
|
||||||
|
if self.auto_closer_task:
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot call connection_made more than once "
|
||||||
|
"on a websocket connection."
|
||||||
|
)
|
||||||
|
self.loop = loop
|
||||||
|
self.io_proto = io_proto
|
||||||
|
self.connection_lost_waiter = self.loop.create_future()
|
||||||
|
self.data_finished_fut = asyncio.shield(self.loop.create_future())
|
||||||
|
|
||||||
|
if self.ping_interval:
|
||||||
|
self.keepalive_ping_task = asyncio.create_task(
|
||||||
|
self.keepalive_ping()
|
||||||
|
)
|
||||||
|
self.auto_closer_task = asyncio.create_task(
|
||||||
|
self.auto_close_connection()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def wait_for_connection_lost(self, timeout=None) -> bool:
|
||||||
|
"""
|
||||||
|
Wait until the TCP connection is closed or ``timeout`` elapses.
|
||||||
|
If timeout is None, wait forever.
|
||||||
|
Recommend you should pass in self.close_timeout as timeout
|
||||||
|
|
||||||
|
Return ``True`` if the connection is closed and ``False`` otherwise.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self.connection_lost_waiter:
|
||||||
|
return False
|
||||||
|
if self.connection_lost_waiter.done():
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.shield(self.connection_lost_waiter), timeout
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Re-check self.connection_lost_waiter.done() synchronously
|
||||||
|
# because connection_lost() could run between the moment the
|
||||||
|
# timeout occurs and the moment this coroutine resumes running
|
||||||
|
return self.connection_lost_waiter.done()
|
||||||
|
|
||||||
|
async def process_events(self, events: Sequence[Event]) -> None:
|
||||||
|
"""
|
||||||
|
Process a list of incoming events.
|
||||||
|
"""
|
||||||
|
# Wrapped in a mutex lock, to prevent other incoming events
|
||||||
|
# from processing at the same time
|
||||||
|
async with self.process_event_mutex:
|
||||||
|
for event in events:
|
||||||
|
if not isinstance(event, Frame):
|
||||||
|
# Event is not a frame. Ignore it.
|
||||||
|
continue
|
||||||
|
if event.opcode == OP_PONG:
|
||||||
|
await self.process_pong(event)
|
||||||
|
else:
|
||||||
|
await self.assembler.put(event)
|
||||||
|
|
||||||
|
async def process_pong(self, frame: "Frame") -> None:
|
||||||
|
if frame.data in self.pings:
|
||||||
|
# Acknowledge all pings up to the one matching this pong.
|
||||||
|
ping_ids = []
|
||||||
|
for ping_id, ping in self.pings.items():
|
||||||
|
ping_ids.append(ping_id)
|
||||||
|
if not ping.done():
|
||||||
|
ping.set_result(None)
|
||||||
|
if ping_id == frame.data:
|
||||||
|
break
|
||||||
|
else: # noqa
|
||||||
|
raise ServerError("ping_id is not in self.pings")
|
||||||
|
# Remove acknowledged pings from self.pings.
|
||||||
|
for ping_id in ping_ids:
|
||||||
|
del self.pings[ping_id]
|
||||||
|
|
||||||
|
async def keepalive_ping(self) -> None:
|
||||||
|
"""
|
||||||
|
Send a Ping frame and wait for a Pong frame at regular intervals.
|
||||||
|
This coroutine exits when the connection terminates and one of the
|
||||||
|
following happens:
|
||||||
|
- :meth:`ping` raises :exc:`ConnectionClosed`, or
|
||||||
|
- :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`.
|
||||||
|
"""
|
||||||
|
if self.ping_interval is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self.ping_interval)
|
||||||
|
|
||||||
|
# ping() raises CancelledError if the connection is closed,
|
||||||
|
# when auto_close_connection() cancels keepalive_ping_task.
|
||||||
|
|
||||||
|
# ping() raises ConnectionClosed if the connection is lost,
|
||||||
|
# when connection_lost() calls abort_pings().
|
||||||
|
|
||||||
|
ping_waiter = await self.ping()
|
||||||
|
|
||||||
|
if self.ping_timeout is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(ping_waiter, self.ping_timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
error_logger.warning(
|
||||||
|
"Websocket timed out waiting for pong"
|
||||||
|
)
|
||||||
|
self.fail_connection(1011)
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# It is expected for this task to be cancelled during during
|
||||||
|
# normal operation, when the connection is closed.
|
||||||
|
logger.debug("Websocket keepalive ping task was cancelled.")
|
||||||
|
except ConnectionClosed:
|
||||||
|
logger.debug("Websocket closed. Keepalive ping task exiting.")
|
||||||
|
except Exception as e:
|
||||||
|
error_logger.warning(
|
||||||
|
"Unexpected exception in websocket keepalive ping task."
|
||||||
|
)
|
||||||
|
logger.debug(str(e))
|
||||||
|
|
||||||
|
def _force_disconnect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Internal methdod used by end_connection and fail_connection
|
||||||
|
only when the graceful auto-closer cannot be used
|
||||||
|
"""
|
||||||
|
if self.auto_closer_task and not self.auto_closer_task.done():
|
||||||
|
self.auto_closer_task.cancel()
|
||||||
|
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||||
|
self.data_finished_fut.cancel()
|
||||||
|
self.data_finished_fut = None
|
||||||
|
if self.keepalive_ping_task and not self.keepalive_ping_task.done():
|
||||||
|
self.keepalive_ping_task.cancel()
|
||||||
|
self.keepalive_ping_task = None
|
||||||
|
if self.loop and self.io_proto and self.io_proto.transport:
|
||||||
|
self.io_proto.transport.close()
|
||||||
|
self.loop.call_later(
|
||||||
|
self.close_timeout, self.io_proto.transport.abort
|
||||||
|
)
|
||||||
|
# We were never open, or already closed
|
||||||
|
return True
|
||||||
|
|
||||||
|
def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
|
||||||
|
"""
|
||||||
|
Fail the WebSocket Connection
|
||||||
|
This requires:
|
||||||
|
1. Stopping all processing of incoming data, which means cancelling
|
||||||
|
pausing the underlying io protocol. The close code will be 1006
|
||||||
|
unless a close frame was received earlier.
|
||||||
|
2. Sending a close frame with an appropriate code if the opening
|
||||||
|
handshake succeeded and the other side is likely to process it.
|
||||||
|
3. Closing the connection. :meth:`auto_close_connection` takes care
|
||||||
|
of this.
|
||||||
|
(The specification describes these steps in the opposite order.)
|
||||||
|
"""
|
||||||
|
if self.io_proto and self.io_proto.transport:
|
||||||
|
# Stop new data coming in
|
||||||
|
# In Python Version 3.7: pause_reading is idempotent
|
||||||
|
# ut can be called when the transport is already paused or closed
|
||||||
|
self.io_proto.transport.pause_reading()
|
||||||
|
|
||||||
|
# Keeping fail_connection() synchronous guarantees it can't
|
||||||
|
# get stuck and simplifies the implementation of the callers.
|
||||||
|
# Not draining the write buffer is acceptable in this context.
|
||||||
|
|
||||||
|
# clear the send buffer
|
||||||
|
_ = self.connection.data_to_send()
|
||||||
|
# If we're not already CLOSED or CLOSING, then send the close.
|
||||||
|
if self.connection.state is OPEN:
|
||||||
|
if code in (1000, 1001):
|
||||||
|
self.connection.send_close(code, reason)
|
||||||
|
else:
|
||||||
|
self.connection.fail(code, reason)
|
||||||
|
try:
|
||||||
|
data_to_send = self.connection.data_to_send()
|
||||||
|
while (
|
||||||
|
len(data_to_send)
|
||||||
|
and self.io_proto
|
||||||
|
and self.io_proto.transport
|
||||||
|
):
|
||||||
|
frame_data = data_to_send.pop(0)
|
||||||
|
self.io_proto.transport.write(frame_data)
|
||||||
|
except Exception:
|
||||||
|
# sending close frames may fail if the
|
||||||
|
# transport closes during this period
|
||||||
|
...
|
||||||
|
if code == 1006:
|
||||||
|
# Special case: 1006 consider the transport already closed
|
||||||
|
self.connection.state = CLOSED
|
||||||
|
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||||
|
# We have a graceful auto-closer. Use it to close the connection.
|
||||||
|
self.data_finished_fut.cancel()
|
||||||
|
self.data_finished_fut = None
|
||||||
|
if (not self.auto_closer_task) or self.auto_closer_task.done():
|
||||||
|
return self._force_disconnect()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def end_connection(self, code=1000, reason=""):
|
||||||
|
# This is like slightly more graceful form of fail_connection
|
||||||
|
# Use this instead of close() when you need an immediate
|
||||||
|
# close and cannot await websocket.close() handshake.
|
||||||
|
|
||||||
|
if code == 1006 or not self.io_proto or not self.io_proto.transport:
|
||||||
|
return self.fail_connection(code, reason)
|
||||||
|
|
||||||
|
# Stop new data coming in
|
||||||
|
# In Python Version 3.7: pause_reading is idempotent
|
||||||
|
# i.e. it can be called when the transport is already paused or closed.
|
||||||
|
self.io_proto.transport.pause_reading()
|
||||||
|
if self.connection.state == OPEN:
|
||||||
|
data_to_send = self.connection.data_to_send()
|
||||||
|
self.connection.send_close(code, reason)
|
||||||
|
data_to_send.extend(self.connection.data_to_send())
|
||||||
|
try:
|
||||||
|
while (
|
||||||
|
len(data_to_send)
|
||||||
|
and self.io_proto
|
||||||
|
and self.io_proto.transport
|
||||||
|
):
|
||||||
|
frame_data = data_to_send.pop(0)
|
||||||
|
self.io_proto.transport.write(frame_data)
|
||||||
|
except Exception:
|
||||||
|
# sending close frames may fail if the
|
||||||
|
# transport closes during this period
|
||||||
|
# But that doesn't matter at this point
|
||||||
|
...
|
||||||
|
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||||
|
# We have the ability to signal the auto-closer
|
||||||
|
# try to trigger it to auto-close the connection
|
||||||
|
self.data_finished_fut.cancel()
|
||||||
|
self.data_finished_fut = None
|
||||||
|
if (not self.auto_closer_task) or self.auto_closer_task.done():
|
||||||
|
# Auto-closer is not running, do force disconnect
|
||||||
|
return self._force_disconnect()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def auto_close_connection(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the WebSocket Connection
|
||||||
|
When the opening handshake succeeds, :meth:`connection_open` starts
|
||||||
|
this coroutine in a task. It waits for the data transfer phase to
|
||||||
|
complete then it closes the TCP connection cleanly.
|
||||||
|
When the opening handshake fails, :meth:`fail_connection` does the
|
||||||
|
same. There's no data transfer phase in that case.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Wait for the data transfer phase to complete.
|
||||||
|
if self.data_finished_fut:
|
||||||
|
try:
|
||||||
|
await self.data_finished_fut
|
||||||
|
logger.debug(
|
||||||
|
"Websocket task finished. Closing the connection."
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Cancelled error is called when data phase is cancelled
|
||||||
|
# if an error occurred or the client closed the connection
|
||||||
|
logger.debug(
|
||||||
|
"Websocket handler cancelled. Closing the connection."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel the keepalive ping task.
|
||||||
|
if self.keepalive_ping_task:
|
||||||
|
self.keepalive_ping_task.cancel()
|
||||||
|
self.keepalive_ping_task = None
|
||||||
|
|
||||||
|
# Half-close the TCP connection if possible (when there's no TLS).
|
||||||
|
if (
|
||||||
|
self.io_proto
|
||||||
|
and self.io_proto.transport
|
||||||
|
and self.io_proto.transport.can_write_eof()
|
||||||
|
):
|
||||||
|
logger.debug("Websocket half-closing TCP connection")
|
||||||
|
self.io_proto.transport.write_eof()
|
||||||
|
if self.connection_lost_waiter:
|
||||||
|
if await self.wait_for_connection_lost(timeout=0):
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
...
|
||||||
|
finally:
|
||||||
|
# The try/finally ensures that the transport never remains open,
|
||||||
|
# even if this coroutine is cancelled (for example).
|
||||||
|
if (not self.io_proto) or (not self.io_proto.transport):
|
||||||
|
# we were never open, or done. Can't do any finalization.
|
||||||
|
return
|
||||||
|
elif (
|
||||||
|
self.connection_lost_waiter
|
||||||
|
and self.connection_lost_waiter.done()
|
||||||
|
):
|
||||||
|
# connection confirmed closed already, proceed to abort waiter
|
||||||
|
...
|
||||||
|
elif self.io_proto.transport.is_closing():
|
||||||
|
# Connection is already closing (due to half-close above)
|
||||||
|
# proceed to abort waiter
|
||||||
|
...
|
||||||
|
else:
|
||||||
|
self.io_proto.transport.close()
|
||||||
|
if not self.connection_lost_waiter:
|
||||||
|
# Our connection monitor task isn't running.
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.close_timeout)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
...
|
||||||
|
if self.io_proto and self.io_proto.transport:
|
||||||
|
self.io_proto.transport.abort()
|
||||||
|
else:
|
||||||
|
if await self.wait_for_connection_lost(
|
||||||
|
timeout=self.close_timeout
|
||||||
|
):
|
||||||
|
# Connection aborted before the timeout expired.
|
||||||
|
return
|
||||||
|
error_logger.warning(
|
||||||
|
"Timeout waiting for TCP connection to close. Aborting"
|
||||||
|
)
|
||||||
|
if self.io_proto and self.io_proto.transport:
|
||||||
|
self.io_proto.transport.abort()
|
||||||
|
|
||||||
|
def abort_pings(self) -> None:
|
||||||
|
"""
|
||||||
|
Raise ConnectionClosed in pending keepalive pings.
|
||||||
|
They'll never receive a pong once the connection is closed.
|
||||||
|
"""
|
||||||
|
if self.connection.state is not CLOSED:
|
||||||
|
raise ServerError(
|
||||||
|
"Webscoket about_pings should only be called "
|
||||||
|
"after connection state is changed to CLOSED"
|
||||||
|
)
|
||||||
|
|
||||||
|
for ping in self.pings.values():
|
||||||
|
ping.set_exception(ConnectionClosedError(None, None))
|
||||||
|
# If the exception is never retrieved, it will be logged when ping
|
||||||
|
# is garbage-collected. This is confusing for users.
|
||||||
|
# Given that ping is done (with an exception), canceling it does
|
||||||
|
# nothing, but it prevents logging the exception.
|
||||||
|
ping.cancel()
|
||||||
|
|
||||||
|
async def close(self, code: int = 1000, reason: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Perform the closing handshake.
|
||||||
|
This is a websocket-protocol level close.
|
||||||
|
:meth:`close` waits for the other end to complete the handshake and
|
||||||
|
for the TCP connection to terminate.
|
||||||
|
:meth:`close` is idempotent: it doesn't do anything once the
|
||||||
|
connection is closed.
|
||||||
|
:param code: WebSocket close code
|
||||||
|
:param reason: WebSocket close reason
|
||||||
|
"""
|
||||||
|
if code == 1006:
|
||||||
|
self.fail_connection(code, reason)
|
||||||
|
return
|
||||||
|
async with self.conn_mutex:
|
||||||
|
if self.connection.state is OPEN:
|
||||||
|
self.connection.send_close(code, reason)
|
||||||
|
data_to_send = self.connection.data_to_send()
|
||||||
|
await self.send_data(data_to_send)
|
||||||
|
|
||||||
|
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||||
|
"""
|
||||||
|
Receive the next message.
|
||||||
|
Return a :class:`str` for a text frame and :class:`bytes` for a binary
|
||||||
|
frame.
|
||||||
|
When the end of the message stream is reached, :meth:`recv` raises
|
||||||
|
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
|
||||||
|
raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
|
||||||
|
connection closure and
|
||||||
|
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||||
|
error or a network failure.
|
||||||
|
If ``timeout`` is ``None``, block until a message is received. Else,
|
||||||
|
if no message is received within ``timeout`` seconds, return ``None``.
|
||||||
|
Set ``timeout`` to ``0`` to check if a message was already received.
|
||||||
|
:raises ~websockets.exceptions.ConnectionClosed: when the
|
||||||
|
connection is closed
|
||||||
|
:raises ServerError: if two tasks call :meth:`recv` or
|
||||||
|
:meth:`recv_streaming` concurrently
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.recv_lock.locked():
|
||||||
|
raise ServerError(
|
||||||
|
"cannot call recv while another task is "
|
||||||
|
"already waiting for the next message"
|
||||||
|
)
|
||||||
|
await self.recv_lock.acquire()
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return await self.assembler.get(timeout)
|
||||||
|
finally:
|
||||||
|
self.recv_lock.release()
|
||||||
|
|
||||||
|
async def recv_burst(self, max_recv=256) -> Sequence[Data]:
|
||||||
|
"""
|
||||||
|
Receive the messages which have arrived since last checking.
|
||||||
|
Return a :class:`list` containing :class:`str` for a text frame
|
||||||
|
and :class:`bytes` for a binary frame.
|
||||||
|
When the end of the message stream is reached, :meth:`recv_burst`
|
||||||
|
raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically,
|
||||||
|
it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a
|
||||||
|
normal connection closure and
|
||||||
|
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
|
||||||
|
error or a network failure.
|
||||||
|
:raises ~websockets.exceptions.ConnectionClosed: when the
|
||||||
|
connection is closed
|
||||||
|
:raises ServerError: if two tasks call :meth:`recv_burst` or
|
||||||
|
:meth:`recv_streaming` concurrently
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.recv_lock.locked():
|
||||||
|
raise ServerError(
|
||||||
|
"cannot call recv_burst while another task is already waiting "
|
||||||
|
"for the next message"
|
||||||
|
)
|
||||||
|
await self.recv_lock.acquire()
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
try:
|
||||||
|
# Prevent pausing the transport when we're
|
||||||
|
# receiving a burst of messages
|
||||||
|
self.can_pause = False
|
||||||
|
while True:
|
||||||
|
m = await self.assembler.get(timeout=0)
|
||||||
|
if m is None:
|
||||||
|
# None left in the burst. This is good!
|
||||||
|
break
|
||||||
|
messages.append(m)
|
||||||
|
if len(messages) >= max_recv:
|
||||||
|
# Too much data in the pipe. Hit our burst limit.
|
||||||
|
break
|
||||||
|
# Allow an eventloop iteration for the
|
||||||
|
# next message to pass into the Assembler
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
finally:
|
||||||
|
self.can_pause = True
|
||||||
|
self.recv_lock.release()
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def recv_streaming(self) -> AsyncIterator[Data]:
|
||||||
|
"""
|
||||||
|
Receive the next message frame by frame.
|
||||||
|
Return an iterator of :class:`str` for a text frame and :class:`bytes`
|
||||||
|
for a binary frame. The iterator should be exhausted, or else the
|
||||||
|
connection will become unusable.
|
||||||
|
With the exception of the return value, :meth:`recv_streaming` behaves
|
||||||
|
like :meth:`recv`.
|
||||||
|
"""
|
||||||
|
if self.recv_lock.locked():
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot call recv_streaming while another task "
|
||||||
|
"is already waiting for the next message"
|
||||||
|
)
|
||||||
|
await self.recv_lock.acquire()
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.can_pause = False
|
||||||
|
async for m in self.assembler.get_iter():
|
||||||
|
yield m
|
||||||
|
finally:
|
||||||
|
self.can_pause = True
|
||||||
|
self.recv_lock.release()
|
||||||
|
|
||||||
|
async def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||||
|
"""
|
||||||
|
Send a message.
|
||||||
|
A string (:class:`str`) is sent as a `Text frame`_. A bytestring or
|
||||||
|
bytes-like object (:class:`bytes`, :class:`bytearray`, or
|
||||||
|
:class:`memoryview`) is sent as a `Binary frame`_.
|
||||||
|
.. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6
|
||||||
|
.. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6
|
||||||
|
:meth:`send` also accepts an iterable of strings, bytestrings, or
|
||||||
|
bytes-like objects. In that case the message is fragmented. Each item
|
||||||
|
is treated as a message fragment and sent in its own frame. All items
|
||||||
|
must be of the same type, or else :meth:`send` will raise a
|
||||||
|
:exc:`TypeError` and the connection will be closed.
|
||||||
|
:meth:`send` rejects dict-like objects because this is often an error.
|
||||||
|
If you wish to send the keys of a dict-like object as fragments, call
|
||||||
|
its :meth:`~dict.keys` method and pass the result to :meth:`send`.
|
||||||
|
:raises TypeError: for unsupported inputs
|
||||||
|
"""
|
||||||
|
async with self.conn_mutex:
|
||||||
|
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot write to websocket interface after it is closed."
|
||||||
|
)
|
||||||
|
if (not self.data_finished_fut) or self.data_finished_fut.done():
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot write to websocket interface after it is finished."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unfragmented message -- this case must be handled first because
|
||||||
|
# strings and bytes-like objects are iterable.
|
||||||
|
|
||||||
|
if isinstance(message, str):
|
||||||
|
self.connection.send_text(message.encode("utf-8"))
|
||||||
|
await self.send_data(self.connection.data_to_send())
|
||||||
|
|
||||||
|
elif isinstance(message, (bytes, bytearray, memoryview)):
|
||||||
|
self.connection.send_binary(message)
|
||||||
|
await self.send_data(self.connection.data_to_send())
|
||||||
|
|
||||||
|
elif isinstance(message, Mapping):
|
||||||
|
# Catch a common mistake -- passing a dict to send().
|
||||||
|
raise TypeError("data is a dict-like object")
|
||||||
|
|
||||||
|
elif isinstance(message, Iterable):
|
||||||
|
# Fragmented message -- regular iterator.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Fragmented websocket messages are not supported."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError("Websocket data must be bytes, str.")
|
||||||
|
|
||||||
|
async def ping(self, data: Optional[Data] = None) -> asyncio.Future:
|
||||||
|
"""
|
||||||
|
Send a ping.
|
||||||
|
Return an :class:`~asyncio.Future` that will be resolved when the
|
||||||
|
corresponding pong is received. You can ignore it if you don't intend
|
||||||
|
to wait.
|
||||||
|
A ping may serve as a keepalive or as a check that the remote endpoint
|
||||||
|
received all messages up to this point::
|
||||||
|
await pong_event = ws.ping()
|
||||||
|
await pong_event # only if you want to wait for the pong
|
||||||
|
By default, the ping contains four random bytes. This payload may be
|
||||||
|
overridden with the optional ``data`` argument which must be a string
|
||||||
|
(which will be encoded to UTF-8) or a bytes-like object.
|
||||||
|
"""
|
||||||
|
async with self.conn_mutex:
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot send a ping when the websocket interface "
|
||||||
|
"is closed."
|
||||||
|
)
|
||||||
|
if (not self.io_proto) or (not self.io_proto.loop):
|
||||||
|
raise ServerError(
|
||||||
|
"Cannot send a ping when the websocket has no I/O "
|
||||||
|
"protocol attached."
|
||||||
|
)
|
||||||
|
if data is not None:
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
elif isinstance(data, (bytearray, memoryview)):
|
||||||
|
data = bytes(data)
|
||||||
|
|
||||||
|
# Protect against duplicates if a payload is explicitly set.
|
||||||
|
if data in self.pings:
|
||||||
|
raise ValueError(
|
||||||
|
"already waiting for a pong with the same data"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate a unique random payload otherwise.
|
||||||
|
while data is None or data in self.pings:
|
||||||
|
data = struct.pack("!I", random.getrandbits(32))
|
||||||
|
|
||||||
|
self.pings[data] = self.io_proto.loop.create_future()
|
||||||
|
|
||||||
|
self.connection.send_ping(data)
|
||||||
|
await self.send_data(self.connection.data_to_send())
|
||||||
|
|
||||||
|
return asyncio.shield(self.pings[data])
|
||||||
|
|
||||||
|
async def pong(self, data: Data = b"") -> None:
|
||||||
|
"""
|
||||||
|
Send a pong.
|
||||||
|
An unsolicited pong may serve as a unidirectional heartbeat.
|
||||||
|
The payload may be set with the optional ``data`` argument which must
|
||||||
|
be a string (which will be encoded to UTF-8) or a bytes-like object.
|
||||||
|
"""
|
||||||
|
async with self.conn_mutex:
|
||||||
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
|
# Cannot send pong after transport is shutting down
|
||||||
|
return
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
elif isinstance(data, (bytearray, memoryview)):
|
||||||
|
data = bytes(data)
|
||||||
|
self.connection.send_pong(data)
|
||||||
|
await self.send_data(self.connection.data_to_send())
|
||||||
|
|
||||||
|
async def send_data(self, data_to_send):
|
||||||
|
for data in data_to_send:
|
||||||
|
if data:
|
||||||
|
await self.io_proto.send(data)
|
||||||
|
else:
|
||||||
|
# Send an EOF - We don't actually send it,
|
||||||
|
# just trigger to autoclose the connection
|
||||||
|
if (
|
||||||
|
self.auto_closer_task
|
||||||
|
and not self.auto_closer_task.done()
|
||||||
|
and self.data_finished_fut
|
||||||
|
and not self.data_finished_fut.done()
|
||||||
|
):
|
||||||
|
# Auto-close the connection
|
||||||
|
self.data_finished_fut.set_result(None)
|
||||||
|
else:
|
||||||
|
# This will fail the connection appropriately
|
||||||
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
|
||||||
|
async def async_data_received(self, data_to_send, events_to_process):
|
||||||
|
if self.connection.state == OPEN and len(data_to_send) > 0:
|
||||||
|
# receiving data can generate data to send (eg, pong for a ping)
|
||||||
|
# send connection.data_to_send()
|
||||||
|
await self.send_data(data_to_send)
|
||||||
|
if len(events_to_process) > 0:
|
||||||
|
await self.process_events(events_to_process)
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
self.connection.receive_data(data)
|
||||||
|
data_to_send = self.connection.data_to_send()
|
||||||
|
events_to_process = self.connection.events_received()
|
||||||
|
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.async_data_received(data_to_send, events_to_process)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_eof_received(self, data_to_send, events_to_process):
|
||||||
|
# receiving EOF can generate data to send
|
||||||
|
# send connection.data_to_send()
|
||||||
|
if self.connection.state == OPEN:
|
||||||
|
await self.send_data(data_to_send)
|
||||||
|
if len(events_to_process) > 0:
|
||||||
|
await self.process_events(events_to_process)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.auto_closer_task
|
||||||
|
and not self.auto_closer_task.done()
|
||||||
|
and self.data_finished_fut
|
||||||
|
and not self.data_finished_fut.done()
|
||||||
|
):
|
||||||
|
# Auto-close the connection
|
||||||
|
self.data_finished_fut.set_result(None)
|
||||||
|
else:
|
||||||
|
# This will fail the connection appropriately
|
||||||
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
|
||||||
|
def eof_received(self) -> Optional[bool]:
|
||||||
|
self.connection.receive_eof()
|
||||||
|
data_to_send = self.connection.data_to_send()
|
||||||
|
events_to_process = self.connection.events_received()
|
||||||
|
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.async_eof_received(data_to_send, events_to_process)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
"""
|
||||||
|
The WebSocket Connection is Closed.
|
||||||
|
"""
|
||||||
|
if not self.connection.state == CLOSED:
|
||||||
|
# signal to the websocket connection handler
|
||||||
|
# we've lost the connection
|
||||||
|
self.connection.fail(code=1006)
|
||||||
|
self.connection.state = CLOSED
|
||||||
|
|
||||||
|
self.abort_pings()
|
||||||
|
if self.connection_lost_waiter:
|
||||||
|
self.connection_lost_waiter.set_result(None)
|
|
@ -1,205 +0,0 @@
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
MutableMapping,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from httptools import HttpParserUpgrade # type: ignore
|
|
||||||
from websockets import ( # type: ignore
|
|
||||||
ConnectionClosed,
|
|
||||||
InvalidHandshake,
|
|
||||||
WebSocketCommonProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Despite the "legacy" namespace, the primary maintainer of websockets
|
|
||||||
# committed to maintaining backwards-compatibility until 2026 and will
|
|
||||||
# consider extending it if sanic continues depending on this module.
|
|
||||||
from websockets.legacy import handshake
|
|
||||||
|
|
||||||
from sanic.exceptions import InvalidUsage
|
|
||||||
from sanic.server import HttpProtocol
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"]
|
|
||||||
|
|
||||||
ASIMessage = MutableMapping[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketProtocol(HttpProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
websocket_timeout=10,
|
|
||||||
websocket_max_size=None,
|
|
||||||
websocket_max_queue=None,
|
|
||||||
websocket_read_limit=2 ** 16,
|
|
||||||
websocket_write_limit=2 ** 16,
|
|
||||||
websocket_ping_interval=20,
|
|
||||||
websocket_ping_timeout=20,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.websocket = None
|
|
||||||
# self.app = None
|
|
||||||
self.websocket_timeout = websocket_timeout
|
|
||||||
self.websocket_max_size = websocket_max_size
|
|
||||||
self.websocket_max_queue = websocket_max_queue
|
|
||||||
self.websocket_read_limit = websocket_read_limit
|
|
||||||
self.websocket_write_limit = websocket_write_limit
|
|
||||||
self.websocket_ping_interval = websocket_ping_interval
|
|
||||||
self.websocket_ping_timeout = websocket_ping_timeout
|
|
||||||
|
|
||||||
# timeouts make no sense for websocket routes
|
|
||||||
def request_timeout_callback(self):
|
|
||||||
if self.websocket is None:
|
|
||||||
super().request_timeout_callback()
|
|
||||||
|
|
||||||
def response_timeout_callback(self):
|
|
||||||
if self.websocket is None:
|
|
||||||
super().response_timeout_callback()
|
|
||||||
|
|
||||||
def keep_alive_timeout_callback(self):
|
|
||||||
if self.websocket is None:
|
|
||||||
super().keep_alive_timeout_callback()
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
|
||||||
if self.websocket is not None:
|
|
||||||
self.websocket.connection_lost(exc)
|
|
||||||
super().connection_lost(exc)
|
|
||||||
|
|
||||||
def data_received(self, data):
|
|
||||||
if self.websocket is not None:
|
|
||||||
# pass the data to the websocket protocol
|
|
||||||
self.websocket.data_received(data)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
super().data_received(data)
|
|
||||||
except HttpParserUpgrade:
|
|
||||||
# this is okay, it just indicates we've got an upgrade request
|
|
||||||
pass
|
|
||||||
|
|
||||||
def write_response(self, response):
|
|
||||||
if self.websocket is not None:
|
|
||||||
# websocket requests do not write a response
|
|
||||||
self.transport.close()
|
|
||||||
else:
|
|
||||||
super().write_response(response)
|
|
||||||
|
|
||||||
async def websocket_handshake(self, request, subprotocols=None):
|
|
||||||
# let the websockets package do the handshake with the client
|
|
||||||
headers = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
key = handshake.check_request(request.headers)
|
|
||||||
handshake.build_response(headers, key)
|
|
||||||
except InvalidHandshake:
|
|
||||||
raise InvalidUsage("Invalid websocket request")
|
|
||||||
|
|
||||||
subprotocol = None
|
|
||||||
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
|
|
||||||
# select a subprotocol
|
|
||||||
client_subprotocols = [
|
|
||||||
p.strip()
|
|
||||||
for p in request.headers["Sec-Websocket-Protocol"].split(",")
|
|
||||||
]
|
|
||||||
for p in client_subprotocols:
|
|
||||||
if p in subprotocols:
|
|
||||||
subprotocol = p
|
|
||||||
headers["Sec-Websocket-Protocol"] = subprotocol
|
|
||||||
break
|
|
||||||
|
|
||||||
# write the 101 response back to the client
|
|
||||||
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
|
|
||||||
for k, v in headers.items():
|
|
||||||
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
|
|
||||||
rv += b"\r\n"
|
|
||||||
request.transport.write(rv)
|
|
||||||
|
|
||||||
# hook up the websocket protocol
|
|
||||||
self.websocket = WebSocketCommonProtocol(
|
|
||||||
close_timeout=self.websocket_timeout,
|
|
||||||
max_size=self.websocket_max_size,
|
|
||||||
max_queue=self.websocket_max_queue,
|
|
||||||
read_limit=self.websocket_read_limit,
|
|
||||||
write_limit=self.websocket_write_limit,
|
|
||||||
ping_interval=self.websocket_ping_interval,
|
|
||||||
ping_timeout=self.websocket_ping_timeout,
|
|
||||||
)
|
|
||||||
# we use WebSocketCommonProtocol because we don't want the handshake
|
|
||||||
# logic from WebSocketServerProtocol; however, we must tell it that
|
|
||||||
# we're running on the server side
|
|
||||||
self.websocket.is_client = False
|
|
||||||
self.websocket.side = "server"
|
|
||||||
self.websocket.subprotocol = subprotocol
|
|
||||||
self.websocket.connection_made(request.transport)
|
|
||||||
self.websocket.connection_open()
|
|
||||||
return self.websocket
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConnection:
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# - Implement ping/pong
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
send: Callable[[ASIMessage], Awaitable[None]],
|
|
||||||
receive: Callable[[], Awaitable[ASIMessage]],
|
|
||||||
subprotocols: Optional[List[str]] = None,
|
|
||||||
) -> None:
|
|
||||||
self._send = send
|
|
||||||
self._receive = receive
|
|
||||||
self._subprotocols = subprotocols or []
|
|
||||||
|
|
||||||
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
|
|
||||||
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
|
|
||||||
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
message.update({"bytes": data})
|
|
||||||
else:
|
|
||||||
message.update({"text": str(data)})
|
|
||||||
|
|
||||||
await self._send(message)
|
|
||||||
|
|
||||||
async def recv(self, *args, **kwargs) -> Optional[str]:
|
|
||||||
message = await self._receive()
|
|
||||||
|
|
||||||
if message["type"] == "websocket.receive":
|
|
||||||
return message["text"]
|
|
||||||
elif message["type"] == "websocket.disconnect":
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
receive = recv
|
|
||||||
|
|
||||||
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
|
|
||||||
subprotocol = None
|
|
||||||
if subprotocols:
|
|
||||||
for subp in subprotocols:
|
|
||||||
if subp in self.subprotocols:
|
|
||||||
subprotocol = subp
|
|
||||||
break
|
|
||||||
|
|
||||||
await self._send(
|
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
"subprotocol": subprotocol,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def subprotocols(self):
|
|
||||||
return self._subprotocols
|
|
||||||
|
|
||||||
@subprotocols.setter
|
|
||||||
def subprotocols(self, subprotocols: Optional[List[str]] = None):
|
|
||||||
self._subprotocols = subprotocols or []
|
|
|
@ -9,7 +9,7 @@ from gunicorn.workers import base # type: ignore
|
||||||
|
|
||||||
from sanic.log import logger
|
from sanic.log import logger
|
||||||
from sanic.server import HttpProtocol, Signal, serve
|
from sanic.server import HttpProtocol, Signal, serve
|
||||||
from sanic.websocket import WebSocketProtocol
|
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -142,14 +142,11 @@ class GunicornWorker(base.Worker):
|
||||||
|
|
||||||
# Force close non-idle connection after waiting for
|
# Force close non-idle connection after waiting for
|
||||||
# graceful_shutdown_timeout
|
# graceful_shutdown_timeout
|
||||||
coros = []
|
|
||||||
for conn in self.connections:
|
for conn in self.connections:
|
||||||
if hasattr(conn, "websocket") and conn.websocket:
|
if hasattr(conn, "websocket") and conn.websocket:
|
||||||
coros.append(conn.websocket.close_connection())
|
conn.websocket.fail_connection(code=1001)
|
||||||
else:
|
else:
|
||||||
conn.close()
|
conn.abort()
|
||||||
_shutdown = asyncio.gather(*coros, loop=self.loop)
|
|
||||||
await _shutdown
|
|
||||||
|
|
||||||
async def _run(self):
|
async def _run(self):
|
||||||
for sock in self.sockets:
|
for sock in self.sockets:
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -88,7 +88,7 @@ requirements = [
|
||||||
uvloop,
|
uvloop,
|
||||||
ujson,
|
ujson,
|
||||||
"aiofiles>=0.6.0",
|
"aiofiles>=0.6.0",
|
||||||
"websockets>=9.0",
|
"websockets>=10.0",
|
||||||
"multidict>=5.0,<6.0",
|
"multidict>=5.0,<6.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable):
|
||||||
@patch("sanic.app.WebSocketProtocol")
|
@patch("sanic.app.WebSocketProtocol")
|
||||||
def test_app_websocket_parameters(websocket_protocol_mock, app):
|
def test_app_websocket_parameters(websocket_protocol_mock, app):
|
||||||
app.config.WEBSOCKET_MAX_SIZE = 44
|
app.config.WEBSOCKET_MAX_SIZE = 44
|
||||||
app.config.WEBSOCKET_MAX_QUEUE = 45
|
|
||||||
app.config.WEBSOCKET_READ_LIMIT = 46
|
|
||||||
app.config.WEBSOCKET_WRITE_LIMIT = 47
|
|
||||||
app.config.WEBSOCKET_PING_TIMEOUT = 48
|
app.config.WEBSOCKET_PING_TIMEOUT = 48
|
||||||
app.config.WEBSOCKET_PING_INTERVAL = 50
|
app.config.WEBSOCKET_PING_INTERVAL = 50
|
||||||
|
|
||||||
|
@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
|
||||||
websocket_protocol_call_args = websocket_protocol_mock.call_args
|
websocket_protocol_call_args = websocket_protocol_mock.call_args
|
||||||
ws_kwargs = websocket_protocol_call_args[1]
|
ws_kwargs = websocket_protocol_call_args[1]
|
||||||
assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
|
assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
|
||||||
assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
|
|
||||||
assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT
|
|
||||||
assert (
|
|
||||||
ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
ws_kwargs["websocket_ping_timeout"]
|
ws_kwargs["websocket_ping_timeout"]
|
||||||
== app.config.WEBSOCKET_PING_TIMEOUT
|
== app.config.WEBSOCKET_PING_TIMEOUT
|
||||||
|
|
|
@ -10,7 +10,7 @@ from sanic.asgi import MockTransport
|
||||||
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
|
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
|
||||||
from sanic.request import Request
|
from sanic.request import Request
|
||||||
from sanic.response import json, text
|
from sanic.response import json, text
|
||||||
from sanic.websocket import WebSocketConnection
|
from sanic.server.websockets.connection import WebSocketConnection
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -16,6 +16,7 @@ from sanic.exceptions import (
|
||||||
abort,
|
abort,
|
||||||
)
|
)
|
||||||
from sanic.response import text
|
from sanic.response import text
|
||||||
|
from websockets.version import version as websockets_version
|
||||||
|
|
||||||
|
|
||||||
class SanicExceptionTestException(Exception):
|
class SanicExceptionTestException(Exception):
|
||||||
|
@ -260,9 +261,14 @@ def test_exception_in_ws_logged(caplog):
|
||||||
|
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
app.test_client.websocket("/feed")
|
app.test_client.websocket("/feed")
|
||||||
|
# Websockets v10.0 and above output an additional
|
||||||
assert caplog.record_tuples[1][0] == "sanic.error"
|
# INFO message when a ws connection is accepted
|
||||||
assert caplog.record_tuples[1][1] == logging.ERROR
|
ws_version_parts = websockets_version.split(".")
|
||||||
|
ws_major = int(ws_version_parts[0])
|
||||||
|
record_index = 2 if ws_major >= 10 else 1
|
||||||
|
assert caplog.record_tuples[record_index][0] == "sanic.error"
|
||||||
|
assert caplog.record_tuples[record_index][1] == logging.ERROR
|
||||||
assert (
|
assert (
|
||||||
"Exception occurred while handling uri:" in caplog.record_tuples[1][2]
|
"Exception occurred while handling uri:"
|
||||||
|
in caplog.record_tuples[record_index][2]
|
||||||
)
|
)
|
||||||
|
|
|
@ -674,16 +674,16 @@ async def test_websocket_route_asgi(app, url):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"subprotocols,expected",
|
"subprotocols,expected",
|
||||||
(
|
(
|
||||||
(["bar"], "bar"),
|
(["one"], "one"),
|
||||||
(["bar", "foo"], "bar"),
|
(["three", "one"], "one"),
|
||||||
(["baz"], None),
|
(["tree"], None),
|
||||||
(None, None),
|
(None, None),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
|
def test_websocket_route_with_subprotocols(app, subprotocols, expected):
|
||||||
results = "unset"
|
results = []
|
||||||
|
|
||||||
@app.websocket("/ws", subprotocols=["foo", "bar"])
|
@app.websocket("/ws", subprotocols=["zero", "one", "two", "three"])
|
||||||
async def handler(request, ws):
|
async def handler(request, ws):
|
||||||
nonlocal results
|
nonlocal results
|
||||||
results = ws.subprotocol
|
results = ws.subprotocol
|
||||||
|
|
|
@ -175,7 +175,7 @@ def test_worker_close(worker):
|
||||||
worker.wsgi = mock.Mock()
|
worker.wsgi = mock.Mock()
|
||||||
conn = mock.Mock()
|
conn = mock.Mock()
|
||||||
conn.websocket = mock.Mock()
|
conn.websocket = mock.Mock()
|
||||||
conn.websocket.close_connection = mock.Mock(wraps=_a_noop)
|
conn.websocket.fail_connection = mock.Mock(wraps=_a_noop)
|
||||||
worker.connections = set([conn])
|
worker.connections = set([conn])
|
||||||
worker.log = mock.Mock()
|
worker.log = mock.Mock()
|
||||||
worker.loop = loop
|
worker.loop = loop
|
||||||
|
@ -190,5 +190,5 @@ def test_worker_close(worker):
|
||||||
loop.run_until_complete(_close)
|
loop.run_until_complete(_close)
|
||||||
|
|
||||||
assert worker.signal.stopped
|
assert worker.signal.stopped
|
||||||
assert conn.websocket.close_connection.called
|
assert conn.websocket.fail_connection.called
|
||||||
assert len(worker.servers) == 0
|
assert len(worker.servers) == 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user