From f7abf3db1bd4e79cd5121327359fc9021fab7ff3 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Fri, 1 Oct 2021 05:20:57 +1000 Subject: [PATCH] Some fixes to the new Websockets impl (#2248) * First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * Update sanic/websocket.py Co-authored-by: Adam Hopkins * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes * Some fixes to the new Websockets impl Will throw WebsocketClosed exception instead of ServerException now when attempting to read or write to closed websocket, this makes it easier to catch The various ws.recv() methods now have the ability to raise CancelledError into your websocket handler Fix a niche close-socket negotiation bug Fix bug where http protocol thought the websocket never sent any response. Allow data to still send in some cases after websocket enters CLOSING state. Fix some badly formatted and badly placed comments * allow eof_received to send back data too, if the connection is in CLOSING state Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins --- sanic/exceptions.py | 5 ++ sanic/http.py | 4 +- sanic/server/protocols/http_protocol.py | 7 +- sanic/server/protocols/websocket_protocol.py | 1 - sanic/server/websockets/frame.py | 13 ++- sanic/server/websockets/impl.py | 89 +++++++++++++++----- 6 files changed, 85 insertions(+), 34 deletions(-) diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 61077cea..1bb06f1d 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -237,6 +237,11 @@ class InvalidSignal(SanicException): pass +class WebsocketClosed(SanicException): + quiet = True + message = "Client has closed the websocket connection" + + def abort(status_code: int, message: Optional[Union[str, bytes]] = None): """ Raise an exception based on SanicException. Returns the HTTP response diff --git a/sanic/http.py b/sanic/http.py index 3fafd83a..d30e4c82 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -187,12 +187,12 @@ class Http(metaclass=TouchUpMeta): if self.response: self.response.stream = None - self.init_for_request() - # Exit and disconnect if no more requests can be taken if self.stage is not Stage.IDLE or not self.keep_alive: break + self.init_for_request() + # Wait for the next request if not self.recv_buffer: await self._receive_more() diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index c609694e..409f5e4b 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -109,7 +109,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): except Exception: error_logger.exception("protocol.connection_task uncaught") finally: - if self.app.debug and self._http and self.transport: + if ( + self.app.debug + and self._http + and self.transport + and not self._http.upgrade_websocket + ): ip = self.transport.get_extra_info("peername") error_logger.error( "Connection lost before response written" diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 628945d2..457f1cd0 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -148,7 +148,6 @@ class WebSocketProtocol(HttpProtocol): await super().send(rbody.encode()) else: raise ServerError(resp.body, resp.status_code) - self.websocket = WebsocketImplProtocol( ws_conn, ping_interval=self.websocket_ping_interval, diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index b4af72b1..fef27db1 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -161,10 +161,8 @@ class WebsocketFrameAssembler: ) self.message_fetched.set() self.chunks = [] - self.chunks_queue = ( - None # this should already be None, but set it here for safety - ) - + # this should already be None, but set it here for safety + self.chunks_queue = None return message async def get_iter(self) -> AsyncIterator[Data]: @@ -193,7 +191,7 @@ class WebsocketFrameAssembler: if self.message_complete.is_set(): await self.chunks_queue.put(None) - # Locking with get_in_progress ensures only one thread can get here + # Locking with get_in_progress ensures only one task can get here for c in chunks: yield c while True: @@ -232,9 +230,8 @@ class WebsocketFrameAssembler: ) self.message_fetched.set() - self.chunks = ( - [] - ) # this should already be empty, but set it here for safety + # this should already be empty, but set it here for safety + self.chunks = [] self.chunks_queue = None async def put(self, frame: Frame) -> None: diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index a2778c57..ed0d7fed 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -14,14 +14,14 @@ from typing import ( from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.exceptions import ConnectionClosed, ConnectionClosedError -from websockets.frames import OP_PONG, Frame +from websockets.frames import Frame, Opcode from websockets.server import ServerConnection from websockets.typing import Data from sanic.log import error_logger, logger from sanic.server.protocols.base_protocol import SanicProtocol -from ...exceptions import ServerError +from ...exceptions import ServerError, WebsocketClosed from .frame import WebsocketFrameAssembler @@ -38,6 +38,7 @@ class WebsocketImplProtocol: pings: Dict[bytes, asyncio.Future] conn_mutex: asyncio.Lock recv_lock: asyncio.Lock + recv_cancel: Optional[asyncio.Future] process_event_mutex: asyncio.Lock can_pause: bool # Optional[asyncio.Future[None]] @@ -69,6 +70,7 @@ class WebsocketImplProtocol: self.pings = {} self.conn_mutex = asyncio.Lock() self.recv_lock = asyncio.Lock() + self.recv_cancel = None self.process_event_mutex = asyncio.Lock() self.data_finished_fut = None self.can_pause = True @@ -180,12 +182,15 @@ class WebsocketImplProtocol: if not isinstance(event, Frame): # Event is not a frame. Ignore it. continue - if event.opcode == OP_PONG: + if event.opcode == Opcode.PONG: await self.process_pong(event) + elif event.opcode == Opcode.CLOSE: + if self.recv_cancel: + self.recv_cancel.cancel() else: await self.assembler.put(event) - async def process_pong(self, frame: "Frame") -> None: + async def process_pong(self, frame: Frame) -> None: if frame.data in self.pings: # Acknowledge all pings up to the one matching this pong. ping_ids = [] @@ -237,7 +242,7 @@ class WebsocketImplProtocol: # It is expected for this task to be cancelled during during # normal operation, when the connection is closed. logger.debug("Websocket keepalive ping task was cancelled.") - except ConnectionClosed: + except (ConnectionClosed, WebsocketClosed): logger.debug("Websocket closed. Keepalive ping task exiting.") except Exception as e: error_logger.warning( @@ -495,6 +500,7 @@ class WebsocketImplProtocol: Set ``timeout`` to ``0`` to check if a message was already received. :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed + :raises asyncio.CancelledError: if the websocket closes while waiting :raises ServerError: if two tasks call :meth:`recv` or :meth:`recv_streaming` concurrently """ @@ -505,13 +511,28 @@ class WebsocketImplProtocol: "already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) try: - return await self.assembler.get(timeout) + self.recv_cancel = asyncio.Future() + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + else: + self.recv_cancel.cancel() + return done_task.result() finally: + self.recv_cancel = None self.recv_lock.release() async def recv_burst(self, max_recv=256) -> Sequence[Data]: @@ -537,8 +558,9 @@ class WebsocketImplProtocol: "for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) messages = [] @@ -546,8 +568,19 @@ class WebsocketImplProtocol: # Prevent pausing the transport when we're # receiving a burst of messages self.can_pause = False + self.recv_cancel = asyncio.Future() while True: - m = await self.assembler.get(timeout=0) + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout=0)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv_burst was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + m = done_task.result() if m is None: # None left in the burst. This is good! break @@ -558,7 +591,9 @@ class WebsocketImplProtocol: # Allow an eventloop iteration for the # next message to pass into the Assembler await asyncio.sleep(0) + self.recv_cancel.cancel() finally: + self.recv_cancel = None self.can_pause = True self.recv_lock.release() return messages @@ -578,16 +613,25 @@ class WebsocketImplProtocol: "is already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) try: + cancelled = False + self.recv_cancel = asyncio.Future() self.can_pause = False async for m in self.assembler.get_iter(): + if self.recv_cancel.done(): + cancelled = True + break yield m + if cancelled: + raise asyncio.CancelledError() finally: self.can_pause = True + self.recv_cancel = None self.recv_lock.release() async def send(self, message: Union[Data, Iterable[Data]]) -> None: @@ -611,7 +655,7 @@ class WebsocketImplProtocol: async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + raise WebsocketClosed( "Cannot write to websocket interface after it is closed." ) if (not self.data_finished_fut) or self.data_finished_fut.done(): @@ -658,7 +702,7 @@ class WebsocketImplProtocol: """ async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + raise WebsocketClosed( "Cannot send a ping when the websocket interface " "is closed." ) @@ -728,7 +772,7 @@ class WebsocketImplProtocol: SanicProtocol.close(self.io_proto, timeout=1.0) async def async_data_received(self, data_to_send, events_to_process): - if self.connection.state == OPEN and len(data_to_send) > 0: + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: # receiving data can generate data to send (eg, pong for a ping) # send connection.data_to_send() await self.send_data(data_to_send) @@ -747,11 +791,12 @@ class WebsocketImplProtocol: async def async_eof_received(self, data_to_send, events_to_process): # receiving EOF can generate data to send # send connection.data_to_send() - if self.connection.state == OPEN: + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: await self.send_data(data_to_send) if len(events_to_process) > 0: await self.process_events(events_to_process) - + if self.recv_cancel: + self.recv_cancel.cancel() if ( self.auto_closer_task and not self.auto_closer_task.done() @@ -760,6 +805,7 @@ class WebsocketImplProtocol: ): # Auto-close the connection self.data_finished_fut.set_result(None) + # Cancel the running handler if its waiting else: # This will fail the connection appropriately SanicProtocol.close(self.io_proto, timeout=1.0) @@ -768,10 +814,9 @@ class WebsocketImplProtocol: self.connection.receive_eof() data_to_send = self.connection.data_to_send() events_to_process = self.connection.events_received() - if len(data_to_send) > 0 or len(events_to_process) > 0: - asyncio.create_task( - self.async_eof_received(data_to_send, events_to_process) - ) + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) return False def connection_lost(self, exc):