Add compatibility with websockets 11.0. (#2609)

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
Aymeric Augustin 2022-11-29 10:45:18 +01:00 committed by GitHub
parent beae35f921
commit 4c14910d5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 53 deletions

View File

@ -1,7 +1,13 @@
from typing import TYPE_CHECKING, Optional, Sequence, cast from typing import TYPE_CHECKING, Optional, Sequence, cast
from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection try: # websockets < 11.0
from websockets.connection import State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import State # type: ignore
from websockets.server import ServerProtocol # type: ignore
from websockets.typing import Subprotocol from websockets.typing import Subprotocol
from sanic.exceptions import ServerError from sanic.exceptions import ServerError
@ -15,6 +21,11 @@ if TYPE_CHECKING:
from websockets import http11 from websockets import http11
OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED
class WebSocketProtocol(HttpProtocol): class WebSocketProtocol(HttpProtocol):
__slots__ = ( __slots__ = (
"websocket", "websocket",
@ -74,7 +85,7 @@ class WebSocketProtocol(HttpProtocol):
# Called by Sanic Server when shutting down # Called by Sanic Server when shutting down
# If we've upgraded to websocket, shut it down # If we've upgraded to websocket, shut it down
if self.websocket is not None: if self.websocket is not None:
if self.websocket.connection.state in (CLOSING, CLOSED): if self.websocket.ws_proto.state in (CLOSING, CLOSED):
return True return True
elif self.websocket.loop is not None: elif self.websocket.loop is not None:
self.websocket.loop.create_task(self.websocket.close(1001)) self.websocket.loop.create_task(self.websocket.close(1001))
@ -90,7 +101,7 @@ class WebSocketProtocol(HttpProtocol):
try: try:
if subprotocols is not None: if subprotocols is not None:
# subprotocols can be a set or frozenset, # subprotocols can be a set or frozenset,
# but ServerConnection needs a list # but ServerProtocol needs a list
subprotocols = cast( subprotocols = cast(
Optional[Sequence[Subprotocol]], Optional[Sequence[Subprotocol]],
list( list(
@ -100,13 +111,13 @@ class WebSocketProtocol(HttpProtocol):
] ]
), ),
) )
ws_conn = ServerConnection( ws_proto = ServerProtocol(
max_size=self.websocket_max_size, max_size=self.websocket_max_size,
subprotocols=subprotocols, subprotocols=subprotocols,
state=OPEN, state=OPEN,
logger=logger, logger=logger,
) )
resp: "http11.Response" = ws_conn.accept(request) resp: "http11.Response" = ws_proto.accept(request)
except Exception: except Exception:
msg = ( msg = (
"Failed to open a WebSocket connection.\n" "Failed to open a WebSocket connection.\n"
@ -129,7 +140,7 @@ class WebSocketProtocol(HttpProtocol):
else: else:
raise ServerError(resp.body, resp.status_code) raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol( self.websocket = WebsocketImplProtocol(
ws_conn, ws_proto,
ping_interval=self.websocket_ping_interval, ping_interval=self.websocket_ping_interval,
ping_timeout=self.websocket_ping_timeout, ping_timeout=self.websocket_ping_timeout,
close_timeout=self.websocket_timeout, close_timeout=self.websocket_timeout,

View File

@ -12,25 +12,37 @@ from typing import (
Union, Union,
) )
from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import ( from websockets.exceptions import (
ConnectionClosed, ConnectionClosed,
ConnectionClosedError, ConnectionClosedError,
ConnectionClosedOK, ConnectionClosedOK,
) )
from websockets.frames import Frame, Opcode from websockets.frames import Frame, Opcode
from websockets.server import ServerConnection
try: # websockets < 11.0
from websockets.connection import Event, State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import Event, State # type: ignore
from websockets.server import ServerProtocol # type: ignore
from websockets.typing import Data from websockets.typing import Data
from sanic.log import error_logger, logger from sanic.log import deprecation, error_logger, logger
from sanic.server.protocols.base_protocol import SanicProtocol from sanic.server.protocols.base_protocol import SanicProtocol
from ...exceptions import ServerError, WebsocketClosed from ...exceptions import ServerError, WebsocketClosed
from .frame import WebsocketFrameAssembler from .frame import WebsocketFrameAssembler
OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED
class WebsocketImplProtocol: class WebsocketImplProtocol:
connection: ServerConnection ws_proto: ServerProtocol
io_proto: Optional[SanicProtocol] io_proto: Optional[SanicProtocol]
loop: Optional[asyncio.AbstractEventLoop] loop: Optional[asyncio.AbstractEventLoop]
max_queue: int max_queue: int
@ -56,14 +68,14 @@ class WebsocketImplProtocol:
def __init__( def __init__(
self, self,
connection, ws_proto,
max_queue=None, max_queue=None,
ping_interval: Optional[float] = 20, ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20, ping_timeout: Optional[float] = 20,
close_timeout: float = 10, close_timeout: float = 10,
loop=None, loop=None,
): ):
self.connection = connection self.ws_proto = ws_proto
self.io_proto = None self.io_proto = None
self.loop = None self.loop = None
self.max_queue = max_queue self.max_queue = max_queue
@ -85,7 +97,16 @@ class WebsocketImplProtocol:
@property @property
def subprotocol(self): def subprotocol(self):
return self.connection.subprotocol return self.ws_proto.subprotocol
@property
def connection(self):
deprecation(
"The connection property has been deprecated and will be removed. "
"Please use the ws_proto property instead going forward.",
22.6,
)
return self.ws_proto
def pause_frames(self): def pause_frames(self):
if not self.can_pause: if not self.can_pause:
@ -299,15 +320,15 @@ class WebsocketImplProtocol:
# Not draining the write buffer is acceptable in this context. # Not draining the write buffer is acceptable in this context.
# clear the send buffer # clear the send buffer
_ = self.connection.data_to_send() _ = self.ws_proto.data_to_send()
# If we're not already CLOSED or CLOSING, then send the close. # If we're not already CLOSED or CLOSING, then send the close.
if self.connection.state is OPEN: if self.ws_proto.state is OPEN:
if code in (1000, 1001): if code in (1000, 1001):
self.connection.send_close(code, reason) self.ws_proto.send_close(code, reason)
else: else:
self.connection.fail(code, reason) self.ws_proto.fail(code, reason)
try: try:
data_to_send = self.connection.data_to_send() data_to_send = self.ws_proto.data_to_send()
while ( while (
len(data_to_send) len(data_to_send)
and self.io_proto and self.io_proto
@ -321,7 +342,7 @@ class WebsocketImplProtocol:
... ...
if code == 1006: if code == 1006:
# Special case: 1006 consider the transport already closed # Special case: 1006 consider the transport already closed
self.connection.state = CLOSED self.ws_proto.state = CLOSED
if self.data_finished_fut and not self.data_finished_fut.done(): if self.data_finished_fut and not self.data_finished_fut.done():
# We have a graceful auto-closer. Use it to close the connection. # We have a graceful auto-closer. Use it to close the connection.
self.data_finished_fut.cancel() self.data_finished_fut.cancel()
@ -342,10 +363,10 @@ class WebsocketImplProtocol:
# In Python Version 3.7: pause_reading is idempotent # In Python Version 3.7: pause_reading is idempotent
# i.e. it can be called when the transport is already paused or closed. # i.e. it can be called when the transport is already paused or closed.
self.io_proto.transport.pause_reading() self.io_proto.transport.pause_reading()
if self.connection.state == OPEN: if self.ws_proto.state == OPEN:
data_to_send = self.connection.data_to_send() data_to_send = self.ws_proto.data_to_send()
self.connection.send_close(code, reason) self.ws_proto.send_close(code, reason)
data_to_send.extend(self.connection.data_to_send()) data_to_send.extend(self.ws_proto.data_to_send())
try: try:
while ( while (
len(data_to_send) len(data_to_send)
@ -454,7 +475,7 @@ class WebsocketImplProtocol:
Raise ConnectionClosed in pending keepalive pings. Raise ConnectionClosed in pending keepalive pings.
They'll never receive a pong once the connection is closed. They'll never receive a pong once the connection is closed.
""" """
if self.connection.state is not CLOSED: if self.ws_proto.state is not CLOSED:
raise ServerError( raise ServerError(
"Webscoket about_pings should only be called " "Webscoket about_pings should only be called "
"after connection state is changed to CLOSED" "after connection state is changed to CLOSED"
@ -483,9 +504,9 @@ class WebsocketImplProtocol:
self.fail_connection(code, reason) self.fail_connection(code, reason)
return return
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state is OPEN: if self.ws_proto.state is OPEN:
self.connection.send_close(code, reason) self.ws_proto.send_close(code, reason)
data_to_send = self.connection.data_to_send() data_to_send = self.ws_proto.data_to_send()
await self.send_data(data_to_send) await self.send_data(data_to_send)
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
@ -515,7 +536,7 @@ class WebsocketImplProtocol:
"already waiting for the next message" "already waiting for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state is CLOSED: if self.ws_proto.state is CLOSED:
self.recv_lock.release() self.recv_lock.release()
raise WebsocketClosed( raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
@ -566,7 +587,7 @@ class WebsocketImplProtocol:
"for the next message" "for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state is CLOSED: if self.ws_proto.state is CLOSED:
self.recv_lock.release() self.recv_lock.release()
raise WebsocketClosed( raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
@ -625,7 +646,7 @@ class WebsocketImplProtocol:
"is already waiting for the next message" "is already waiting for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state is CLOSED: if self.ws_proto.state is CLOSED:
self.recv_lock.release() self.recv_lock.release()
raise WebsocketClosed( raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
@ -666,7 +687,7 @@ class WebsocketImplProtocol:
""" """
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING): if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed( raise WebsocketClosed(
"Cannot write to websocket interface after it is closed." "Cannot write to websocket interface after it is closed."
) )
@ -679,12 +700,12 @@ class WebsocketImplProtocol:
# strings and bytes-like objects are iterable. # strings and bytes-like objects are iterable.
if isinstance(message, str): if isinstance(message, str):
self.connection.send_text(message.encode("utf-8")) self.ws_proto.send_text(message.encode("utf-8"))
await self.send_data(self.connection.data_to_send()) await self.send_data(self.ws_proto.data_to_send())
elif isinstance(message, (bytes, bytearray, memoryview)): elif isinstance(message, (bytes, bytearray, memoryview)):
self.connection.send_binary(message) self.ws_proto.send_binary(message)
await self.send_data(self.connection.data_to_send()) await self.send_data(self.ws_proto.data_to_send())
elif isinstance(message, Mapping): elif isinstance(message, Mapping):
# Catch a common mistake -- passing a dict to send(). # Catch a common mistake -- passing a dict to send().
@ -713,7 +734,7 @@ class WebsocketImplProtocol:
(which will be encoded to UTF-8) or a bytes-like object. (which will be encoded to UTF-8) or a bytes-like object.
""" """
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING): if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed( raise WebsocketClosed(
"Cannot send a ping when the websocket interface " "Cannot send a ping when the websocket interface "
"is closed." "is closed."
@ -741,8 +762,8 @@ class WebsocketImplProtocol:
self.pings[data] = self.io_proto.loop.create_future() self.pings[data] = self.io_proto.loop.create_future()
self.connection.send_ping(data) self.ws_proto.send_ping(data)
await self.send_data(self.connection.data_to_send()) await self.send_data(self.ws_proto.data_to_send())
return asyncio.shield(self.pings[data]) return asyncio.shield(self.pings[data])
@ -754,15 +775,15 @@ class WebsocketImplProtocol:
be a string (which will be encoded to UTF-8) or a bytes-like object. be a string (which will be encoded to UTF-8) or a bytes-like object.
""" """
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING): if self.ws_proto.state in (CLOSED, CLOSING):
# Cannot send pong after transport is shutting down # Cannot send pong after transport is shutting down
return return
if isinstance(data, str): if isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
elif isinstance(data, (bytearray, memoryview)): elif isinstance(data, (bytearray, memoryview)):
data = bytes(data) data = bytes(data)
self.connection.send_pong(data) self.ws_proto.send_pong(data)
await self.send_data(self.connection.data_to_send()) await self.send_data(self.ws_proto.data_to_send())
async def send_data(self, data_to_send): async def send_data(self, data_to_send):
for data in data_to_send: for data in data_to_send:
@ -784,7 +805,7 @@ class WebsocketImplProtocol:
SanicProtocol.close(self.io_proto, timeout=1.0) SanicProtocol.close(self.io_proto, timeout=1.0)
async def async_data_received(self, data_to_send, events_to_process): async def async_data_received(self, data_to_send, events_to_process):
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
# receiving data can generate data to send (eg, pong for a ping) # receiving data can generate data to send (eg, pong for a ping)
# send connection.data_to_send() # send connection.data_to_send()
await self.send_data(data_to_send) await self.send_data(data_to_send)
@ -792,9 +813,9 @@ class WebsocketImplProtocol:
await self.process_events(events_to_process) await self.process_events(events_to_process)
def data_received(self, data): def data_received(self, data):
self.connection.receive_data(data) self.ws_proto.receive_data(data)
data_to_send = self.connection.data_to_send() data_to_send = self.ws_proto.data_to_send()
events_to_process = self.connection.events_received() events_to_process = self.ws_proto.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0: if len(data_to_send) > 0 or len(events_to_process) > 0:
asyncio.create_task( asyncio.create_task(
self.async_data_received(data_to_send, events_to_process) self.async_data_received(data_to_send, events_to_process)
@ -803,7 +824,7 @@ class WebsocketImplProtocol:
async def async_eof_received(self, data_to_send, events_to_process): async def async_eof_received(self, data_to_send, events_to_process):
# receiving EOF can generate data to send # receiving EOF can generate data to send
# send connection.data_to_send() # send connection.data_to_send()
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
await self.send_data(data_to_send) await self.send_data(data_to_send)
if len(events_to_process) > 0: if len(events_to_process) > 0:
await self.process_events(events_to_process) await self.process_events(events_to_process)
@ -823,9 +844,9 @@ class WebsocketImplProtocol:
SanicProtocol.close(self.io_proto, timeout=1.0) SanicProtocol.close(self.io_proto, timeout=1.0)
def eof_received(self) -> Optional[bool]: def eof_received(self) -> Optional[bool]:
self.connection.receive_eof() self.ws_proto.receive_eof()
data_to_send = self.connection.data_to_send() data_to_send = self.ws_proto.data_to_send()
events_to_process = self.connection.events_received() events_to_process = self.ws_proto.events_received()
asyncio.create_task( asyncio.create_task(
self.async_eof_received(data_to_send, events_to_process) self.async_eof_received(data_to_send, events_to_process)
) )
@ -835,11 +856,11 @@ class WebsocketImplProtocol:
""" """
The WebSocket Connection is Closed. The WebSocket Connection is Closed.
""" """
if not self.connection.state == CLOSED: if not self.ws_proto.state == CLOSED:
# signal to the websocket connection handler # signal to the websocket connection handler
# we've lost the connection # we've lost the connection
self.connection.fail(code=1006) self.ws_proto.fail(code=1006)
self.connection.state = CLOSED self.ws_proto.state = CLOSED
self.abort_pings() self.abort_pings()
if self.connection_lost_waiter: if self.connection_lost_waiter: