diff --git a/examples/websocket.py b/examples/websocket.py index 9cba083c..92f71375 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,13 +1,14 @@ from sanic import Sanic -from sanic.response import file +from sanic.response import redirect app = Sanic(__name__) -@app.route('/') -async def index(request): - return await file('websocket.html') +app.static('index.html', "websocket.html") +@app.route('/') +def index(request): + return redirect("index.html") @app.websocket('/feed') async def feed(request, ws): diff --git a/sanic/app.py b/sanic/app.py index 634ba665..38d4b1d2 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -74,9 +74,10 @@ from sanic.router import Router from sanic.server import AsyncioServer, HttpProtocol from sanic.server import Signal as ServerSignal from sanic.server import serve, serve_multiple, serve_single +from sanic.server.protocols.websocket_protocol import WebSocketProtocol +from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta -from sanic.websocket import ConnectionClosed, WebSocketProtocol class Sanic(BaseSanic, metaclass=TouchUpMeta): @@ -871,23 +872,11 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs ): - request.app = self - if not getattr(handler, "__blueprintname__", False): - request._name = handler.__name__ - else: - request._name = ( - getattr(handler, "__blueprintname__", "") + handler.__name__ - ) - - pass - if self.asgi: ws = request.transport.get_websocket_connection() await ws.accept(subprotocols) else: protocol = request.transport.get_protocol() - protocol.app = self - ws = await protocol.websocket_handshake(request, subprotocols) # schedule the application handler @@ -895,15 +884,19 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta): # needs to be cancelled due to the server being stopped fut = ensure_future(handler(request, ws, *args, **kwargs)) self.websocket_tasks.add(fut) + cancelled = False try: await fut except Exception as e: self.error_handler.log(request, e) except (CancelledError, ConnectionClosed): - pass + cancelled = True finally: self.websocket_tasks.remove(fut) - await ws.close() + if cancelled: + ws.end_connection(1000) + else: + await ws.close() # -------------------------------------------------------------------- # # Testing diff --git a/sanic/asgi.py b/sanic/asgi.py index 13d4f87c..55c18d5c 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -10,7 +10,7 @@ from sanic.exceptions import ServerError from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.server import ConnInfo -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection class Lifespan: diff --git a/sanic/config.py b/sanic/config.py index 27699f80..2a90c5fb 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -35,12 +35,9 @@ DEFAULT_CONFIG = { "REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds - "WEBSOCKET_MAX_QUEUE": 32, "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte "WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_TIMEOUT": 20, - "WEBSOCKET_READ_LIMIT": 2 ** 16, - "WEBSOCKET_WRITE_LIMIT": 2 ** 16, } @@ -62,12 +59,9 @@ class Config(dict): REQUEST_MAX_SIZE: int REQUEST_TIMEOUT: int RESPONSE_TIMEOUT: int - WEBSOCKET_MAX_QUEUE: int WEBSOCKET_MAX_SIZE: int WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_TIMEOUT: int - WEBSOCKET_READ_LIMIT: int - WEBSOCKET_WRITE_LIMIT: int def __init__( self, diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 788d7017..6881f291 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -121,8 +121,11 @@ class RouteMixin: "Expected either string or Iterable of host strings, " "not %s" % host ) - - if isinstance(subprotocols, (list, tuple, set)): + if isinstance(subprotocols, list): + # Ordered subprotocols, maintain order + subprotocols = tuple(subprotocols) + elif isinstance(subprotocols, set): + # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) route = FutureRoute( diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 595b0553..1b707ebc 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -3,7 +3,7 @@ import asyncio from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from sanic.exceptions import InvalidUsage -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection ASGIScope = MutableMapping[str, Any] diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py new file mode 100644 index 00000000..628945d2 --- /dev/null +++ b/sanic/server/protocols/websocket_protocol.py @@ -0,0 +1,165 @@ +from typing import TYPE_CHECKING, Optional, Sequence + +from websockets.connection import CLOSED, CLOSING, OPEN +from websockets.server import ServerConnection + +from sanic.exceptions import ServerError +from sanic.log import error_logger +from sanic.server import HttpProtocol + +from ..websockets.impl import WebsocketImplProtocol + + +if TYPE_CHECKING: + from websockets import http11 + + +class WebSocketProtocol(HttpProtocol): + + websocket: Optional[WebsocketImplProtocol] + websocket_timeout: float + websocket_max_size = Optional[int] + websocket_ping_interval = Optional[float] + websocket_ping_timeout = Optional[float] + + def __init__( + self, + *args, + websocket_timeout: float = 10.0, + websocket_max_size: Optional[int] = None, + websocket_max_queue: Optional[int] = None, # max_queue is deprecated + websocket_read_limit: Optional[int] = None, # read_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit deprecated + websocket_ping_interval: Optional[float] = 20.0, + websocket_ping_timeout: Optional[float] = 20.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.websocket = None + self.websocket_timeout = websocket_timeout + self.websocket_max_size = websocket_max_size + if websocket_max_queue is not None and websocket_max_queue > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses queueing, so websocket_max_queue" + " is no longer required." + ) + ) + if websocket_read_limit is not None and websocket_read_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses read buffers, so " + "websocket_read_limit is not required." + ) + ) + if websocket_write_limit is not None and websocket_write_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses write buffers, so " + "websocket_write_limit is not required." + ) + ) + self.websocket_ping_interval = websocket_ping_interval + self.websocket_ping_timeout = websocket_ping_timeout + + def connection_lost(self, exc): + if self.websocket is not None: + self.websocket.connection_lost(exc) + super().connection_lost(exc) + + def data_received(self, data): + if self.websocket is not None: + self.websocket.data_received(data) + else: + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) + + def eof_received(self) -> Optional[bool]: + if self.websocket is not None: + return self.websocket.eof_received() + else: + return False + + def close(self, timeout: Optional[float] = None): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closing + if self.websocket is not None: + # Note, we don't want to use websocket.close() + # That is used for user's application code to send a + # websocket close packet. This is different. + self.websocket.end_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + elif self.websocket.loop is not None: + self.websocket.loop.create_task(self.websocket.close(1001)) + else: + self.websocket.end_connection(1001) + else: + return super().close_if_idle() + + async def websocket_handshake( + self, request, subprotocols=Optional[Sequence[str]] + ): + # let the websockets package do the handshake with the client + try: + if subprotocols is not None: + # subprotocols can be a set or frozenset, + # but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_conn = ServerConnection( + max_size=self.websocket_max_size, + subprotocols=subprotocols, + state=OPEN, + logger=error_logger, + ) + resp: "http11.Response" = ws_conn.accept(request) + except Exception: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise ServerError(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbody = "".join( + [ + "HTTP/1.1 ", + str(resp.status_code), + " ", + resp.reason_phrase, + "\r\n", + ] + ) + rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + if resp.body is not None: + rbody += f"\r\n{resp.body}\r\n\r\n" + else: + rbody += "\r\n" + await super().send(rbody.encode()) + else: + raise ServerError(resp.body, resp.status_code) + + self.websocket = WebsocketImplProtocol( + ws_conn, + ping_interval=self.websocket_ping_interval, + ping_timeout=self.websocket_ping_timeout, + close_timeout=self.websocket_timeout, + ) + loop = ( + request.transport.loop + if hasattr(request, "transport") + and hasattr(request.transport, "loop") + else None + ) + await self.websocket.connection_made(self, loop=loop) + return self.websocket diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 39d8f736..f0bebb03 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -175,15 +175,11 @@ def serve( # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + conn.websocket.fail_connection(code=1001) else: conn.abort() - - _shutdown = asyncio.gather(*coros) - loop.run_until_complete(_shutdown) loop.run_until_complete(app._server_event("shutdown", "after")) remove_unix_socket(unix) @@ -278,9 +274,6 @@ def _build_protocol_kwargs( if hasattr(protocol, "websocket_handshake"): return { "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, - "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, - "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, } diff --git a/sanic/server/websockets/__init__.py b/sanic/server/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py new file mode 100644 index 00000000..c53a65a5 --- /dev/null +++ b/sanic/server/websockets/connection.py @@ -0,0 +1,82 @@ +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + MutableMapping, + Optional, + Union, +) + + +ASIMessage = MutableMapping[str, Any] + + +class WebSocketConnection: + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ + + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, + ) -> None: + self._send = send + self._receive = receive + self._subprotocols = subprotocols or [] + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} + + if isinstance(data, bytes): + message.update({"bytes": data}) + else: + message.update({"text": str(data)}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + + return None + + receive = recv + + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + + await self._send( + { + "type": "websocket.accept", + "subprotocol": subprotocol, + } + ) + + async def close(self, code: int = 1000, reason: str = "") -> None: + pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py new file mode 100644 index 00000000..b4af72b1 --- /dev/null +++ b/sanic/server/websockets/frame.py @@ -0,0 +1,297 @@ +import asyncio +import codecs + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from websockets.frames import Frame, Opcode +from websockets.typing import Data + +from sanic.exceptions import ServerError + + +if TYPE_CHECKING: + from .impl import WebsocketImplProtocol + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + + __slots__ = ( + "protocol", + "read_mutex", + "write_mutex", + "message_complete", + "message_fetched", + "get_in_progress", + "decoder", + "completed_queue", + "chunks", + "chunks_queue", + "paused", + "get_id", + "put_id", + ) + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue( + maxsize=1 + ) # type: asyncio.Queue[Data] + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder = None + + # Buffer data from frames belonging to the same message. + self.chunks = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue = None + + # Flag to indicate we've paused the protocol + self.paused = False + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one task can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + try: + await asyncio.wait_for( + self.message_complete.wait(), timeout=timeout + ) + except asyncio.TimeoutError: + ... + finally: + completed = self.message_complete.is_set() + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + if not self.message_complete.is_set(): + return None + + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is here + # as a failsafe + raise ServerError( + "Websocket get() found a message when " + "state was already fetched." + ) + self.message_fetched.set() + self.chunks = [] + self.chunks_queue = ( + None # this should already be None, but set it here for safety + ) + + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one thread can get here + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + if not self.message_complete.is_set(): + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "Websocket frame assembler chunks queue ended before " + "message was complete." + ) + self.message_complete.clear() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is + # here as a failsafe + raise ServerError( + "Websocket get_iter() found a message when state was " + "already fetched." + ) + + self.message_fetched.set() + self.chunks = ( + [] + ) # this should already be empty, but set it here for safety + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + # nobody is waiting for this frame, so try to pause subsequent + # frames at the protocol level + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + if self.message_complete.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + self.message_complete.set() # Signal to get() it can serve the + if self.message_fetched.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when the previous " + "message was not yet fetched." + ) + + # Allow get() to run and eventually set the event. + await self.message_fetched.wait() + self.message_fetched.clear() + self.decoder = None diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py new file mode 100644 index 00000000..a2778c57 --- /dev/null +++ b/sanic/server/websockets/impl.py @@ -0,0 +1,789 @@ +import asyncio +import random +import struct + +from typing import ( + AsyncIterator, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +from websockets.connection import CLOSED, CLOSING, OPEN, Event +from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.frames import OP_PONG, Frame +from websockets.server import ServerConnection +from websockets.typing import Data + +from sanic.log import error_logger, logger +from sanic.server.protocols.base_protocol import SanicProtocol + +from ...exceptions import ServerError +from .frame import WebsocketFrameAssembler + + +class WebsocketImplProtocol: + connection: ServerConnection + io_proto: Optional[SanicProtocol] + loop: Optional[asyncio.AbstractEventLoop] + max_queue: int + close_timeout: float + ping_interval: Optional[float] + ping_timeout: Optional[float] + assembler: WebsocketFrameAssembler + # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] + conn_mutex: asyncio.Lock + recv_lock: asyncio.Lock + process_event_mutex: asyncio.Lock + can_pause: bool + # Optional[asyncio.Future[None]] + data_finished_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] + keepalive_ping_task: Optional[asyncio.Task] + auto_closer_task: Optional[asyncio.Task] + + def __init__( + self, + connection, + max_queue=None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: float = 10, + loop=None, + ): + self.connection = connection + self.io_proto = None + self.loop = None + self.max_queue = max_queue + self.close_timeout = close_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.assembler = WebsocketFrameAssembler(self) + self.pings = {} + self.conn_mutex = asyncio.Lock() + self.recv_lock = asyncio.Lock() + self.process_event_mutex = asyncio.Lock() + self.data_finished_fut = None + self.can_pause = True + self.pause_frame_fut = None + self.keepalive_ping_task = None + self.auto_closer_task = None + self.connection_lost_waiter = None + + @property + def subprotocol(self): + return self.connection.subprotocol + + def pause_frames(self): + if not self.can_pause: + return False + if self.pause_frame_fut: + logger.debug("Websocket connection already paused.") + return False + if (not self.loop) or (not self.io_proto): + return False + if self.io_proto.transport: + self.io_proto.transport.pause_reading() + self.pause_frame_fut = self.loop.create_future() + logger.debug("Websocket connection paused.") + return True + + def resume_frames(self): + if not self.pause_frame_fut: + logger.debug("Websocket connection not paused.") + return False + if (not self.loop) or (not self.io_proto): + logger.debug( + "Websocket attempting to resume reading frames, " + "but connection is gone." + ) + return False + if self.io_proto.transport: + self.io_proto.transport.resume_reading() + self.pause_frame_fut.set_result(None) + self.pause_frame_fut = None + logger.debug("Websocket connection unpaused.") + return True + + async def connection_made( + self, + io_proto: SanicProtocol, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if not loop: + try: + loop = getattr(io_proto, "loop") + except AttributeError: + loop = asyncio.get_event_loop() + if not loop: + # This catch is for mypy type checker + # to assert loop is not None here. + raise ServerError("Connection received with no asyncio loop.") + if self.auto_closer_task: + raise ServerError( + "Cannot call connection_made more than once " + "on a websocket connection." + ) + self.loop = loop + self.io_proto = io_proto + self.connection_lost_waiter = self.loop.create_future() + self.data_finished_fut = asyncio.shield(self.loop.create_future()) + + if self.ping_interval: + self.keepalive_ping_task = asyncio.create_task( + self.keepalive_ping() + ) + self.auto_closer_task = asyncio.create_task( + self.auto_close_connection() + ) + + async def wait_for_connection_lost(self, timeout=None) -> bool: + """ + Wait until the TCP connection is closed or ``timeout`` elapses. + If timeout is None, wait forever. + Recommend you should pass in self.close_timeout as timeout + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter: + return False + if self.connection_lost_waiter.done(): + return True + else: + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), timeout + ) + return True + except asyncio.TimeoutError: + # Re-check self.connection_lost_waiter.done() synchronously + # because connection_lost() could run between the moment the + # timeout occurs and the moment this coroutine resumes running + return self.connection_lost_waiter.done() + + async def process_events(self, events: Sequence[Event]) -> None: + """ + Process a list of incoming events. + """ + # Wrapped in a mutex lock, to prevent other incoming events + # from processing at the same time + async with self.process_event_mutex: + for event in events: + if not isinstance(event, Frame): + # Event is not a frame. Ignore it. + continue + if event.opcode == OP_PONG: + await self.process_pong(event) + else: + await self.assembler.put(event) + + async def process_pong(self, frame: "Frame") -> None: + if frame.data in self.pings: + # Acknowledge all pings up to the one matching this pong. + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # noqa + raise ServerError("ping_id is not in self.pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when auto_close_connection() cancels keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + ping_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for(ping_waiter, self.ping_timeout) + except asyncio.TimeoutError: + error_logger.warning( + "Websocket timed out waiting for pong" + ) + self.fail_connection(1011) + break + except asyncio.CancelledError: + # It is expected for this task to be cancelled during during + # normal operation, when the connection is closed. + logger.debug("Websocket keepalive ping task was cancelled.") + except ConnectionClosed: + logger.debug("Websocket closed. Keepalive ping task exiting.") + except Exception as e: + error_logger.warning( + "Unexpected exception in websocket keepalive ping task." + ) + logger.debug(str(e)) + + def _force_disconnect(self) -> bool: + """ + Internal methdod used by end_connection and fail_connection + only when the graceful auto-closer cannot be used + """ + if self.auto_closer_task and not self.auto_closer_task.done(): + self.auto_closer_task.cancel() + if self.data_finished_fut and not self.data_finished_fut.done(): + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.keepalive_ping_task and not self.keepalive_ping_task.done(): + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + if self.loop and self.io_proto and self.io_proto.transport: + self.io_proto.transport.close() + self.loop.call_later( + self.close_timeout, self.io_proto.transport.abort + ) + # We were never open, or already closed + return True + + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: + """ + Fail the WebSocket Connection + This requires: + 1. Stopping all processing of incoming data, which means cancelling + pausing the underlying io protocol. The close code will be 1006 + unless a close frame was received earlier. + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + 3. Closing the connection. :meth:`auto_close_connection` takes care + of this. + (The specification describes these steps in the opposite order.) + """ + if self.io_proto and self.io_proto.transport: + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # ut can be called when the transport is already paused or closed + self.io_proto.transport.pause_reading() + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not draining the write buffer is acceptable in this context. + + # clear the send buffer + _ = self.connection.data_to_send() + # If we're not already CLOSED or CLOSING, then send the close. + if self.connection.state is OPEN: + if code in (1000, 1001): + self.connection.send_close(code, reason) + else: + self.connection.fail(code, reason) + try: + data_to_send = self.connection.data_to_send() + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... + if code == 1006: + # Special case: 1006 consider the transport already closed + self.connection.state = CLOSED + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + return self._force_disconnect() + return False + + def end_connection(self, code=1000, reason=""): + # This is like slightly more graceful form of fail_connection + # Use this instead of close() when you need an immediate + # close and cannot await websocket.close() handshake. + + if code == 1006 or not self.io_proto or not self.io_proto.transport: + return self.fail_connection(code, reason) + + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + if self.connection.state == OPEN: + data_to_send = self.connection.data_to_send() + self.connection.send_close(code, reason) + data_to_send.extend(self.connection.data_to_send()) + try: + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + # But that doesn't matter at this point + ... + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have the ability to signal the auto-closer + # try to trigger it to auto-close the connection + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + # Auto-closer is not running, do force disconnect + return self._force_disconnect() + return False + + async def auto_close_connection(self) -> None: + """ + Close the WebSocket Connection + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + """ + try: + # Wait for the data transfer phase to complete. + if self.data_finished_fut: + try: + await self.data_finished_fut + logger.debug( + "Websocket task finished. Closing the connection." + ) + except asyncio.CancelledError: + # Cancelled error is called when data phase is cancelled + # if an error occurred or the client closed the connection + logger.debug( + "Websocket handler cancelled. Closing the connection." + ) + + # Cancel the keepalive ping task. + if self.keepalive_ping_task: + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + + # Half-close the TCP connection if possible (when there's no TLS). + if ( + self.io_proto + and self.io_proto.transport + and self.io_proto.transport.can_write_eof() + ): + logger.debug("Websocket half-closing TCP connection") + self.io_proto.transport.write_eof() + if self.connection_lost_waiter: + if await self.wait_for_connection_lost(timeout=0): + return + except asyncio.CancelledError: + ... + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + if (not self.io_proto) or (not self.io_proto.transport): + # we were never open, or done. Can't do any finalization. + return + elif ( + self.connection_lost_waiter + and self.connection_lost_waiter.done() + ): + # connection confirmed closed already, proceed to abort waiter + ... + elif self.io_proto.transport.is_closing(): + # Connection is already closing (due to half-close above) + # proceed to abort waiter + ... + else: + self.io_proto.transport.close() + if not self.connection_lost_waiter: + # Our connection monitor task isn't running. + try: + await asyncio.sleep(self.close_timeout) + except asyncio.CancelledError: + ... + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + else: + if await self.wait_for_connection_lost( + timeout=self.close_timeout + ): + # Connection aborted before the timeout expired. + return + error_logger.warning( + "Timeout waiting for TCP connection to close. Aborting" + ) + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + They'll never receive a pong once the connection is closed. + """ + if self.connection.state is not CLOSED: + raise ServerError( + "Webscoket about_pings should only be called " + "after connection state is changed to CLOSED" + ) + + for ping in self.pings.values(): + ping.set_exception(ConnectionClosedError(None, None)) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + This is a websocket-protocol level close. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ + if code == 1006: + self.fail_connection(code, reason) + return + async with self.conn_mutex: + if self.connection.state is OPEN: + self.connection.send_close(code, reason) + data_to_send = self.connection.data_to_send() + await self.send_data(data_to_send) + + async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Receive the next message. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + If ``timeout`` is ``None``, block until a message is received. Else, + if no message is received within ``timeout`` seconds, return ``None``. + Set ``timeout`` to ``0`` to check if a message was already received. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ServerError: if two tasks call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv while another task is " + "already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state in (CLOSED, CLOSING): + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) + try: + return await self.assembler.get(timeout) + finally: + self.recv_lock.release() + + async def recv_burst(self, max_recv=256) -> Sequence[Data]: + """ + Receive the messages which have arrived since last checking. + Return a :class:`list` containing :class:`str` for a text frame + and :class:`bytes` for a binary frame. + When the end of the message stream is reached, :meth:`recv_burst` + raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a + normal connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ServerError: if two tasks call :meth:`recv_burst` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv_burst while another task is already waiting " + "for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state in (CLOSED, CLOSING): + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) + messages = [] + try: + # Prevent pausing the transport when we're + # receiving a burst of messages + self.can_pause = False + while True: + m = await self.assembler.get(timeout=0) + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) + finally: + self.can_pause = True + self.recv_lock.release() + return messages + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + Return an iterator of :class:`str` for a text frame and :class:`bytes` + for a binary frame. The iterator should be exhausted, or else the + connection will become unusable. + With the exception of the return value, :meth:`recv_streaming` behaves + like :meth:`recv`. + """ + if self.recv_lock.locked(): + raise ServerError( + "Cannot call recv_streaming while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state in (CLOSED, CLOSING): + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) + try: + self.can_pause = False + async for m in self.assembler.get_iter(): + yield m + finally: + self.can_pause = True + self.recv_lock.release() + + async def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + :raises TypeError: for unsupported inputs + """ + async with self.conn_mutex: + + if self.connection.state in (CLOSED, CLOSING): + raise ServerError( + "Cannot write to websocket interface after it is closed." + ) + if (not self.data_finished_fut) or self.data_finished_fut.done(): + raise ServerError( + "Cannot write to websocket interface after it is finished." + ) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + self.connection.send_text(message.encode("utf-8")) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, (bytes, bytearray, memoryview)): + self.connection.send_binary(message) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, Mapping): + # Catch a common mistake -- passing a dict to send(). + raise TypeError("data is a dict-like object") + + elif isinstance(message, Iterable): + # Fragmented message -- regular iterator. + raise NotImplementedError( + "Fragmented websocket messages are not supported." + ) + else: + raise TypeError("Websocket data must be bytes, str.") + + async def ping(self, data: Optional[Data] = None) -> asyncio.Future: + """ + Send a ping. + Return an :class:`~asyncio.Future` that will be resolved when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + await pong_event = ws.ping() + await pong_event # only if you want to wait for the pong + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + raise ServerError( + "Cannot send a ping when the websocket interface " + "is closed." + ) + if (not self.io_proto) or (not self.io_proto.loop): + raise ServerError( + "Cannot send a ping when the websocket has no I/O " + "protocol attached." + ) + if data is not None: + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError( + "already waiting for a pong with the same data" + ) + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.io_proto.loop.create_future() + + self.connection.send_ping(data) + await self.send_data(self.connection.data_to_send()) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + An unsolicited pong may serve as a unidirectional heartbeat. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + # Cannot send pong after transport is shutting down + return + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + self.connection.send_pong(data) + await self.send_data(self.connection.data_to_send()) + + async def send_data(self, data_to_send): + for data in data_to_send: + if data: + await self.io_proto.send(data) + else: + # Send an EOF - We don't actually send it, + # just trigger to autoclose the connection + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + async def async_data_received(self, data_to_send, events_to_process): + if self.connection.state == OPEN and len(data_to_send) > 0: + # receiving data can generate data to send (eg, pong for a ping) + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + def data_received(self, data): + self.connection.receive_data(data) + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task( + self.async_data_received(data_to_send, events_to_process) + ) + + async def async_eof_received(self, data_to_send, events_to_process): + # receiving EOF can generate data to send + # send connection.data_to_send() + if self.connection.state == OPEN: + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + def eof_received(self) -> Optional[bool]: + self.connection.receive_eof() + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) + return False + + def connection_lost(self, exc): + """ + The WebSocket Connection is Closed. + """ + if not self.connection.state == CLOSED: + # signal to the websocket connection handler + # we've lost the connection + self.connection.fail(code=1006) + self.connection.state = CLOSED + + self.abort_pings() + if self.connection_lost_waiter: + self.connection_lost_waiter.set_result(None) diff --git a/sanic/websocket.py b/sanic/websocket.py deleted file mode 100644 index b5600ed7..00000000 --- a/sanic/websocket.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - MutableMapping, - Optional, - Union, -) - -from httptools import HttpParserUpgrade # type: ignore -from websockets import ( # type: ignore - ConnectionClosed, - InvalidHandshake, - WebSocketCommonProtocol, -) - -# Despite the "legacy" namespace, the primary maintainer of websockets -# committed to maintaining backwards-compatibility until 2026 and will -# consider extending it if sanic continues depending on this module. -from websockets.legacy import handshake - -from sanic.exceptions import InvalidUsage -from sanic.server import HttpProtocol - - -__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] - -ASIMessage = MutableMapping[str, Any] - - -class WebSocketProtocol(HttpProtocol): - def __init__( - self, - *args, - websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - websocket_ping_interval=20, - websocket_ping_timeout=20, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.websocket = None - # self.app = None - self.websocket_timeout = websocket_timeout - self.websocket_max_size = websocket_max_size - self.websocket_max_queue = websocket_max_queue - self.websocket_read_limit = websocket_read_limit - self.websocket_write_limit = websocket_write_limit - self.websocket_ping_interval = websocket_ping_interval - self.websocket_ping_timeout = websocket_ping_timeout - - # timeouts make no sense for websocket routes - def request_timeout_callback(self): - if self.websocket is None: - super().request_timeout_callback() - - def response_timeout_callback(self): - if self.websocket is None: - super().response_timeout_callback() - - def keep_alive_timeout_callback(self): - if self.websocket is None: - super().keep_alive_timeout_callback() - - def connection_lost(self, exc): - if self.websocket is not None: - self.websocket.connection_lost(exc) - super().connection_lost(exc) - - def data_received(self, data): - if self.websocket is not None: - # pass the data to the websocket protocol - self.websocket.data_received(data) - else: - try: - super().data_received(data) - except HttpParserUpgrade: - # this is okay, it just indicates we've got an upgrade request - pass - - def write_response(self, response): - if self.websocket is not None: - # websocket requests do not write a response - self.transport.close() - else: - super().write_response(response) - - async def websocket_handshake(self, request, subprotocols=None): - # let the websockets package do the handshake with the client - headers = {} - - try: - key = handshake.check_request(request.headers) - handshake.build_response(headers, key) - except InvalidHandshake: - raise InvalidUsage("Invalid websocket request") - - subprotocol = None - if subprotocols and "Sec-Websocket-Protocol" in request.headers: - # select a subprotocol - client_subprotocols = [ - p.strip() - for p in request.headers["Sec-Websocket-Protocol"].split(",") - ] - for p in client_subprotocols: - if p in subprotocols: - subprotocol = p - headers["Sec-Websocket-Protocol"] = subprotocol - break - - # write the 101 response back to the client - rv = b"HTTP/1.1 101 Switching Protocols\r\n" - for k, v in headers.items(): - rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - rv += b"\r\n" - request.transport.write(rv) - - # hook up the websocket protocol - self.websocket = WebSocketCommonProtocol( - close_timeout=self.websocket_timeout, - max_size=self.websocket_max_size, - max_queue=self.websocket_max_queue, - read_limit=self.websocket_read_limit, - write_limit=self.websocket_write_limit, - ping_interval=self.websocket_ping_interval, - ping_timeout=self.websocket_ping_timeout, - ) - # we use WebSocketCommonProtocol because we don't want the handshake - # logic from WebSocketServerProtocol; however, we must tell it that - # we're running on the server side - self.websocket.is_client = False - self.websocket.side = "server" - self.websocket.subprotocol = subprotocol - self.websocket.connection_made(request.transport) - self.websocket.connection_open() - return self.websocket - - -class WebSocketConnection: - - # TODO - # - Implement ping/pong - - def __init__( - self, - send: Callable[[ASIMessage], Awaitable[None]], - receive: Callable[[], Awaitable[ASIMessage]], - subprotocols: Optional[List[str]] = None, - ) -> None: - self._send = send - self._receive = receive - self._subprotocols = subprotocols or [] - - async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: - message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} - - if isinstance(data, bytes): - message.update({"bytes": data}) - else: - message.update({"text": str(data)}) - - await self._send(message) - - async def recv(self, *args, **kwargs) -> Optional[str]: - message = await self._receive() - - if message["type"] == "websocket.receive": - return message["text"] - elif message["type"] == "websocket.disconnect": - pass - - return None - - receive = recv - - async def accept(self, subprotocols: Optional[List[str]] = None) -> None: - subprotocol = None - if subprotocols: - for subp in subprotocols: - if subp in self.subprotocols: - subprotocol = subp - break - - await self._send( - { - "type": "websocket.accept", - "subprotocol": subprotocol, - } - ) - - async def close(self) -> None: - pass - - @property - def subprotocols(self): - return self._subprotocols - - @subprotocols.setter - def subprotocols(self, subprotocols: Optional[List[str]] = None): - self._subprotocols = subprotocols or [] diff --git a/sanic/worker.py b/sanic/worker.py index 51bee6c2..a3bc29b8 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -9,7 +9,7 @@ from gunicorn.workers import base # type: ignore from sanic.log import logger from sanic.server import HttpProtocol, Signal, serve -from sanic.websocket import WebSocketProtocol +from sanic.server.protocols.websocket_protocol import WebSocketProtocol try: @@ -142,14 +142,11 @@ class GunicornWorker(base.Worker): # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + conn.websocket.fail_connection(code=1001) else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown + conn.abort() async def _run(self): for sock in self.sockets: diff --git a/setup.py b/setup.py index d65703ac..ebfe85c9 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ requirements = [ uvloop, ujson, "aiofiles>=0.6.0", - "websockets>=9.0", + "websockets>=10.0", "multidict>=5.0,<6.0", ] diff --git a/tests/test_app.py b/tests/test_app.py index 9598d54f..196f34f5 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -178,9 +178,6 @@ def test_app_enable_websocket(app, websocket_enabled, enable): @patch("sanic.app.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 - app.config.WEBSOCKET_MAX_QUEUE = 45 - app.config.WEBSOCKET_READ_LIMIT = 46 - app.config.WEBSOCKET_WRITE_LIMIT = 47 app.config.WEBSOCKET_PING_TIMEOUT = 48 app.config.WEBSOCKET_PING_INTERVAL = 50 @@ -197,11 +194,6 @@ def test_app_websocket_parameters(websocket_protocol_mock, app): websocket_protocol_call_args = websocket_protocol_mock.call_args ws_kwargs = websocket_protocol_call_args[1] assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE - assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE - assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT - assert ( - ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT - ) assert ( ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 0745c2ed..3d464a4f 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -10,7 +10,7 @@ from sanic.asgi import MockTransport from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request from sanic.response import json, text -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection @pytest.fixture diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1ccd5547..29797e1e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -16,6 +16,7 @@ from sanic.exceptions import ( abort, ) from sanic.response import text +from websockets.version import version as websockets_version class SanicExceptionTestException(Exception): @@ -260,9 +261,14 @@ def test_exception_in_ws_logged(caplog): with caplog.at_level(logging.INFO): app.test_client.websocket("/feed") - - assert caplog.record_tuples[1][0] == "sanic.error" - assert caplog.record_tuples[1][1] == logging.ERROR + # Websockets v10.0 and above output an additional + # INFO message when a ws connection is accepted + ws_version_parts = websockets_version.split(".") + ws_major = int(ws_version_parts[0]) + record_index = 2 if ws_major >= 10 else 1 + assert caplog.record_tuples[record_index][0] == "sanic.error" + assert caplog.record_tuples[record_index][1] == logging.ERROR assert ( - "Exception occurred while handling uri:" in caplog.record_tuples[1][2] + "Exception occurred while handling uri:" + in caplog.record_tuples[record_index][2] ) diff --git a/tests/test_routes.py b/tests/test_routes.py index 9af615b5..520ab5be 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -674,16 +674,16 @@ async def test_websocket_route_asgi(app, url): @pytest.mark.parametrize( "subprotocols,expected", ( - (["bar"], "bar"), - (["bar", "foo"], "bar"), - (["baz"], None), + (["one"], "one"), + (["three", "one"], "one"), + (["tree"], None), (None, None), ), ) def test_websocket_route_with_subprotocols(app, subprotocols, expected): - results = "unset" + results = [] - @app.websocket("/ws", subprotocols=["foo", "bar"]) + @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) async def handler(request, ws): nonlocal results results = ws.subprotocol diff --git a/tests/test_worker.py b/tests/test_worker.py index 252bdb36..3850b8a6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -175,7 +175,7 @@ def test_worker_close(worker): worker.wsgi = mock.Mock() conn = mock.Mock() conn.websocket = mock.Mock() - conn.websocket.close_connection = mock.Mock(wraps=_a_noop) + conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -190,5 +190,5 @@ def test_worker_close(worker): loop.run_until_complete(_close) assert worker.signal.stopped - assert conn.websocket.close_connection.called + assert conn.websocket.fail_connection.called assert len(worker.servers) == 0