diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 61077cea..1bb06f1d 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -237,6 +237,11 @@ class InvalidSignal(SanicException): pass +class WebsocketClosed(SanicException): + quiet = True + message = "Client has closed the websocket connection" + + def abort(status_code: int, message: Optional[Union[str, bytes]] = None): """ Raise an exception based on SanicException. Returns the HTTP response diff --git a/sanic/http.py b/sanic/http.py index 3fafd83a..d30e4c82 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -187,12 +187,12 @@ class Http(metaclass=TouchUpMeta): if self.response: self.response.stream = None - self.init_for_request() - # Exit and disconnect if no more requests can be taken if self.stage is not Stage.IDLE or not self.keep_alive: break + self.init_for_request() + # Wait for the next request if not self.recv_buffer: await self._receive_more() diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index c609694e..409f5e4b 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -109,7 +109,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): except Exception: error_logger.exception("protocol.connection_task uncaught") finally: - if self.app.debug and self._http and self.transport: + if ( + self.app.debug + and self._http + and self.transport + and not self._http.upgrade_websocket + ): ip = self.transport.get_extra_info("peername") error_logger.error( "Connection lost before response written" diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 628945d2..457f1cd0 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -148,7 +148,6 @@ class WebSocketProtocol(HttpProtocol): await super().send(rbody.encode()) else: raise ServerError(resp.body, resp.status_code) - self.websocket = WebsocketImplProtocol( ws_conn, ping_interval=self.websocket_ping_interval, diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index b4af72b1..fef27db1 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -161,10 +161,8 @@ class WebsocketFrameAssembler: ) self.message_fetched.set() self.chunks = [] - self.chunks_queue = ( - None # this should already be None, but set it here for safety - ) - + # this should already be None, but set it here for safety + self.chunks_queue = None return message async def get_iter(self) -> AsyncIterator[Data]: @@ -193,7 +191,7 @@ class WebsocketFrameAssembler: if self.message_complete.is_set(): await self.chunks_queue.put(None) - # Locking with get_in_progress ensures only one thread can get here + # Locking with get_in_progress ensures only one task can get here for c in chunks: yield c while True: @@ -232,9 +230,8 @@ class WebsocketFrameAssembler: ) self.message_fetched.set() - self.chunks = ( - [] - ) # this should already be empty, but set it here for safety + # this should already be empty, but set it here for safety + self.chunks = [] self.chunks_queue = None async def put(self, frame: Frame) -> None: diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index a2778c57..ed0d7fed 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -14,14 +14,14 @@ from typing import ( from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.exceptions import ConnectionClosed, ConnectionClosedError -from websockets.frames import OP_PONG, Frame +from websockets.frames import Frame, Opcode from websockets.server import ServerConnection from websockets.typing import Data from sanic.log import error_logger, logger from sanic.server.protocols.base_protocol import SanicProtocol -from ...exceptions import ServerError +from ...exceptions import ServerError, WebsocketClosed from .frame import WebsocketFrameAssembler @@ -38,6 +38,7 @@ class WebsocketImplProtocol: pings: Dict[bytes, asyncio.Future] conn_mutex: asyncio.Lock recv_lock: asyncio.Lock + recv_cancel: Optional[asyncio.Future] process_event_mutex: asyncio.Lock can_pause: bool # Optional[asyncio.Future[None]] @@ -69,6 +70,7 @@ class WebsocketImplProtocol: self.pings = {} self.conn_mutex = asyncio.Lock() self.recv_lock = asyncio.Lock() + self.recv_cancel = None self.process_event_mutex = asyncio.Lock() self.data_finished_fut = None self.can_pause = True @@ -180,12 +182,15 @@ class WebsocketImplProtocol: if not isinstance(event, Frame): # Event is not a frame. Ignore it. continue - if event.opcode == OP_PONG: + if event.opcode == Opcode.PONG: await self.process_pong(event) + elif event.opcode == Opcode.CLOSE: + if self.recv_cancel: + self.recv_cancel.cancel() else: await self.assembler.put(event) - async def process_pong(self, frame: "Frame") -> None: + async def process_pong(self, frame: Frame) -> None: if frame.data in self.pings: # Acknowledge all pings up to the one matching this pong. ping_ids = [] @@ -237,7 +242,7 @@ class WebsocketImplProtocol: # It is expected for this task to be cancelled during during # normal operation, when the connection is closed. logger.debug("Websocket keepalive ping task was cancelled.") - except ConnectionClosed: + except (ConnectionClosed, WebsocketClosed): logger.debug("Websocket closed. Keepalive ping task exiting.") except Exception as e: error_logger.warning( @@ -495,6 +500,7 @@ class WebsocketImplProtocol: Set ``timeout`` to ``0`` to check if a message was already received. :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed + :raises asyncio.CancelledError: if the websocket closes while waiting :raises ServerError: if two tasks call :meth:`recv` or :meth:`recv_streaming` concurrently """ @@ -505,13 +511,28 @@ class WebsocketImplProtocol: "already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) try: - return await self.assembler.get(timeout) + self.recv_cancel = asyncio.Future() + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + else: + self.recv_cancel.cancel() + return done_task.result() finally: + self.recv_cancel = None self.recv_lock.release() async def recv_burst(self, max_recv=256) -> Sequence[Data]: @@ -537,8 +558,9 @@ class WebsocketImplProtocol: "for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) messages = [] @@ -546,8 +568,19 @@ class WebsocketImplProtocol: # Prevent pausing the transport when we're # receiving a burst of messages self.can_pause = False + self.recv_cancel = asyncio.Future() while True: - m = await self.assembler.get(timeout=0) + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout=0)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv_burst was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + m = done_task.result() if m is None: # None left in the burst. This is good! break @@ -558,7 +591,9 @@ class WebsocketImplProtocol: # Allow an eventloop iteration for the # next message to pass into the Assembler await asyncio.sleep(0) + self.recv_cancel.cancel() finally: + self.recv_cancel = None self.can_pause = True self.recv_lock.release() return messages @@ -578,16 +613,25 @@ class WebsocketImplProtocol: "is already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." ) try: + cancelled = False + self.recv_cancel = asyncio.Future() self.can_pause = False async for m in self.assembler.get_iter(): + if self.recv_cancel.done(): + cancelled = True + break yield m + if cancelled: + raise asyncio.CancelledError() finally: self.can_pause = True + self.recv_cancel = None self.recv_lock.release() async def send(self, message: Union[Data, Iterable[Data]]) -> None: @@ -611,7 +655,7 @@ class WebsocketImplProtocol: async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + raise WebsocketClosed( "Cannot write to websocket interface after it is closed." ) if (not self.data_finished_fut) or self.data_finished_fut.done(): @@ -658,7 +702,7 @@ class WebsocketImplProtocol: """ async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise ServerError( + raise WebsocketClosed( "Cannot send a ping when the websocket interface " "is closed." ) @@ -728,7 +772,7 @@ class WebsocketImplProtocol: SanicProtocol.close(self.io_proto, timeout=1.0) async def async_data_received(self, data_to_send, events_to_process): - if self.connection.state == OPEN and len(data_to_send) > 0: + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: # receiving data can generate data to send (eg, pong for a ping) # send connection.data_to_send() await self.send_data(data_to_send) @@ -747,11 +791,12 @@ class WebsocketImplProtocol: async def async_eof_received(self, data_to_send, events_to_process): # receiving EOF can generate data to send # send connection.data_to_send() - if self.connection.state == OPEN: + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: await self.send_data(data_to_send) if len(events_to_process) > 0: await self.process_events(events_to_process) - + if self.recv_cancel: + self.recv_cancel.cancel() if ( self.auto_closer_task and not self.auto_closer_task.done() @@ -760,6 +805,7 @@ class WebsocketImplProtocol: ): # Auto-close the connection self.data_finished_fut.set_result(None) + # Cancel the running handler if its waiting else: # This will fail the connection appropriately SanicProtocol.close(self.io_proto, timeout=1.0) @@ -768,10 +814,9 @@ class WebsocketImplProtocol: self.connection.receive_eof() data_to_send = self.connection.data_to_send() events_to_process = self.connection.events_received() - if len(data_to_send) > 0 or len(events_to_process) > 0: - asyncio.create_task( - self.async_eof_received(data_to_send, events_to_process) - ) + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) return False def connection_lost(self, exc):