Add compatibility with websockets 11.0. (#2609)
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
This commit is contained in:
parent
beae35f921
commit
4c14910d5b
|
@ -1,7 +1,13 @@
|
||||||
from typing import TYPE_CHECKING, Optional, Sequence, cast
|
from typing import TYPE_CHECKING, Optional, Sequence, cast
|
||||||
|
|
||||||
from websockets.connection import CLOSED, CLOSING, OPEN
|
|
||||||
from websockets.server import ServerConnection
|
try: # websockets < 11.0
|
||||||
|
from websockets.connection import State
|
||||||
|
from websockets.server import ServerConnection as ServerProtocol
|
||||||
|
except ImportError: # websockets >= 11.0
|
||||||
|
from websockets.protocol import State # type: ignore
|
||||||
|
from websockets.server import ServerProtocol # type: ignore
|
||||||
|
|
||||||
from websockets.typing import Subprotocol
|
from websockets.typing import Subprotocol
|
||||||
|
|
||||||
from sanic.exceptions import ServerError
|
from sanic.exceptions import ServerError
|
||||||
|
@ -15,6 +21,11 @@ if TYPE_CHECKING:
|
||||||
from websockets import http11
|
from websockets import http11
|
||||||
|
|
||||||
|
|
||||||
|
OPEN = State.OPEN
|
||||||
|
CLOSING = State.CLOSING
|
||||||
|
CLOSED = State.CLOSED
|
||||||
|
|
||||||
|
|
||||||
class WebSocketProtocol(HttpProtocol):
|
class WebSocketProtocol(HttpProtocol):
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"websocket",
|
"websocket",
|
||||||
|
@ -74,7 +85,7 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
# Called by Sanic Server when shutting down
|
# Called by Sanic Server when shutting down
|
||||||
# If we've upgraded to websocket, shut it down
|
# If we've upgraded to websocket, shut it down
|
||||||
if self.websocket is not None:
|
if self.websocket is not None:
|
||||||
if self.websocket.connection.state in (CLOSING, CLOSED):
|
if self.websocket.ws_proto.state in (CLOSING, CLOSED):
|
||||||
return True
|
return True
|
||||||
elif self.websocket.loop is not None:
|
elif self.websocket.loop is not None:
|
||||||
self.websocket.loop.create_task(self.websocket.close(1001))
|
self.websocket.loop.create_task(self.websocket.close(1001))
|
||||||
|
@ -90,7 +101,7 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
try:
|
try:
|
||||||
if subprotocols is not None:
|
if subprotocols is not None:
|
||||||
# subprotocols can be a set or frozenset,
|
# subprotocols can be a set or frozenset,
|
||||||
# but ServerConnection needs a list
|
# but ServerProtocol needs a list
|
||||||
subprotocols = cast(
|
subprotocols = cast(
|
||||||
Optional[Sequence[Subprotocol]],
|
Optional[Sequence[Subprotocol]],
|
||||||
list(
|
list(
|
||||||
|
@ -100,13 +111,13 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
ws_conn = ServerConnection(
|
ws_proto = ServerProtocol(
|
||||||
max_size=self.websocket_max_size,
|
max_size=self.websocket_max_size,
|
||||||
subprotocols=subprotocols,
|
subprotocols=subprotocols,
|
||||||
state=OPEN,
|
state=OPEN,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
resp: "http11.Response" = ws_conn.accept(request)
|
resp: "http11.Response" = ws_proto.accept(request)
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = (
|
msg = (
|
||||||
"Failed to open a WebSocket connection.\n"
|
"Failed to open a WebSocket connection.\n"
|
||||||
|
@ -129,7 +140,7 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
else:
|
else:
|
||||||
raise ServerError(resp.body, resp.status_code)
|
raise ServerError(resp.body, resp.status_code)
|
||||||
self.websocket = WebsocketImplProtocol(
|
self.websocket = WebsocketImplProtocol(
|
||||||
ws_conn,
|
ws_proto,
|
||||||
ping_interval=self.websocket_ping_interval,
|
ping_interval=self.websocket_ping_interval,
|
||||||
ping_timeout=self.websocket_ping_timeout,
|
ping_timeout=self.websocket_ping_timeout,
|
||||||
close_timeout=self.websocket_timeout,
|
close_timeout=self.websocket_timeout,
|
||||||
|
|
|
@ -12,25 +12,37 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
|
||||||
from websockets.exceptions import (
|
from websockets.exceptions import (
|
||||||
ConnectionClosed,
|
ConnectionClosed,
|
||||||
ConnectionClosedError,
|
ConnectionClosedError,
|
||||||
ConnectionClosedOK,
|
ConnectionClosedOK,
|
||||||
)
|
)
|
||||||
from websockets.frames import Frame, Opcode
|
from websockets.frames import Frame, Opcode
|
||||||
from websockets.server import ServerConnection
|
|
||||||
|
|
||||||
|
try: # websockets < 11.0
|
||||||
|
from websockets.connection import Event, State
|
||||||
|
from websockets.server import ServerConnection as ServerProtocol
|
||||||
|
except ImportError: # websockets >= 11.0
|
||||||
|
from websockets.protocol import Event, State # type: ignore
|
||||||
|
from websockets.server import ServerProtocol # type: ignore
|
||||||
|
|
||||||
from websockets.typing import Data
|
from websockets.typing import Data
|
||||||
|
|
||||||
from sanic.log import error_logger, logger
|
from sanic.log import deprecation, error_logger, logger
|
||||||
from sanic.server.protocols.base_protocol import SanicProtocol
|
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||||
|
|
||||||
from ...exceptions import ServerError, WebsocketClosed
|
from ...exceptions import ServerError, WebsocketClosed
|
||||||
from .frame import WebsocketFrameAssembler
|
from .frame import WebsocketFrameAssembler
|
||||||
|
|
||||||
|
|
||||||
|
OPEN = State.OPEN
|
||||||
|
CLOSING = State.CLOSING
|
||||||
|
CLOSED = State.CLOSED
|
||||||
|
|
||||||
|
|
||||||
class WebsocketImplProtocol:
|
class WebsocketImplProtocol:
|
||||||
connection: ServerConnection
|
ws_proto: ServerProtocol
|
||||||
io_proto: Optional[SanicProtocol]
|
io_proto: Optional[SanicProtocol]
|
||||||
loop: Optional[asyncio.AbstractEventLoop]
|
loop: Optional[asyncio.AbstractEventLoop]
|
||||||
max_queue: int
|
max_queue: int
|
||||||
|
@ -56,14 +68,14 @@ class WebsocketImplProtocol:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connection,
|
ws_proto,
|
||||||
max_queue=None,
|
max_queue=None,
|
||||||
ping_interval: Optional[float] = 20,
|
ping_interval: Optional[float] = 20,
|
||||||
ping_timeout: Optional[float] = 20,
|
ping_timeout: Optional[float] = 20,
|
||||||
close_timeout: float = 10,
|
close_timeout: float = 10,
|
||||||
loop=None,
|
loop=None,
|
||||||
):
|
):
|
||||||
self.connection = connection
|
self.ws_proto = ws_proto
|
||||||
self.io_proto = None
|
self.io_proto = None
|
||||||
self.loop = None
|
self.loop = None
|
||||||
self.max_queue = max_queue
|
self.max_queue = max_queue
|
||||||
|
@ -85,7 +97,16 @@ class WebsocketImplProtocol:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subprotocol(self):
|
def subprotocol(self):
|
||||||
return self.connection.subprotocol
|
return self.ws_proto.subprotocol
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection(self):
|
||||||
|
deprecation(
|
||||||
|
"The connection property has been deprecated and will be removed. "
|
||||||
|
"Please use the ws_proto property instead going forward.",
|
||||||
|
22.6,
|
||||||
|
)
|
||||||
|
return self.ws_proto
|
||||||
|
|
||||||
def pause_frames(self):
|
def pause_frames(self):
|
||||||
if not self.can_pause:
|
if not self.can_pause:
|
||||||
|
@ -299,15 +320,15 @@ class WebsocketImplProtocol:
|
||||||
# Not draining the write buffer is acceptable in this context.
|
# Not draining the write buffer is acceptable in this context.
|
||||||
|
|
||||||
# clear the send buffer
|
# clear the send buffer
|
||||||
_ = self.connection.data_to_send()
|
_ = self.ws_proto.data_to_send()
|
||||||
# If we're not already CLOSED or CLOSING, then send the close.
|
# If we're not already CLOSED or CLOSING, then send the close.
|
||||||
if self.connection.state is OPEN:
|
if self.ws_proto.state is OPEN:
|
||||||
if code in (1000, 1001):
|
if code in (1000, 1001):
|
||||||
self.connection.send_close(code, reason)
|
self.ws_proto.send_close(code, reason)
|
||||||
else:
|
else:
|
||||||
self.connection.fail(code, reason)
|
self.ws_proto.fail(code, reason)
|
||||||
try:
|
try:
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.ws_proto.data_to_send()
|
||||||
while (
|
while (
|
||||||
len(data_to_send)
|
len(data_to_send)
|
||||||
and self.io_proto
|
and self.io_proto
|
||||||
|
@ -321,7 +342,7 @@ class WebsocketImplProtocol:
|
||||||
...
|
...
|
||||||
if code == 1006:
|
if code == 1006:
|
||||||
# Special case: 1006 consider the transport already closed
|
# Special case: 1006 consider the transport already closed
|
||||||
self.connection.state = CLOSED
|
self.ws_proto.state = CLOSED
|
||||||
if self.data_finished_fut and not self.data_finished_fut.done():
|
if self.data_finished_fut and not self.data_finished_fut.done():
|
||||||
# We have a graceful auto-closer. Use it to close the connection.
|
# We have a graceful auto-closer. Use it to close the connection.
|
||||||
self.data_finished_fut.cancel()
|
self.data_finished_fut.cancel()
|
||||||
|
@ -342,10 +363,10 @@ class WebsocketImplProtocol:
|
||||||
# In Python Version 3.7: pause_reading is idempotent
|
# In Python Version 3.7: pause_reading is idempotent
|
||||||
# i.e. it can be called when the transport is already paused or closed.
|
# i.e. it can be called when the transport is already paused or closed.
|
||||||
self.io_proto.transport.pause_reading()
|
self.io_proto.transport.pause_reading()
|
||||||
if self.connection.state == OPEN:
|
if self.ws_proto.state == OPEN:
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.ws_proto.data_to_send()
|
||||||
self.connection.send_close(code, reason)
|
self.ws_proto.send_close(code, reason)
|
||||||
data_to_send.extend(self.connection.data_to_send())
|
data_to_send.extend(self.ws_proto.data_to_send())
|
||||||
try:
|
try:
|
||||||
while (
|
while (
|
||||||
len(data_to_send)
|
len(data_to_send)
|
||||||
|
@ -454,7 +475,7 @@ class WebsocketImplProtocol:
|
||||||
Raise ConnectionClosed in pending keepalive pings.
|
Raise ConnectionClosed in pending keepalive pings.
|
||||||
They'll never receive a pong once the connection is closed.
|
They'll never receive a pong once the connection is closed.
|
||||||
"""
|
"""
|
||||||
if self.connection.state is not CLOSED:
|
if self.ws_proto.state is not CLOSED:
|
||||||
raise ServerError(
|
raise ServerError(
|
||||||
"Webscoket about_pings should only be called "
|
"Webscoket about_pings should only be called "
|
||||||
"after connection state is changed to CLOSED"
|
"after connection state is changed to CLOSED"
|
||||||
|
@ -483,9 +504,9 @@ class WebsocketImplProtocol:
|
||||||
self.fail_connection(code, reason)
|
self.fail_connection(code, reason)
|
||||||
return
|
return
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
if self.connection.state is OPEN:
|
if self.ws_proto.state is OPEN:
|
||||||
self.connection.send_close(code, reason)
|
self.ws_proto.send_close(code, reason)
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.ws_proto.data_to_send()
|
||||||
await self.send_data(data_to_send)
|
await self.send_data(data_to_send)
|
||||||
|
|
||||||
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
|
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
|
||||||
|
@ -515,7 +536,7 @@ class WebsocketImplProtocol:
|
||||||
"already waiting for the next message"
|
"already waiting for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state is CLOSED:
|
if self.ws_proto.state is CLOSED:
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
raise WebsocketClosed(
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
@ -566,7 +587,7 @@ class WebsocketImplProtocol:
|
||||||
"for the next message"
|
"for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state is CLOSED:
|
if self.ws_proto.state is CLOSED:
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
raise WebsocketClosed(
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
@ -625,7 +646,7 @@ class WebsocketImplProtocol:
|
||||||
"is already waiting for the next message"
|
"is already waiting for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state is CLOSED:
|
if self.ws_proto.state is CLOSED:
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
raise WebsocketClosed(
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
|
@ -666,7 +687,7 @@ class WebsocketImplProtocol:
|
||||||
"""
|
"""
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
|
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||||
raise WebsocketClosed(
|
raise WebsocketClosed(
|
||||||
"Cannot write to websocket interface after it is closed."
|
"Cannot write to websocket interface after it is closed."
|
||||||
)
|
)
|
||||||
|
@ -679,12 +700,12 @@ class WebsocketImplProtocol:
|
||||||
# strings and bytes-like objects are iterable.
|
# strings and bytes-like objects are iterable.
|
||||||
|
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
self.connection.send_text(message.encode("utf-8"))
|
self.ws_proto.send_text(message.encode("utf-8"))
|
||||||
await self.send_data(self.connection.data_to_send())
|
await self.send_data(self.ws_proto.data_to_send())
|
||||||
|
|
||||||
elif isinstance(message, (bytes, bytearray, memoryview)):
|
elif isinstance(message, (bytes, bytearray, memoryview)):
|
||||||
self.connection.send_binary(message)
|
self.ws_proto.send_binary(message)
|
||||||
await self.send_data(self.connection.data_to_send())
|
await self.send_data(self.ws_proto.data_to_send())
|
||||||
|
|
||||||
elif isinstance(message, Mapping):
|
elif isinstance(message, Mapping):
|
||||||
# Catch a common mistake -- passing a dict to send().
|
# Catch a common mistake -- passing a dict to send().
|
||||||
|
@ -713,7 +734,7 @@ class WebsocketImplProtocol:
|
||||||
(which will be encoded to UTF-8) or a bytes-like object.
|
(which will be encoded to UTF-8) or a bytes-like object.
|
||||||
"""
|
"""
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||||
raise WebsocketClosed(
|
raise WebsocketClosed(
|
||||||
"Cannot send a ping when the websocket interface "
|
"Cannot send a ping when the websocket interface "
|
||||||
"is closed."
|
"is closed."
|
||||||
|
@ -741,8 +762,8 @@ class WebsocketImplProtocol:
|
||||||
|
|
||||||
self.pings[data] = self.io_proto.loop.create_future()
|
self.pings[data] = self.io_proto.loop.create_future()
|
||||||
|
|
||||||
self.connection.send_ping(data)
|
self.ws_proto.send_ping(data)
|
||||||
await self.send_data(self.connection.data_to_send())
|
await self.send_data(self.ws_proto.data_to_send())
|
||||||
|
|
||||||
return asyncio.shield(self.pings[data])
|
return asyncio.shield(self.pings[data])
|
||||||
|
|
||||||
|
@ -754,15 +775,15 @@ class WebsocketImplProtocol:
|
||||||
be a string (which will be encoded to UTF-8) or a bytes-like object.
|
be a string (which will be encoded to UTF-8) or a bytes-like object.
|
||||||
"""
|
"""
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.ws_proto.state in (CLOSED, CLOSING):
|
||||||
# Cannot send pong after transport is shutting down
|
# Cannot send pong after transport is shutting down
|
||||||
return
|
return
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
data = data.encode("utf-8")
|
data = data.encode("utf-8")
|
||||||
elif isinstance(data, (bytearray, memoryview)):
|
elif isinstance(data, (bytearray, memoryview)):
|
||||||
data = bytes(data)
|
data = bytes(data)
|
||||||
self.connection.send_pong(data)
|
self.ws_proto.send_pong(data)
|
||||||
await self.send_data(self.connection.data_to_send())
|
await self.send_data(self.ws_proto.data_to_send())
|
||||||
|
|
||||||
async def send_data(self, data_to_send):
|
async def send_data(self, data_to_send):
|
||||||
for data in data_to_send:
|
for data in data_to_send:
|
||||||
|
@ -784,7 +805,7 @@ class WebsocketImplProtocol:
|
||||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
|
||||||
async def async_data_received(self, data_to_send, events_to_process):
|
async def async_data_received(self, data_to_send, events_to_process):
|
||||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||||
# receiving data can generate data to send (eg, pong for a ping)
|
# receiving data can generate data to send (eg, pong for a ping)
|
||||||
# send connection.data_to_send()
|
# send connection.data_to_send()
|
||||||
await self.send_data(data_to_send)
|
await self.send_data(data_to_send)
|
||||||
|
@ -792,9 +813,9 @@ class WebsocketImplProtocol:
|
||||||
await self.process_events(events_to_process)
|
await self.process_events(events_to_process)
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
self.connection.receive_data(data)
|
self.ws_proto.receive_data(data)
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.ws_proto.data_to_send()
|
||||||
events_to_process = self.connection.events_received()
|
events_to_process = self.ws_proto.events_received()
|
||||||
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.async_data_received(data_to_send, events_to_process)
|
self.async_data_received(data_to_send, events_to_process)
|
||||||
|
@ -803,7 +824,7 @@ class WebsocketImplProtocol:
|
||||||
async def async_eof_received(self, data_to_send, events_to_process):
|
async def async_eof_received(self, data_to_send, events_to_process):
|
||||||
# receiving EOF can generate data to send
|
# receiving EOF can generate data to send
|
||||||
# send connection.data_to_send()
|
# send connection.data_to_send()
|
||||||
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||||
await self.send_data(data_to_send)
|
await self.send_data(data_to_send)
|
||||||
if len(events_to_process) > 0:
|
if len(events_to_process) > 0:
|
||||||
await self.process_events(events_to_process)
|
await self.process_events(events_to_process)
|
||||||
|
@ -823,9 +844,9 @@ class WebsocketImplProtocol:
|
||||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
|
||||||
def eof_received(self) -> Optional[bool]:
|
def eof_received(self) -> Optional[bool]:
|
||||||
self.connection.receive_eof()
|
self.ws_proto.receive_eof()
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.ws_proto.data_to_send()
|
||||||
events_to_process = self.connection.events_received()
|
events_to_process = self.ws_proto.events_received()
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.async_eof_received(data_to_send, events_to_process)
|
self.async_eof_received(data_to_send, events_to_process)
|
||||||
)
|
)
|
||||||
|
@ -835,11 +856,11 @@ class WebsocketImplProtocol:
|
||||||
"""
|
"""
|
||||||
The WebSocket Connection is Closed.
|
The WebSocket Connection is Closed.
|
||||||
"""
|
"""
|
||||||
if not self.connection.state == CLOSED:
|
if not self.ws_proto.state == CLOSED:
|
||||||
# signal to the websocket connection handler
|
# signal to the websocket connection handler
|
||||||
# we've lost the connection
|
# we've lost the connection
|
||||||
self.connection.fail(code=1006)
|
self.ws_proto.fail(code=1006)
|
||||||
self.connection.state = CLOSED
|
self.ws_proto.state = CLOSED
|
||||||
|
|
||||||
self.abort_pings()
|
self.abort_pings()
|
||||||
if self.connection_lost_waiter:
|
if self.connection_lost_waiter:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user