Add compatibility with websockets 11.0. (#2609)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
parent
beae35f921
commit
4c14910d5b
|
@ -1,7 +1,13 @@
|
|||
from typing import TYPE_CHECKING, Optional, Sequence, cast
|
||||
|
||||
from websockets.connection import CLOSED, CLOSING, OPEN
|
||||
from websockets.server import ServerConnection
|
||||
|
||||
try: # websockets < 11.0
|
||||
from websockets.connection import State
|
||||
from websockets.server import ServerConnection as ServerProtocol
|
||||
except ImportError: # websockets >= 11.0
|
||||
from websockets.protocol import State # type: ignore
|
||||
from websockets.server import ServerProtocol # type: ignore
|
||||
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
from sanic.exceptions import ServerError
|
||||
|
@ -15,6 +21,11 @@ if TYPE_CHECKING:
|
|||
from websockets import http11
|
||||
|
||||
|
||||
OPEN = State.OPEN
|
||||
CLOSING = State.CLOSING
|
||||
CLOSED = State.CLOSED
|
||||
|
||||
|
||||
class WebSocketProtocol(HttpProtocol):
|
||||
__slots__ = (
|
||||
"websocket",
|
||||
|
@ -74,7 +85,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||
# Called by Sanic Server when shutting down
|
||||
# If we've upgraded to websocket, shut it down
|
||||
if self.websocket is not None:
|
||||
if self.websocket.connection.state in (CLOSING, CLOSED):
|
||||
if self.websocket.ws_proto.state in (CLOSING, CLOSED):
|
||||
return True
|
||||
elif self.websocket.loop is not None:
|
||||
self.websocket.loop.create_task(self.websocket.close(1001))
|
||||
|
@ -90,7 +101,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||
try:
|
||||
if subprotocols is not None:
|
||||
# subprotocols can be a set or frozenset,
|
||||
# but ServerConnection needs a list
|
||||
# but ServerProtocol needs a list
|
||||
subprotocols = cast(
|
||||
Optional[Sequence[Subprotocol]],
|
||||
list(
|
||||
|
@ -100,13 +111,13 @@ class WebSocketProtocol(HttpProtocol):
|
|||
]
|
||||
),
|
||||
)
|
||||
ws_conn = ServerConnection(
|
||||
ws_proto = ServerProtocol(
|
||||
max_size=self.websocket_max_size,
|
||||
subprotocols=subprotocols,
|
||||
state=OPEN,
|
||||
logger=logger,
|
||||
)
|
||||
resp: "http11.Response" = ws_conn.accept(request)
|
||||
resp: "http11.Response" = ws_proto.accept(request)
|
||||
except Exception:
|
||||
msg = (
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
|
@ -129,7 +140,7 @@ class WebSocketProtocol(HttpProtocol):
|
|||
else:
|
||||
raise ServerError(resp.body, resp.status_code)
|
||||
self.websocket = WebsocketImplProtocol(
|
||||
ws_conn,
|
||||
ws_proto,
|
||||
ping_interval=self.websocket_ping_interval,
|
||||
ping_timeout=self.websocket_ping_timeout,
|
||||
close_timeout=self.websocket_timeout,
|
||||
|
|
|
@ -12,25 +12,37 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
||||
from websockets.exceptions import (
|
||||
ConnectionClosed,
|
||||
ConnectionClosedError,
|
||||
ConnectionClosedOK,
|
||||
)
|
||||
from websockets.frames import Frame, Opcode
|
||||
from websockets.server import ServerConnection
|
||||
|
||||
|
||||
try: # websockets < 11.0
|
||||
from websockets.connection import Event, State
|
||||
from websockets.server import ServerConnection as ServerProtocol
|
||||
except ImportError: # websockets >= 11.0
|
||||
from websockets.protocol import Event, State # type: ignore
|
||||
from websockets.server import ServerProtocol # type: ignore
|
||||
|
||||
from websockets.typing import Data
|
||||
|
||||
from sanic.log import error_logger, logger
|
||||
from sanic.log import deprecation, error_logger, logger
|
||||
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||
|
||||
from ...exceptions import ServerError, WebsocketClosed
|
||||
from .frame import WebsocketFrameAssembler
|
||||
|
||||
|
||||
OPEN = State.OPEN
|
||||
CLOSING = State.CLOSING
|
||||
CLOSED = State.CLOSED
|
||||
|
||||
|
||||
class WebsocketImplProtocol:
|
||||
connection: ServerConnection
|
||||
ws_proto: ServerProtocol
|
||||
io_proto: Optional[SanicProtocol]
|
||||
loop: Optional[asyncio.AbstractEventLoop]
|
||||
max_queue: int
|
||||
|
@ -56,14 +68,14 @@ class WebsocketImplProtocol:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
ws_proto,
|
||||
max_queue=None,
|
||||
ping_interval: Optional[float] = 20,
|
||||
ping_timeout: Optional[float] = 20,
|
||||
close_timeout: float = 10,
|
||||
loop=None,
|
||||
):
|
||||
self.connection = connection
|
||||
self.ws_proto = ws_proto
|
||||
self.io_proto = None
|
||||
self.loop = None
|
||||
self.max_queue = max_queue
|
||||
|
@ -85,7 +97,16 @@ class WebsocketImplProtocol:
|
|||
|
||||
@property
|
||||
def subprotocol(self):
|
||||
return self.connection.subprotocol
|
||||
return self.ws_proto.subprotocol
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
deprecation(
|
||||
"The connection property has been deprecated and will be removed. "
|
||||
"Please use the ws_proto property instead going forward.",
|
||||
22.6,
|
||||
)
|
||||
return self.ws_proto
|
||||
|
||||
def pause_frames(self):
|
||||
if not self.can_pause:
|
||||
|
@ -299,15 +320,15 @@ class WebsocketImplProtocol:
|
|||
# Not draining the write buffer is acceptable in this context.
|
||||
|
||||
# clear the send buffer
|
||||
_ = self.connection.data_to_send()
|
||||
_ = self.ws_proto.data_to_send()
|
||||
# If we're not already CLOSED or CLOSING, then send the close.
|
||||
if self.connection.state is OPEN:
|
||||
if self.ws_proto.state is OPEN:
|
||||
if code in (1000, 1001):
|
||||
self.connection.send_close(code, reason)
|
||||
self.ws_proto.send_close(code, reason)
|
||||
else:
|
||||
self.connection.fail(code, reason)
|
||||
self.ws_proto.fail(code, reason)
|
||||
try:
|
||||
data_to_send = self.connection.data_to_send()
|
||||
data_to_send = self.ws_proto.data_to_send()
|
||||
while (
|
||||
len(data_to_send)
|
||||
and self.io_proto
|
||||
|
@ -321,7 +342,7 @@ class WebsocketImplProtocol:
|
|||
...
|
||||
if code == 1006:
|
||||
# Special case: 1006 consider the transport already closed
|
||||
self.connection.state = CLOSED
|
||||
self.ws_proto.state = CLOSED
|
||||
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||
# We have a graceful auto-closer. Use it to close the connection.
|
||||
self.data_finished_fut.cancel()
|
||||
|
@ -342,10 +363,10 @@ class WebsocketImplProtocol:
|
|||
# In Python Version 3.7: pause_reading is idempotent
|
||||
# i.e. it can be called when the transport is already paused or closed.
|
||||
self.io_proto.transport.pause_reading()
|
||||
if self.connection.state == OPEN:
|
||||
data_to_send = self.connection.data_to_send()
|
||||
self.connection.send_close(code, reason)
|
||||
data_to_send.extend(self.connection.data_to_send())
|
||||
if self.ws_proto.state == OPEN:
|
||||
data_to_send = self.ws_proto.data_to_send()
|
||||
self.ws_proto.send_close(code, reason)
|
||||
data_to_send.extend(self.ws_proto.data_to_send())
|
||||
try:
|
||||
while (
|
||||
len(data_to_send)
|
||||
|
@ -454,7 +475,7 @@ class WebsocketImplProtocol:
|
|||
Raise ConnectionClosed in pending keepalive pings.
|
||||
They'll never receive a pong once the connection is closed.
|
||||
"""
|
||||
if self.connection.state is not CLOSED:
|
||||
if self.ws_proto.state is not CLOSED:
|
||||
raise ServerError(
|
||||
"Webscoket about_pings should only be called "
|
||||
"after connection state is changed to CLOSED"
|
||||
|
@ -483,9 +504,9 @@ class WebsocketImplProtocol:
|
|||
self.fail_connection(code, reason)
|
||||
return
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state is OPEN:
|
||||
self.connection.send_close(code, reason)
|
||||
data_to_send = self.connection.data_to_send()
|
||||
if self.ws_proto.state is OPEN:
|
||||
self.ws_proto.send_close(code, reason)
|
||||
data_to_send = self.ws_proto.data_to_send()
|
||||
await self.send_data(data_to_send)
|
||||
|
||||
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||
|
@ -515,7 +536,7 @@ class WebsocketImplProtocol:
|
|||
"already waiting for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
if self.ws_proto.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
|
@ -566,7 +587,7 @@ class WebsocketImplProtocol:
|
|||
"for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
if self.ws_proto.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
|
@ -625,7 +646,7 @@ class WebsocketImplProtocol:
|
|||
"is already waiting for the next message"
|
||||
)
|
||||
await self.recv_lock.acquire()
|
||||
if self.connection.state is CLOSED:
|
||||
if self.ws_proto.state is CLOSED:
|
||||
self.recv_lock.release()
|
||||
raise WebsocketClosed(
|
||||
"Cannot receive from websocket interface after it is closed."
|
||||
|
@ -666,7 +687,7 @@ class WebsocketImplProtocol:
|
|||
"""
|
||||
async with self.conn_mutex:
|
||||
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||
raise WebsocketClosed(
|
||||
"Cannot write to websocket interface after it is closed."
|
||||
)
|
||||
|
@ -679,12 +700,12 @@ class WebsocketImplProtocol:
|
|||
# strings and bytes-like objects are iterable.
|
||||
|
||||
if isinstance(message, str):
|
||||
self.connection.send_text(message.encode("utf-8"))
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
self.ws_proto.send_text(message.encode("utf-8"))
|
||||
await self.send_data(self.ws_proto.data_to_send())
|
||||
|
||||
elif isinstance(message, (bytes, bytearray, memoryview)):
|
||||
self.connection.send_binary(message)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
self.ws_proto.send_binary(message)
|
||||
await self.send_data(self.ws_proto.data_to_send())
|
||||
|
||||
elif isinstance(message, Mapping):
|
||||
# Catch a common mistake -- passing a dict to send().
|
||||
|
@ -713,7 +734,7 @@ class WebsocketImplProtocol:
|
|||
(which will be encoded to UTF-8) or a bytes-like object.
|
||||
"""
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||
raise WebsocketClosed(
|
||||
"Cannot send a ping when the websocket interface "
|
||||
"is closed."
|
||||
|
@ -741,8 +762,8 @@ class WebsocketImplProtocol:
|
|||
|
||||
self.pings[data] = self.io_proto.loop.create_future()
|
||||
|
||||
self.connection.send_ping(data)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
self.ws_proto.send_ping(data)
|
||||
await self.send_data(self.ws_proto.data_to_send())
|
||||
|
||||
return asyncio.shield(self.pings[data])
|
||||
|
||||
|
@ -754,15 +775,15 @@ class WebsocketImplProtocol:
|
|||
be a string (which will be encoded to UTF-8) or a bytes-like object.
|
||||
"""
|
||||
async with self.conn_mutex:
|
||||
if self.connection.state in (CLOSED, CLOSING):
|
||||
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||
# Cannot send pong after transport is shutting down
|
||||
return
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
elif isinstance(data, (bytearray, memoryview)):
|
||||
data = bytes(data)
|
||||
self.connection.send_pong(data)
|
||||
await self.send_data(self.connection.data_to_send())
|
||||
self.ws_proto.send_pong(data)
|
||||
await self.send_data(self.ws_proto.data_to_send())
|
||||
|
||||
async def send_data(self, data_to_send):
|
||||
for data in data_to_send:
|
||||
|
@ -784,7 +805,7 @@ class WebsocketImplProtocol:
|
|||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||
|
||||
async def async_data_received(self, data_to_send, events_to_process):
|
||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
# receiving data can generate data to send (eg, pong for a ping)
|
||||
# send connection.data_to_send()
|
||||
await self.send_data(data_to_send)
|
||||
|
@ -792,9 +813,9 @@ class WebsocketImplProtocol:
|
|||
await self.process_events(events_to_process)
|
||||
|
||||
def data_received(self, data):
|
||||
self.connection.receive_data(data)
|
||||
data_to_send = self.connection.data_to_send()
|
||||
events_to_process = self.connection.events_received()
|
||||
self.ws_proto.receive_data(data)
|
||||
data_to_send = self.ws_proto.data_to_send()
|
||||
events_to_process = self.ws_proto.events_received()
|
||||
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
||||
asyncio.create_task(
|
||||
self.async_data_received(data_to_send, events_to_process)
|
||||
|
@ -803,7 +824,7 @@ class WebsocketImplProtocol:
|
|||
async def async_eof_received(self, data_to_send, events_to_process):
|
||||
# receiving EOF can generate data to send
|
||||
# send connection.data_to_send()
|
||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||
await self.send_data(data_to_send)
|
||||
if len(events_to_process) > 0:
|
||||
await self.process_events(events_to_process)
|
||||
|
@ -823,9 +844,9 @@ class WebsocketImplProtocol:
|
|||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
self.connection.receive_eof()
|
||||
data_to_send = self.connection.data_to_send()
|
||||
events_to_process = self.connection.events_received()
|
||||
self.ws_proto.receive_eof()
|
||||
data_to_send = self.ws_proto.data_to_send()
|
||||
events_to_process = self.ws_proto.events_received()
|
||||
asyncio.create_task(
|
||||
self.async_eof_received(data_to_send, events_to_process)
|
||||
)
|
||||
|
@ -835,11 +856,11 @@ class WebsocketImplProtocol:
|
|||
"""
|
||||
The WebSocket Connection is Closed.
|
||||
"""
|
||||
if not self.connection.state == CLOSED:
|
||||
if not self.ws_proto.state == CLOSED:
|
||||
# signal to the websocket connection handler
|
||||
# we've lost the connection
|
||||
self.connection.fail(code=1006)
|
||||
self.connection.state = CLOSED
|
||||
self.ws_proto.fail(code=1006)
|
||||
self.ws_proto.state = CLOSED
|
||||
|
||||
self.abort_pings()
|
||||
if self.connection_lost_waiter:
|
||||
|
|
Loading…
Reference in New Issue
Block a user