New websockets (#2158)

* First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work.

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* wip, update websockets code to new Sans/IO API

* Refactored new websockets impl into own modules
Incorporated other suggestions made by team

* Another round of work on the new websockets impl
* Added websocket_timeout support (matching previous/legacy support)
* Lots more comments
* Incorporated suggested changes from previous round of review
* Changed RuntimeError usage to ServerError
* Changed SanicException usage to ServerError
* Removed some redundant asserts
* Change remaining asserts to ServerErrors
* Fixed some timeout handling issues
* Fixed websocket.close() handling, and made it more robust
* Made auto_close task smarter and more error-resilient
* Made fail_connection routine smarter and more error-resilient

* Further new websockets impl fixes
* Update compatibility with Websockets v10
* Track server connection state in a more precise way
* Try to handle the shutdown process more gracefully
* Add a new end_connection() helper, to use as an alterative to close() or fail_connection()
* Kill the auto-close task and keepalive-timeout task when sanic is shutdown
* Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation.

* Change a warning message to debug level
Remove default values for deprecated websocket parameters

* Fix flake8 errors

* Fix a couple of missed failing tests

* remove websocket bench from examples

* Integrate suggestions from code reviews
Use Optional[T] instead of union[T,None]
Fix mypy type logic errors
change "is not None" to truthy checks where appropriate
change "is None" to falsy checks were appropriate
Add more debug logging when debug mode is on
Change to using sanic.logger for debug logging rather than error_logger.

* Fix long line lengths of debug messages
Add some new debug messages when websocket IO is paused and unpaused for flow control
Fix websocket example to use app.static()

* remove unused import in websocket example app

* re-run isort after Flake8 fixes

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Ashley Sommer 2021-09-29 20:09:23 +10:00 committed by GitHub
parent 595d2c76ac
commit 6ffc4d9756
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1376 additions and 269 deletions

View File

@ -1,13 +1,14 @@
from sanic import Sanic from sanic import Sanic
from sanic.response import file from sanic.response import redirect
app = Sanic(__name__) app = Sanic(__name__)
@app.route('/') app.static('index.html', "websocket.html")
async def index(request):
return await file('websocket.html')
@app.route('/')
def index(request):
return redirect("index.html")
@app.websocket('/feed') @app.websocket('/feed')
async def feed(request, ws): async def feed(request, ws):

View File

@ -74,9 +74,10 @@ from sanic.router import Router
from sanic.server import AsyncioServer, HttpProtocol from sanic.server import AsyncioServer, HttpProtocol
from sanic.server import Signal as ServerSignal from sanic.server import Signal as ServerSignal
from sanic.server import serve, serve_multiple, serve_single from sanic.server import serve, serve_multiple, serve_single
from sanic.server.protocols.websocket_protocol import WebSocketProtocol
from sanic.server.websockets.impl import ConnectionClosed
from sanic.signals import Signal, SignalRouter from sanic.signals import Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta from sanic.touchup import TouchUp, TouchUpMeta
from sanic.websocket import ConnectionClosed, WebSocketProtocol
class Sanic(BaseSanic, metaclass=TouchUpMeta): class Sanic(BaseSanic, metaclass=TouchUpMeta):
@ -871,23 +872,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
async def _websocket_handler( async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs self, handler, request, *args, subprotocols=None, **kwargs
): ):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request._name = handler.__name__
else:
request._name = (
getattr(handler, "__blueprintname__", "") + handler.__name__
)
pass
if self.asgi: if self.asgi:
ws = request.transport.get_websocket_connection() ws = request.transport.get_websocket_connection()
await ws.accept(subprotocols) await ws.accept(subprotocols)
else: else:
protocol = request.transport.get_protocol() protocol = request.transport.get_protocol()
protocol.app = self
ws = await protocol.websocket_handshake(request, subprotocols) ws = await protocol.websocket_handshake(request, subprotocols)
# schedule the application handler # schedule the application handler
@ -895,15 +884,19 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
# needs to be cancelled due to the server being stopped # needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs)) fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut) self.websocket_tasks.add(fut)
cancelled = False
try: try:
await fut await fut
except Exception as e: except Exception as e:
self.error_handler.log(request, e) self.error_handler.log(request, e)
except (CancelledError, ConnectionClosed): except (CancelledError, ConnectionClosed):
pass cancelled = True
finally: finally:
self.websocket_tasks.remove(fut) self.websocket_tasks.remove(fut)
await ws.close() if cancelled:
ws.end_connection(1000)
else:
await ws.close()
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# Testing # Testing

View File

@ -10,7 +10,7 @@ from sanic.exceptions import ServerError
from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport
from sanic.request import Request from sanic.request import Request
from sanic.server import ConnInfo from sanic.server import ConnInfo
from sanic.websocket import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
class Lifespan: class Lifespan:

View File

@ -35,12 +35,9 @@ DEFAULT_CONFIG = {
"REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_MAX_SIZE": 100000000, # 100 megabytes
"REQUEST_TIMEOUT": 60, # 60 seconds "REQUEST_TIMEOUT": 60, # 60 seconds
"RESPONSE_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds
"WEBSOCKET_MAX_QUEUE": 32,
"WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte
"WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_INTERVAL": 20,
"WEBSOCKET_PING_TIMEOUT": 20, "WEBSOCKET_PING_TIMEOUT": 20,
"WEBSOCKET_READ_LIMIT": 2 ** 16,
"WEBSOCKET_WRITE_LIMIT": 2 ** 16,
} }
@ -62,12 +59,9 @@ class Config(dict):
REQUEST_MAX_SIZE: int REQUEST_MAX_SIZE: int
REQUEST_TIMEOUT: int REQUEST_TIMEOUT: int
RESPONSE_TIMEOUT: int RESPONSE_TIMEOUT: int
WEBSOCKET_MAX_QUEUE: int
WEBSOCKET_MAX_SIZE: int WEBSOCKET_MAX_SIZE: int
WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_INTERVAL: int
WEBSOCKET_PING_TIMEOUT: int WEBSOCKET_PING_TIMEOUT: int
WEBSOCKET_READ_LIMIT: int
WEBSOCKET_WRITE_LIMIT: int
def __init__( def __init__(
self, self,

View File

@ -121,8 +121,11 @@ class RouteMixin:
"Expected either string or Iterable of host strings, " "Expected either string or Iterable of host strings, "
"not %s" % host "not %s" % host
) )
if isinstance(subprotocols, list):
if isinstance(subprotocols, (list, tuple, set)): # Ordered subprotocols, maintain order
subprotocols = tuple(subprotocols)
elif isinstance(subprotocols, set):
# subprotocol is unordered, keep it unordered
subprotocols = frozenset(subprotocols) subprotocols = frozenset(subprotocols)
route = FutureRoute( route = FutureRoute(

View File

@ -3,7 +3,7 @@ import asyncio
from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union
from sanic.exceptions import InvalidUsage from sanic.exceptions import InvalidUsage
from sanic.websocket import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
ASGIScope = MutableMapping[str, Any] ASGIScope = MutableMapping[str, Any]

View File

@ -0,0 +1,165 @@
from typing import TYPE_CHECKING, Optional, Sequence
from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection
from sanic.exceptions import ServerError
from sanic.log import error_logger
from sanic.server import HttpProtocol
from ..websockets.impl import WebsocketImplProtocol
if TYPE_CHECKING:
from websockets import http11
class WebSocketProtocol(HttpProtocol):
websocket: Optional[WebsocketImplProtocol]
websocket_timeout: float
websocket_max_size = Optional[int]
websocket_ping_interval = Optional[float]
websocket_ping_timeout = Optional[float]
def __init__(
self,
*args,
websocket_timeout: float = 10.0,
websocket_max_size: Optional[int] = None,
websocket_max_queue: Optional[int] = None, # max_queue is deprecated
websocket_read_limit: Optional[int] = None, # read_limit is deprecated
websocket_write_limit: Optional[int] = None, # write_limit deprecated
websocket_ping_interval: Optional[float] = 20.0,
websocket_ping_timeout: Optional[float] = 20.0,
**kwargs,
):
super().__init__(*args, **kwargs)
self.websocket = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
if websocket_max_queue is not None and websocket_max_queue > 0:
# TODO: Reminder remove this warning in v22.3
error_logger.warning(
DeprecationWarning(
"Websocket no longer uses queueing, so websocket_max_queue"
" is no longer required."
)
)
if websocket_read_limit is not None and websocket_read_limit > 0:
# TODO: Reminder remove this warning in v22.3
error_logger.warning(
DeprecationWarning(
"Websocket no longer uses read buffers, so "
"websocket_read_limit is not required."
)
)
if websocket_write_limit is not None and websocket_write_limit > 0:
# TODO: Reminder remove this warning in v22.3
error_logger.warning(
DeprecationWarning(
"Websocket no longer uses write buffers, so "
"websocket_write_limit is not required."
)
)
self.websocket_ping_interval = websocket_ping_interval
self.websocket_ping_timeout = websocket_ping_timeout
def connection_lost(self, exc):
if self.websocket is not None:
self.websocket.connection_lost(exc)
super().connection_lost(exc)
def data_received(self, data):
if self.websocket is not None:
self.websocket.data_received(data)
else:
# Pass it to HttpProtocol handler first
# That will (hopefully) upgrade it to a websocket.
super().data_received(data)
def eof_received(self) -> Optional[bool]:
if self.websocket is not None:
return self.websocket.eof_received()
else:
return False
def close(self, timeout: Optional[float] = None):
# Called by HttpProtocol at the end of connection_task
# If we've upgraded to websocket, we do our own closing
if self.websocket is not None:
# Note, we don't want to use websocket.close()
# That is used for user's application code to send a
# websocket close packet. This is different.
self.websocket.end_connection(1001)
else:
super().close()
def close_if_idle(self):
# Called by Sanic Server when shutting down
# If we've upgraded to websocket, shut it down
if self.websocket is not None:
if self.websocket.connection.state in (CLOSING, CLOSED):
return True
elif self.websocket.loop is not None:
self.websocket.loop.create_task(self.websocket.close(1001))
else:
self.websocket.end_connection(1001)
else:
return super().close_if_idle()
async def websocket_handshake(
self, request, subprotocols=Optional[Sequence[str]]
):
# let the websockets package do the handshake with the client
try:
if subprotocols is not None:
# subprotocols can be a set or frozenset,
# but ServerConnection needs a list
subprotocols = list(subprotocols)
ws_conn = ServerConnection(
max_size=self.websocket_max_size,
subprotocols=subprotocols,
state=OPEN,
logger=error_logger,
)
resp: "http11.Response" = ws_conn.accept(request)
except Exception:
msg = (
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
)
raise ServerError(msg, status_code=500)
if 100 <= resp.status_code <= 299:
rbody = "".join(
[
"HTTP/1.1 ",
str(resp.status_code),
" ",
resp.reason_phrase,
"\r\n",
]
)
rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items())
if resp.body is not None:
rbody += f"\r\n{resp.body}\r\n\r\n"
else:
rbody += "\r\n"
await super().send(rbody.encode())
else:
raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol(
ws_conn,
ping_interval=self.websocket_ping_interval,
ping_timeout=self.websocket_ping_timeout,
close_timeout=self.websocket_timeout,
)
loop = (
request.transport.loop
if hasattr(request, "transport")
and hasattr(request.transport, "loop")
else None
)
await self.websocket.connection_made(self, loop=loop)
return self.websocket

View File

@ -175,15 +175,11 @@ def serve(
# Force close non-idle connection after waiting for # Force close non-idle connection after waiting for
# graceful_shutdown_timeout # graceful_shutdown_timeout
coros = []
for conn in connections: for conn in connections:
if hasattr(conn, "websocket") and conn.websocket: if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection()) conn.websocket.fail_connection(code=1001)
else: else:
conn.abort() conn.abort()
_shutdown = asyncio.gather(*coros)
loop.run_until_complete(_shutdown)
loop.run_until_complete(app._server_event("shutdown", "after")) loop.run_until_complete(app._server_event("shutdown", "after"))
remove_unix_socket(unix) remove_unix_socket(unix)
@ -278,9 +274,6 @@ def _build_protocol_kwargs(
if hasattr(protocol, "websocket_handshake"): if hasattr(protocol, "websocket_handshake"):
return { return {
"websocket_max_size": config.WEBSOCKET_MAX_SIZE, "websocket_max_size": config.WEBSOCKET_MAX_SIZE,
"websocket_max_queue": config.WEBSOCKET_MAX_QUEUE,
"websocket_read_limit": config.WEBSOCKET_READ_LIMIT,
"websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT,
"websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT,
"websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL,
} }

View File

View File

@ -0,0 +1,82 @@
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
)
ASIMessage = MutableMapping[str, Any]
class WebSocketConnection:
"""
This is for ASGI Connections.
It provides an interface similar to WebsocketProtocol, but
sends/receives over an ASGI connection.
"""
# TODO
# - Implement ping/pong
def __init__(
self,
send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]],
subprotocols: Optional[List[str]] = None,
) -> None:
self._send = send
self._receive = receive
self._subprotocols = subprotocols or []
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
if isinstance(data, bytes):
message.update({"bytes": data})
else:
message.update({"text": str(data)})
await self._send(message)
async def recv(self, *args, **kwargs) -> Optional[str]:
message = await self._receive()
if message["type"] == "websocket.receive":
return message["text"]
elif message["type"] == "websocket.disconnect":
pass
return None
receive = recv
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
subprotocol = None
if subprotocols:
for subp in subprotocols:
if subp in self.subprotocols:
subprotocol = subp
break
await self._send(
{
"type": "websocket.accept",
"subprotocol": subprotocol,
}
)
async def close(self, code: int = 1000, reason: str = "") -> None:
pass
@property
def subprotocols(self):
return self._subprotocols
@subprotocols.setter
def subprotocols(self, subprotocols: Optional[List[str]] = None):
self._subprotocols = subprotocols or []

View File

@ -0,0 +1,297 @@
import asyncio
import codecs
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
from websockets.frames import Frame, Opcode
from websockets.typing import Data
from sanic.exceptions import ServerError
if TYPE_CHECKING:
from .impl import WebsocketImplProtocol
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
class WebsocketFrameAssembler:
"""
Assemble a message from frames.
Code borrowed from aaugustin/websockets project:
https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
"""
__slots__ = (
"protocol",
"read_mutex",
"write_mutex",
"message_complete",
"message_fetched",
"get_in_progress",
"decoder",
"completed_queue",
"chunks",
"chunks_queue",
"paused",
"get_id",
"put_id",
)
if TYPE_CHECKING:
protocol: "WebsocketImplProtocol"
read_mutex: asyncio.Lock
write_mutex: asyncio.Lock
message_complete: asyncio.Event
message_fetched: asyncio.Event
completed_queue: asyncio.Queue
get_in_progress: bool
decoder: Optional[codecs.IncrementalDecoder]
# For streaming chunks rather than messages:
chunks: List[Data]
chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
paused: bool
def __init__(self, protocol) -> None:
self.protocol = protocol
self.read_mutex = asyncio.Lock()
self.write_mutex = asyncio.Lock()
self.completed_queue = asyncio.Queue(
maxsize=1
) # type: asyncio.Queue[Data]
# put() sets this event to tell get() that a message can be fetched.
self.message_complete = asyncio.Event()
# get() sets this event to let put()
self.message_fetched = asyncio.Event()
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False
# Decoder for text frames, None for binary frames.
self.decoder = None
# Buffer data from frames belonging to the same message.
self.chunks = []
# When switching from "buffering" to "streaming", we use a thread-safe
# queue for transferring frames from the writing thread (library code)
# to the reading thread (user code). We're buffering when chunks_queue
# is None and streaming when it's a Queue. None is a sentinel
# value marking the end of the stream, superseding message_complete.
# Stream data from frames belonging to the same message.
self.chunks_queue = None
# Flag to indicate we've paused the protocol
self.paused = False
async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
"""
Read the next message.
:meth:`get` returns a single :class:`str` or :class:`bytes`.
If the :message was fragmented, :meth:`get` waits until the last frame
is received, then it reassembles the message.
If ``timeout`` is set and elapses before a complete message is
received, :meth:`get` returns ``None``.
"""
async with self.read_mutex:
if timeout is not None and timeout <= 0:
if not self.message_complete.is_set():
return None
if self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is only here as a failsafe
raise ServerError(
"Called get() on Websocket frame assembler "
"while asynchronous get is already in progress."
)
self.get_in_progress = True
# If the message_complete event isn't set yet, release the lock to
# allow put() to run and eventually set it.
# Locking with get_in_progress ensures only one task can get here.
if timeout is None:
completed = await self.message_complete.wait()
elif timeout <= 0:
completed = self.message_complete.is_set()
else:
try:
await asyncio.wait_for(
self.message_complete.wait(), timeout=timeout
)
except asyncio.TimeoutError:
...
finally:
completed = self.message_complete.is_set()
# Unpause the transport, if its paused
if self.paused:
self.protocol.resume_frames()
self.paused = False
if not self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"State of Websocket frame assembler was modified while an "
"asynchronous get was in progress."
)
self.get_in_progress = False
# Waiting for a complete message timed out.
if not completed:
return None
if not self.message_complete.is_set():
return None
self.message_complete.clear()
joiner: Data = b"" if self.decoder is None else ""
# mypy cannot figure out that chunks have the proper type.
message: Data = joiner.join(self.chunks) # type: ignore
if self.message_fetched.is_set():
# This should be guarded against with the read_mutex,
# and get_in_progress check, this exception is here
# as a failsafe
raise ServerError(
"Websocket get() found a message when "
"state was already fetched."
)
self.message_fetched.set()
self.chunks = []
self.chunks_queue = (
None # this should already be None, but set it here for safety
)
return message
async def get_iter(self) -> AsyncIterator[Data]:
"""
Stream the next message.
Iterating the return value of :meth:`get_iter` yields a :class:`str`
or :class:`bytes` for each frame in the message.
"""
async with self.read_mutex:
if self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is only here as a failsafe
raise ServerError(
"Called get_iter on Websocket frame assembler "
"while asynchronous get is already in progress."
)
self.get_in_progress = True
chunks = self.chunks
self.chunks = []
self.chunks_queue = asyncio.Queue()
# Sending None in chunk_queue supersedes setting message_complete
# when switching to "streaming". If message is already complete
# when the switch happens, put() didn't send None, so we have to.
if self.message_complete.is_set():
await self.chunks_queue.put(None)
# Locking with get_in_progress ensures only one thread can get here
for c in chunks:
yield c
while True:
chunk = await self.chunks_queue.get()
if chunk is None:
break
yield chunk
# Unpause the transport, if its paused
if self.paused:
self.protocol.resume_frames()
self.paused = False
if not self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"State of Websocket frame assembler was modified while an "
"asynchronous get was in progress."
)
self.get_in_progress = False
if not self.message_complete.is_set():
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"Websocket frame assembler chunks queue ended before "
"message was complete."
)
self.message_complete.clear()
if self.message_fetched.is_set():
# This should be guarded against with the read_mutex,
# and get_in_progress check, this exception is
# here as a failsafe
raise ServerError(
"Websocket get_iter() found a message when state was "
"already fetched."
)
self.message_fetched.set()
self.chunks = (
[]
) # this should already be empty, but set it here for safety
self.chunks_queue = None
async def put(self, frame: Frame) -> None:
"""
Add ``frame`` to the next message.
When ``frame`` is the final frame in a message, :meth:`put` waits
until the message is fetched, either by calling :meth:`get` or by
iterating the return value of :meth:`get_iter`.
:meth:`put` assumes that the stream of frames respects the protocol.
If it doesn't, the behavior is undefined.
"""
async with self.write_mutex:
if frame.opcode is Opcode.TEXT:
self.decoder = UTF8Decoder(errors="strict")
elif frame.opcode is Opcode.BINARY:
self.decoder = None
elif frame.opcode is Opcode.CONT:
pass
else:
# Ignore control frames.
return
data: Data
if self.decoder is not None:
data = self.decoder.decode(frame.data, frame.fin)
else:
data = frame.data
if self.chunks_queue is None:
self.chunks.append(data)
else:
await self.chunks_queue.put(data)
if not frame.fin:
return
if not self.get_in_progress:
# nobody is waiting for this frame, so try to pause subsequent
# frames at the protocol level
self.paused = self.protocol.pause_frames()
# Message is complete. Wait until it's fetched to return.
if self.chunks_queue is not None:
await self.chunks_queue.put(None)
if self.message_complete.is_set():
# This should be guarded against with the write_mutex
raise ServerError(
"Websocket put() got a new message when a message was "
"already in its chamber."
)
self.message_complete.set() # Signal to get() it can serve the
if self.message_fetched.is_set():
# This should be guarded against with the write_mutex
raise ServerError(
"Websocket put() got a new message when the previous "
"message was not yet fetched."
)
# Allow get() to run and eventually set the event.
await self.message_fetched.wait()
self.message_fetched.clear()
self.decoder = None

View File

@ -0,0 +1,789 @@
import asyncio
import random
import struct
from typing import (
AsyncIterator,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
from websockets.frames import OP_PONG, Frame
from websockets.server import ServerConnection
from websockets.typing import Data
from sanic.log import error_logger, logger
from sanic.server.protocols.base_protocol import SanicProtocol
from ...exceptions import ServerError
from .frame import WebsocketFrameAssembler
class WebsocketImplProtocol:
connection: ServerConnection
io_proto: Optional[SanicProtocol]
loop: Optional[asyncio.AbstractEventLoop]
max_queue: int
close_timeout: float
ping_interval: Optional[float]
ping_timeout: Optional[float]
assembler: WebsocketFrameAssembler
# Dict[bytes, asyncio.Future[None]]
pings: Dict[bytes, asyncio.Future]
conn_mutex: asyncio.Lock
recv_lock: asyncio.Lock
process_event_mutex: asyncio.Lock
can_pause: bool
# Optional[asyncio.Future[None]]
data_finished_fut: Optional[asyncio.Future]
# Optional[asyncio.Future[None]]
pause_frame_fut: Optional[asyncio.Future]
# Optional[asyncio.Future[None]]
connection_lost_waiter: Optional[asyncio.Future]
keepalive_ping_task: Optional[asyncio.Task]
auto_closer_task: Optional[asyncio.Task]
def __init__(
self,
connection,
max_queue=None,
ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20,
close_timeout: float = 10,
loop=None,
):
self.connection = connection
self.io_proto = None
self.loop = None
self.max_queue = max_queue
self.close_timeout = close_timeout
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.assembler = WebsocketFrameAssembler(self)
self.pings = {}
self.conn_mutex = asyncio.Lock()
self.recv_lock = asyncio.Lock()
self.process_event_mutex = asyncio.Lock()
self.data_finished_fut = None
self.can_pause = True
self.pause_frame_fut = None
self.keepalive_ping_task = None
self.auto_closer_task = None
self.connection_lost_waiter = None
@property
def subprotocol(self):
return self.connection.subprotocol
def pause_frames(self):
if not self.can_pause:
return False
if self.pause_frame_fut:
logger.debug("Websocket connection already paused.")
return False
if (not self.loop) or (not self.io_proto):
return False
if self.io_proto.transport:
self.io_proto.transport.pause_reading()
self.pause_frame_fut = self.loop.create_future()
logger.debug("Websocket connection paused.")
return True
def resume_frames(self):
if not self.pause_frame_fut:
logger.debug("Websocket connection not paused.")
return False
if (not self.loop) or (not self.io_proto):
logger.debug(
"Websocket attempting to resume reading frames, "
"but connection is gone."
)
return False
if self.io_proto.transport:
self.io_proto.transport.resume_reading()
self.pause_frame_fut.set_result(None)
self.pause_frame_fut = None
logger.debug("Websocket connection unpaused.")
return True
async def connection_made(
self,
io_proto: SanicProtocol,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if not loop:
try:
loop = getattr(io_proto, "loop")
except AttributeError:
loop = asyncio.get_event_loop()
if not loop:
# This catch is for mypy type checker
# to assert loop is not None here.
raise ServerError("Connection received with no asyncio loop.")
if self.auto_closer_task:
raise ServerError(
"Cannot call connection_made more than once "
"on a websocket connection."
)
self.loop = loop
self.io_proto = io_proto
self.connection_lost_waiter = self.loop.create_future()
self.data_finished_fut = asyncio.shield(self.loop.create_future())
if self.ping_interval:
self.keepalive_ping_task = asyncio.create_task(
self.keepalive_ping()
)
self.auto_closer_task = asyncio.create_task(
self.auto_close_connection()
)
async def wait_for_connection_lost(self, timeout=None) -> bool:
"""
Wait until the TCP connection is closed or ``timeout`` elapses.
If timeout is None, wait forever.
Recommend you should pass in self.close_timeout as timeout
Return ``True`` if the connection is closed and ``False`` otherwise.
"""
if not self.connection_lost_waiter:
return False
if self.connection_lost_waiter.done():
return True
else:
try:
await asyncio.wait_for(
asyncio.shield(self.connection_lost_waiter), timeout
)
return True
except asyncio.TimeoutError:
# Re-check self.connection_lost_waiter.done() synchronously
# because connection_lost() could run between the moment the
# timeout occurs and the moment this coroutine resumes running
return self.connection_lost_waiter.done()
async def process_events(self, events: Sequence[Event]) -> None:
"""
Process a list of incoming events.
"""
# Wrapped in a mutex lock, to prevent other incoming events
# from processing at the same time
async with self.process_event_mutex:
for event in events:
if not isinstance(event, Frame):
# Event is not a frame. Ignore it.
continue
if event.opcode == OP_PONG:
await self.process_pong(event)
else:
await self.assembler.put(event)
async def process_pong(self, frame: "Frame") -> None:
if frame.data in self.pings:
# Acknowledge all pings up to the one matching this pong.
ping_ids = []
for ping_id, ping in self.pings.items():
ping_ids.append(ping_id)
if not ping.done():
ping.set_result(None)
if ping_id == frame.data:
break
else: # noqa
raise ServerError("ping_id is not in self.pings")
# Remove acknowledged pings from self.pings.
for ping_id in ping_ids:
del self.pings[ping_id]
async def keepalive_ping(self) -> None:
"""
Send a Ping frame and wait for a Pong frame at regular intervals.
This coroutine exits when the connection terminates and one of the
following happens:
- :meth:`ping` raises :exc:`ConnectionClosed`, or
- :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`.
"""
if self.ping_interval is None:
return
try:
while True:
await asyncio.sleep(self.ping_interval)
# ping() raises CancelledError if the connection is closed,
# when auto_close_connection() cancels keepalive_ping_task.
# ping() raises ConnectionClosed if the connection is lost,
# when connection_lost() calls abort_pings().
ping_waiter = await self.ping()
if self.ping_timeout is not None:
try:
await asyncio.wait_for(ping_waiter, self.ping_timeout)
except asyncio.TimeoutError:
error_logger.warning(
"Websocket timed out waiting for pong"
)
self.fail_connection(1011)
break
except asyncio.CancelledError:
# It is expected for this task to be cancelled during during
# normal operation, when the connection is closed.
logger.debug("Websocket keepalive ping task was cancelled.")
except ConnectionClosed:
logger.debug("Websocket closed. Keepalive ping task exiting.")
except Exception as e:
error_logger.warning(
"Unexpected exception in websocket keepalive ping task."
)
logger.debug(str(e))
def _force_disconnect(self) -> bool:
"""
Internal methdod used by end_connection and fail_connection
only when the graceful auto-closer cannot be used
"""
if self.auto_closer_task and not self.auto_closer_task.done():
self.auto_closer_task.cancel()
if self.data_finished_fut and not self.data_finished_fut.done():
self.data_finished_fut.cancel()
self.data_finished_fut = None
if self.keepalive_ping_task and not self.keepalive_ping_task.done():
self.keepalive_ping_task.cancel()
self.keepalive_ping_task = None
if self.loop and self.io_proto and self.io_proto.transport:
self.io_proto.transport.close()
self.loop.call_later(
self.close_timeout, self.io_proto.transport.abort
)
# We were never open, or already closed
return True
def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
"""
Fail the WebSocket Connection
This requires:
1. Stopping all processing of incoming data, which means cancelling
pausing the underlying io protocol. The close code will be 1006
unless a close frame was received earlier.
2. Sending a close frame with an appropriate code if the opening
handshake succeeded and the other side is likely to process it.
3. Closing the connection. :meth:`auto_close_connection` takes care
of this.
(The specification describes these steps in the opposite order.)
"""
if self.io_proto and self.io_proto.transport:
# Stop new data coming in
# In Python Version 3.7: pause_reading is idempotent
# ut can be called when the transport is already paused or closed
self.io_proto.transport.pause_reading()
# Keeping fail_connection() synchronous guarantees it can't
# get stuck and simplifies the implementation of the callers.
# Not draining the write buffer is acceptable in this context.
# clear the send buffer
_ = self.connection.data_to_send()
# If we're not already CLOSED or CLOSING, then send the close.
if self.connection.state is OPEN:
if code in (1000, 1001):
self.connection.send_close(code, reason)
else:
self.connection.fail(code, reason)
try:
data_to_send = self.connection.data_to_send()
while (
len(data_to_send)
and self.io_proto
and self.io_proto.transport
):
frame_data = data_to_send.pop(0)
self.io_proto.transport.write(frame_data)
except Exception:
# sending close frames may fail if the
# transport closes during this period
...
if code == 1006:
# Special case: 1006 consider the transport already closed
self.connection.state = CLOSED
if self.data_finished_fut and not self.data_finished_fut.done():
# We have a graceful auto-closer. Use it to close the connection.
self.data_finished_fut.cancel()
self.data_finished_fut = None
if (not self.auto_closer_task) or self.auto_closer_task.done():
return self._force_disconnect()
return False
def end_connection(self, code=1000, reason=""):
# This is like slightly more graceful form of fail_connection
# Use this instead of close() when you need an immediate
# close and cannot await websocket.close() handshake.
if code == 1006 or not self.io_proto or not self.io_proto.transport:
return self.fail_connection(code, reason)
# Stop new data coming in
# In Python Version 3.7: pause_reading is idempotent
# i.e. it can be called when the transport is already paused or closed.
self.io_proto.transport.pause_reading()
if self.connection.state == OPEN:
data_to_send = self.connection.data_to_send()
self.connection.send_close(code, reason)
data_to_send.extend(self.connection.data_to_send())
try:
while (
len(data_to_send)
and self.io_proto
and self.io_proto.transport
):
frame_data = data_to_send.pop(0)
self.io_proto.transport.write(frame_data)
except Exception:
# sending close frames may fail if the
# transport closes during this period
# But that doesn't matter at this point
...
if self.data_finished_fut and not self.data_finished_fut.done():
# We have the ability to signal the auto-closer
# try to trigger it to auto-close the connection
self.data_finished_fut.cancel()
self.data_finished_fut = None
if (not self.auto_closer_task) or self.auto_closer_task.done():
# Auto-closer is not running, do force disconnect
return self._force_disconnect()
return False
async def auto_close_connection(self) -> None:
"""
Close the WebSocket Connection
When the opening handshake succeeds, :meth:`connection_open` starts
this coroutine in a task. It waits for the data transfer phase to
complete then it closes the TCP connection cleanly.
When the opening handshake fails, :meth:`fail_connection` does the
same. There's no data transfer phase in that case.
"""
try:
# Wait for the data transfer phase to complete.
if self.data_finished_fut:
try:
await self.data_finished_fut
logger.debug(
"Websocket task finished. Closing the connection."
)
except asyncio.CancelledError:
# Cancelled error is called when data phase is cancelled
# if an error occurred or the client closed the connection
logger.debug(
"Websocket handler cancelled. Closing the connection."
)
# Cancel the keepalive ping task.
if self.keepalive_ping_task:
self.keepalive_ping_task.cancel()
self.keepalive_ping_task = None
# Half-close the TCP connection if possible (when there's no TLS).
if (
self.io_proto
and self.io_proto.transport
and self.io_proto.transport.can_write_eof()
):
logger.debug("Websocket half-closing TCP connection")
self.io_proto.transport.write_eof()
if self.connection_lost_waiter:
if await self.wait_for_connection_lost(timeout=0):
return
except asyncio.CancelledError:
...
finally:
# The try/finally ensures that the transport never remains open,
# even if this coroutine is cancelled (for example).
if (not self.io_proto) or (not self.io_proto.transport):
# we were never open, or done. Can't do any finalization.
return
elif (
self.connection_lost_waiter
and self.connection_lost_waiter.done()
):
# connection confirmed closed already, proceed to abort waiter
...
elif self.io_proto.transport.is_closing():
# Connection is already closing (due to half-close above)
# proceed to abort waiter
...
else:
self.io_proto.transport.close()
if not self.connection_lost_waiter:
# Our connection monitor task isn't running.
try:
await asyncio.sleep(self.close_timeout)
except asyncio.CancelledError:
...
if self.io_proto and self.io_proto.transport:
self.io_proto.transport.abort()
else:
if await self.wait_for_connection_lost(
timeout=self.close_timeout
):
# Connection aborted before the timeout expired.
return
error_logger.warning(
"Timeout waiting for TCP connection to close. Aborting"
)
if self.io_proto and self.io_proto.transport:
self.io_proto.transport.abort()
def abort_pings(self) -> None:
"""
Raise ConnectionClosed in pending keepalive pings.
They'll never receive a pong once the connection is closed.
"""
if self.connection.state is not CLOSED:
raise ServerError(
"Webscoket about_pings should only be called "
"after connection state is changed to CLOSED"
)
for ping in self.pings.values():
ping.set_exception(ConnectionClosedError(None, None))
# If the exception is never retrieved, it will be logged when ping
# is garbage-collected. This is confusing for users.
# Given that ping is done (with an exception), canceling it does
# nothing, but it prevents logging the exception.
ping.cancel()
async def close(self, code: int = 1000, reason: str = "") -> None:
"""
Perform the closing handshake.
This is a websocket-protocol level close.
:meth:`close` waits for the other end to complete the handshake and
for the TCP connection to terminate.
:meth:`close` is idempotent: it doesn't do anything once the
connection is closed.
:param code: WebSocket close code
:param reason: WebSocket close reason
"""
if code == 1006:
self.fail_connection(code, reason)
return
async with self.conn_mutex:
if self.connection.state is OPEN:
self.connection.send_close(code, reason)
data_to_send = self.connection.data_to_send()
await self.send_data(data_to_send)
async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
"""
Receive the next message.
Return a :class:`str` for a text frame and :class:`bytes` for a binary
frame.
When the end of the message stream is reached, :meth:`recv` raises
:exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
connection closure and
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
error or a network failure.
If ``timeout`` is ``None``, block until a message is received. Else,
if no message is received within ``timeout`` seconds, return ``None``.
Set ``timeout`` to ``0`` to check if a message was already received.
:raises ~websockets.exceptions.ConnectionClosed: when the
connection is closed
:raises ServerError: if two tasks call :meth:`recv` or
:meth:`recv_streaming` concurrently
"""
if self.recv_lock.locked():
raise ServerError(
"cannot call recv while another task is "
"already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING):
raise ServerError(
"Cannot receive from websocket interface after it is closed."
)
try:
return await self.assembler.get(timeout)
finally:
self.recv_lock.release()
async def recv_burst(self, max_recv=256) -> Sequence[Data]:
"""
Receive the messages which have arrived since last checking.
Return a :class:`list` containing :class:`str` for a text frame
and :class:`bytes` for a binary frame.
When the end of the message stream is reached, :meth:`recv_burst`
raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically,
it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a
normal connection closure and
:exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
error or a network failure.
:raises ~websockets.exceptions.ConnectionClosed: when the
connection is closed
:raises ServerError: if two tasks call :meth:`recv_burst` or
:meth:`recv_streaming` concurrently
"""
if self.recv_lock.locked():
raise ServerError(
"cannot call recv_burst while another task is already waiting "
"for the next message"
)
await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING):
raise ServerError(
"Cannot receive from websocket interface after it is closed."
)
messages = []
try:
# Prevent pausing the transport when we're
# receiving a burst of messages
self.can_pause = False
while True:
m = await self.assembler.get(timeout=0)
if m is None:
# None left in the burst. This is good!
break
messages.append(m)
if len(messages) >= max_recv:
# Too much data in the pipe. Hit our burst limit.
break
# Allow an eventloop iteration for the
# next message to pass into the Assembler
await asyncio.sleep(0)
finally:
self.can_pause = True
self.recv_lock.release()
return messages
async def recv_streaming(self) -> AsyncIterator[Data]:
"""
Receive the next message frame by frame.
Return an iterator of :class:`str` for a text frame and :class:`bytes`
for a binary frame. The iterator should be exhausted, or else the
connection will become unusable.
With the exception of the return value, :meth:`recv_streaming` behaves
like :meth:`recv`.
"""
if self.recv_lock.locked():
raise ServerError(
"Cannot call recv_streaming while another task "
"is already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING):
raise ServerError(
"Cannot receive from websocket interface after it is closed."
)
try:
self.can_pause = False
async for m in self.assembler.get_iter():
yield m
finally:
self.can_pause = True
self.recv_lock.release()
async def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""
Send a message.
A string (:class:`str`) is sent as a `Text frame`_. A bytestring or
bytes-like object (:class:`bytes`, :class:`bytearray`, or
:class:`memoryview`) is sent as a `Binary frame`_.
.. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6
.. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6
:meth:`send` also accepts an iterable of strings, bytestrings, or
bytes-like objects. In that case the message is fragmented. Each item
is treated as a message fragment and sent in its own frame. All items
must be of the same type, or else :meth:`send` will raise a
:exc:`TypeError` and the connection will be closed.
:meth:`send` rejects dict-like objects because this is often an error.
If you wish to send the keys of a dict-like object as fragments, call
its :meth:`~dict.keys` method and pass the result to :meth:`send`.
:raises TypeError: for unsupported inputs
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
raise ServerError(
"Cannot write to websocket interface after it is closed."
)
if (not self.data_finished_fut) or self.data_finished_fut.done():
raise ServerError(
"Cannot write to websocket interface after it is finished."
)
# Unfragmented message -- this case must be handled first because
# strings and bytes-like objects are iterable.
if isinstance(message, str):
self.connection.send_text(message.encode("utf-8"))
await self.send_data(self.connection.data_to_send())
elif isinstance(message, (bytes, bytearray, memoryview)):
self.connection.send_binary(message)
await self.send_data(self.connection.data_to_send())
elif isinstance(message, Mapping):
# Catch a common mistake -- passing a dict to send().
raise TypeError("data is a dict-like object")
elif isinstance(message, Iterable):
# Fragmented message -- regular iterator.
raise NotImplementedError(
"Fragmented websocket messages are not supported."
)
else:
raise TypeError("Websocket data must be bytes, str.")
async def ping(self, data: Optional[Data] = None) -> asyncio.Future:
"""
Send a ping.
Return an :class:`~asyncio.Future` that will be resolved when the
corresponding pong is received. You can ignore it if you don't intend
to wait.
A ping may serve as a keepalive or as a check that the remote endpoint
received all messages up to this point::
await pong_event = ws.ping()
await pong_event # only if you want to wait for the pong
By default, the ping contains four random bytes. This payload may be
overridden with the optional ``data`` argument which must be a string
(which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
raise ServerError(
"Cannot send a ping when the websocket interface "
"is closed."
)
if (not self.io_proto) or (not self.io_proto.loop):
raise ServerError(
"Cannot send a ping when the websocket has no I/O "
"protocol attached."
)
if data is not None:
if isinstance(data, str):
data = data.encode("utf-8")
elif isinstance(data, (bytearray, memoryview)):
data = bytes(data)
# Protect against duplicates if a payload is explicitly set.
if data in self.pings:
raise ValueError(
"already waiting for a pong with the same data"
)
# Generate a unique random payload otherwise.
while data is None or data in self.pings:
data = struct.pack("!I", random.getrandbits(32))
self.pings[data] = self.io_proto.loop.create_future()
self.connection.send_ping(data)
await self.send_data(self.connection.data_to_send())
return asyncio.shield(self.pings[data])
async def pong(self, data: Data = b"") -> None:
"""
Send a pong.
An unsolicited pong may serve as a unidirectional heartbeat.
The payload may be set with the optional ``data`` argument which must
be a string (which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
# Cannot send pong after transport is shutting down
return
if isinstance(data, str):
data = data.encode("utf-8")
elif isinstance(data, (bytearray, memoryview)):
data = bytes(data)
self.connection.send_pong(data)
await self.send_data(self.connection.data_to_send())
async def send_data(self, data_to_send):
for data in data_to_send:
if data:
await self.io_proto.send(data)
else:
# Send an EOF - We don't actually send it,
# just trigger to autoclose the connection
if (
self.auto_closer_task
and not self.auto_closer_task.done()
and self.data_finished_fut
and not self.data_finished_fut.done()
):
# Auto-close the connection
self.data_finished_fut.set_result(None)
else:
# This will fail the connection appropriately
SanicProtocol.close(self.io_proto, timeout=1.0)
async def async_data_received(self, data_to_send, events_to_process):
if self.connection.state == OPEN and len(data_to_send) > 0:
# receiving data can generate data to send (eg, pong for a ping)
# send connection.data_to_send()
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)
def data_received(self, data):
self.connection.receive_data(data)
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0:
asyncio.create_task(
self.async_data_received(data_to_send, events_to_process)
)
async def async_eof_received(self, data_to_send, events_to_process):
# receiving EOF can generate data to send
# send connection.data_to_send()
if self.connection.state == OPEN:
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)
if (
self.auto_closer_task
and not self.auto_closer_task.done()
and self.data_finished_fut
and not self.data_finished_fut.done()
):
# Auto-close the connection
self.data_finished_fut.set_result(None)
else:
# This will fail the connection appropriately
SanicProtocol.close(self.io_proto, timeout=1.0)
def eof_received(self) -> Optional[bool]:
self.connection.receive_eof()
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0:
asyncio.create_task(
self.async_eof_received(data_to_send, events_to_process)
)
return False
def connection_lost(self, exc):
"""
The WebSocket Connection is Closed.
"""
if not self.connection.state == CLOSED:
# signal to the websocket connection handler
# we've lost the connection
self.connection.fail(code=1006)
self.connection.state = CLOSED
self.abort_pings()
if self.connection_lost_waiter:
self.connection_lost_waiter.set_result(None)

View File

@ -1,205 +0,0 @@
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
)
from httptools import HttpParserUpgrade # type: ignore
from websockets import ( # type: ignore
ConnectionClosed,
InvalidHandshake,
WebSocketCommonProtocol,
)
# Despite the "legacy" namespace, the primary maintainer of websockets
# committed to maintaining backwards-compatibility until 2026 and will
# consider extending it if sanic continues depending on this module.
from websockets.legacy import handshake
from sanic.exceptions import InvalidUsage
from sanic.server import HttpProtocol
__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"]
ASIMessage = MutableMapping[str, Any]
class WebSocketProtocol(HttpProtocol):
def __init__(
self,
*args,
websocket_timeout=10,
websocket_max_size=None,
websocket_max_queue=None,
websocket_read_limit=2 ** 16,
websocket_write_limit=2 ** 16,
websocket_ping_interval=20,
websocket_ping_timeout=20,
**kwargs,
):
super().__init__(*args, **kwargs)
self.websocket = None
# self.app = None
self.websocket_timeout = websocket_timeout
self.websocket_max_size = websocket_max_size
self.websocket_max_queue = websocket_max_queue
self.websocket_read_limit = websocket_read_limit
self.websocket_write_limit = websocket_write_limit
self.websocket_ping_interval = websocket_ping_interval
self.websocket_ping_timeout = websocket_ping_timeout
# timeouts make no sense for websocket routes
def request_timeout_callback(self):
if self.websocket is None:
super().request_timeout_callback()
def response_timeout_callback(self):
if self.websocket is None:
super().response_timeout_callback()
def keep_alive_timeout_callback(self):
if self.websocket is None:
super().keep_alive_timeout_callback()
def connection_lost(self, exc):
if self.websocket is not None:
self.websocket.connection_lost(exc)
super().connection_lost(exc)
def data_received(self, data):
if self.websocket is not None:
# pass the data to the websocket protocol
self.websocket.data_received(data)
else:
try:
super().data_received(data)
except HttpParserUpgrade:
# this is okay, it just indicates we've got an upgrade request
pass
def write_response(self, response):
if self.websocket is not None:
# websocket requests do not write a response
self.transport.close()
else:
super().write_response(response)
async def websocket_handshake(self, request, subprotocols=None):
# let the websockets package do the handshake with the client
headers = {}
try:
key = handshake.check_request(request.headers)
handshake.build_response(headers, key)
except InvalidHandshake:
raise InvalidUsage("Invalid websocket request")
subprotocol = None
if subprotocols and "Sec-Websocket-Protocol" in request.headers:
# select a subprotocol
client_subprotocols = [
p.strip()
for p in request.headers["Sec-Websocket-Protocol"].split(",")
]
for p in client_subprotocols:
if p in subprotocols:
subprotocol = p
headers["Sec-Websocket-Protocol"] = subprotocol
break
# write the 101 response back to the client
rv = b"HTTP/1.1 101 Switching Protocols\r\n"
for k, v in headers.items():
rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n"
rv += b"\r\n"
request.transport.write(rv)
# hook up the websocket protocol
self.websocket = WebSocketCommonProtocol(
close_timeout=self.websocket_timeout,
max_size=self.websocket_max_size,
max_queue=self.websocket_max_queue,
read_limit=self.websocket_read_limit,
write_limit=self.websocket_write_limit,
ping_interval=self.websocket_ping_interval,
ping_timeout=self.websocket_ping_timeout,
)
# we use WebSocketCommonProtocol because we don't want the handshake
# logic from WebSocketServerProtocol; however, we must tell it that
# we're running on the server side
self.websocket.is_client = False
self.websocket.side = "server"
self.websocket.subprotocol = subprotocol
self.websocket.connection_made(request.transport)
self.websocket.connection_open()
return self.websocket
class WebSocketConnection:
# TODO
# - Implement ping/pong
def __init__(
self,
send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]],
subprotocols: Optional[List[str]] = None,
) -> None:
self._send = send
self._receive = receive
self._subprotocols = subprotocols or []
async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
if isinstance(data, bytes):
message.update({"bytes": data})
else:
message.update({"text": str(data)})
await self._send(message)
async def recv(self, *args, **kwargs) -> Optional[str]:
message = await self._receive()
if message["type"] == "websocket.receive":
return message["text"]
elif message["type"] == "websocket.disconnect":
pass
return None
receive = recv
async def accept(self, subprotocols: Optional[List[str]] = None) -> None:
subprotocol = None
if subprotocols:
for subp in subprotocols:
if subp in self.subprotocols:
subprotocol = subp
break
await self._send(
{
"type": "websocket.accept",
"subprotocol": subprotocol,
}
)
async def close(self) -> None:
pass
@property
def subprotocols(self):
return self._subprotocols
@subprotocols.setter
def subprotocols(self, subprotocols: Optional[List[str]] = None):
self._subprotocols = subprotocols or []

View File

@ -9,7 +9,7 @@ from gunicorn.workers import base # type: ignore
from sanic.log import logger from sanic.log import logger
from sanic.server import HttpProtocol, Signal, serve from sanic.server import HttpProtocol, Signal, serve
from sanic.websocket import WebSocketProtocol from sanic.server.protocols.websocket_protocol import WebSocketProtocol
try: try:
@ -142,14 +142,11 @@ class GunicornWorker(base.Worker):
# Force close non-idle connection after waiting for # Force close non-idle connection after waiting for
# graceful_shutdown_timeout # graceful_shutdown_timeout
coros = []
for conn in self.connections: for conn in self.connections:
if hasattr(conn, "websocket") and conn.websocket: if hasattr(conn, "websocket") and conn.websocket:
coros.append(conn.websocket.close_connection()) conn.websocket.fail_connection(code=1001)
else: else:
conn.close() conn.abort()
_shutdown = asyncio.gather(*coros, loop=self.loop)
await _shutdown
async def _run(self): async def _run(self):
for sock in self.sockets: for sock in self.sockets:

View File

@ -88,7 +88,7 @@ requirements = [
uvloop, uvloop,
ujson, ujson,
"aiofiles>=0.6.0", "aiofiles>=0.6.0",
"websockets>=9.0", "websockets>=10.0",
"multidict>=5.0,<6.0", "multidict>=5.0,<6.0",
] ]

View File

@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable):
@patch("sanic.app.WebSocketProtocol") @patch("sanic.app.WebSocketProtocol")
def test_app_websocket_parameters(websocket_protocol_mock, app): def test_app_websocket_parameters(websocket_protocol_mock, app):
app.config.WEBSOCKET_MAX_SIZE = 44 app.config.WEBSOCKET_MAX_SIZE = 44
app.config.WEBSOCKET_MAX_QUEUE = 45
app.config.WEBSOCKET_READ_LIMIT = 46
app.config.WEBSOCKET_WRITE_LIMIT = 47
app.config.WEBSOCKET_PING_TIMEOUT = 48 app.config.WEBSOCKET_PING_TIMEOUT = 48
app.config.WEBSOCKET_PING_INTERVAL = 50 app.config.WEBSOCKET_PING_INTERVAL = 50
@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app):
websocket_protocol_call_args = websocket_protocol_mock.call_args websocket_protocol_call_args = websocket_protocol_mock.call_args
ws_kwargs = websocket_protocol_call_args[1] ws_kwargs = websocket_protocol_call_args[1]
assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE
assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE
assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT
assert (
ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT
)
assert ( assert (
ws_kwargs["websocket_ping_timeout"] ws_kwargs["websocket_ping_timeout"]
== app.config.WEBSOCKET_PING_TIMEOUT == app.config.WEBSOCKET_PING_TIMEOUT

View File

@ -10,7 +10,7 @@ from sanic.asgi import MockTransport
from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable
from sanic.request import Request from sanic.request import Request
from sanic.response import json, text from sanic.response import json, text
from sanic.websocket import WebSocketConnection from sanic.server.websockets.connection import WebSocketConnection
@pytest.fixture @pytest.fixture

View File

@ -16,6 +16,7 @@ from sanic.exceptions import (
abort, abort,
) )
from sanic.response import text from sanic.response import text
from websockets.version import version as websockets_version
class SanicExceptionTestException(Exception): class SanicExceptionTestException(Exception):
@ -260,9 +261,14 @@ def test_exception_in_ws_logged(caplog):
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
app.test_client.websocket("/feed") app.test_client.websocket("/feed")
# Websockets v10.0 and above output an additional
assert caplog.record_tuples[1][0] == "sanic.error" # INFO message when a ws connection is accepted
assert caplog.record_tuples[1][1] == logging.ERROR ws_version_parts = websockets_version.split(".")
ws_major = int(ws_version_parts[0])
record_index = 2 if ws_major >= 10 else 1
assert caplog.record_tuples[record_index][0] == "sanic.error"
assert caplog.record_tuples[record_index][1] == logging.ERROR
assert ( assert (
"Exception occurred while handling uri:" in caplog.record_tuples[1][2] "Exception occurred while handling uri:"
in caplog.record_tuples[record_index][2]
) )

View File

@ -674,16 +674,16 @@ async def test_websocket_route_asgi(app, url):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"subprotocols,expected", "subprotocols,expected",
( (
(["bar"], "bar"), (["one"], "one"),
(["bar", "foo"], "bar"), (["three", "one"], "one"),
(["baz"], None), (["tree"], None),
(None, None), (None, None),
), ),
) )
def test_websocket_route_with_subprotocols(app, subprotocols, expected): def test_websocket_route_with_subprotocols(app, subprotocols, expected):
results = "unset" results = []
@app.websocket("/ws", subprotocols=["foo", "bar"]) @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"])
async def handler(request, ws): async def handler(request, ws):
nonlocal results nonlocal results
results = ws.subprotocol results = ws.subprotocol

View File

@ -175,7 +175,7 @@ def test_worker_close(worker):
worker.wsgi = mock.Mock() worker.wsgi = mock.Mock()
conn = mock.Mock() conn = mock.Mock()
conn.websocket = mock.Mock() conn.websocket = mock.Mock()
conn.websocket.close_connection = mock.Mock(wraps=_a_noop) conn.websocket.fail_connection = mock.Mock(wraps=_a_noop)
worker.connections = set([conn]) worker.connections = set([conn])
worker.log = mock.Mock() worker.log = mock.Mock()
worker.loop = loop worker.loop = loop
@ -190,5 +190,5 @@ def test_worker_close(worker):
loop.run_until_complete(_close) loop.run_until_complete(_close)
assert worker.signal.stopped assert worker.signal.stopped
assert conn.websocket.close_connection.called assert conn.websocket.fail_connection.called
assert len(worker.servers) == 0 assert len(worker.servers) == 0