Some fixes to the new Websockets impl (#2248)
* First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes * Some fixes to the new Websockets impl Will throw WebsocketClosed exception instead of ServerException now when attempting to read or write to closed websocket, this makes it easier to catch The various ws.recv() methods now have the ability to raise CancelledError into your websocket handler Fix a niche close-socket negotiation bug Fix bug where http protocol thought the websocket never sent any response. Allow data to still send in some cases after websocket enters CLOSING state. Fix some badly formatted and badly placed comments * allow eof_received to send back data too, if the connection is in CLOSING state Co-authored-by: Adam Hopkins <adam@amhopkins.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
		| @@ -237,6 +237,11 @@ class InvalidSignal(SanicException): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class WebsocketClosed(SanicException): | ||||
|     quiet = True | ||||
|     message = "Client has closed the websocket connection" | ||||
|  | ||||
|  | ||||
| def abort(status_code: int, message: Optional[Union[str, bytes]] = None): | ||||
|     """ | ||||
|     Raise an exception based on SanicException. Returns the HTTP response | ||||
|   | ||||
| @@ -187,12 +187,12 @@ class Http(metaclass=TouchUpMeta): | ||||
|                 if self.response: | ||||
|                     self.response.stream = None | ||||
|  | ||||
|             self.init_for_request() | ||||
|  | ||||
|             # Exit and disconnect if no more requests can be taken | ||||
|             if self.stage is not Stage.IDLE or not self.keep_alive: | ||||
|                 break | ||||
|  | ||||
|             self.init_for_request() | ||||
|  | ||||
|             # Wait for the next request | ||||
|             if not self.recv_buffer: | ||||
|                 await self._receive_more() | ||||
|   | ||||
| @@ -109,7 +109,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): | ||||
|         except Exception: | ||||
|             error_logger.exception("protocol.connection_task uncaught") | ||||
|         finally: | ||||
|             if self.app.debug and self._http and self.transport: | ||||
|             if ( | ||||
|                 self.app.debug | ||||
|                 and self._http | ||||
|                 and self.transport | ||||
|                 and not self._http.upgrade_websocket | ||||
|             ): | ||||
|                 ip = self.transport.get_extra_info("peername") | ||||
|                 error_logger.error( | ||||
|                     "Connection lost before response written" | ||||
|   | ||||
| @@ -148,7 +148,6 @@ class WebSocketProtocol(HttpProtocol): | ||||
|             await super().send(rbody.encode()) | ||||
|         else: | ||||
|             raise ServerError(resp.body, resp.status_code) | ||||
|  | ||||
|         self.websocket = WebsocketImplProtocol( | ||||
|             ws_conn, | ||||
|             ping_interval=self.websocket_ping_interval, | ||||
|   | ||||
| @@ -161,10 +161,8 @@ class WebsocketFrameAssembler: | ||||
|                 ) | ||||
|             self.message_fetched.set() | ||||
|             self.chunks = [] | ||||
|             self.chunks_queue = ( | ||||
|                 None  # this should already be None, but set it here for safety | ||||
|             ) | ||||
|  | ||||
|             # this should already be None, but set it here for safety | ||||
|             self.chunks_queue = None | ||||
|             return message | ||||
|  | ||||
|     async def get_iter(self) -> AsyncIterator[Data]: | ||||
| @@ -193,7 +191,7 @@ class WebsocketFrameAssembler: | ||||
|             if self.message_complete.is_set(): | ||||
|                 await self.chunks_queue.put(None) | ||||
|  | ||||
|             # Locking with get_in_progress ensures only one thread can get here | ||||
|             # Locking with get_in_progress ensures only one task can get here | ||||
|             for c in chunks: | ||||
|                 yield c | ||||
|             while True: | ||||
| @@ -232,9 +230,8 @@ class WebsocketFrameAssembler: | ||||
|                 ) | ||||
|  | ||||
|             self.message_fetched.set() | ||||
|             self.chunks = ( | ||||
|                 [] | ||||
|             )  # this should already be empty, but set it here for safety | ||||
|             # this should already be empty, but set it here for safety | ||||
|             self.chunks = [] | ||||
|             self.chunks_queue = None | ||||
|  | ||||
|     async def put(self, frame: Frame) -> None: | ||||
|   | ||||
| @@ -14,14 +14,14 @@ from typing import ( | ||||
|  | ||||
| from websockets.connection import CLOSED, CLOSING, OPEN, Event | ||||
| from websockets.exceptions import ConnectionClosed, ConnectionClosedError | ||||
| from websockets.frames import OP_PONG, Frame | ||||
| from websockets.frames import Frame, Opcode | ||||
| from websockets.server import ServerConnection | ||||
| from websockets.typing import Data | ||||
|  | ||||
| from sanic.log import error_logger, logger | ||||
| from sanic.server.protocols.base_protocol import SanicProtocol | ||||
|  | ||||
| from ...exceptions import ServerError | ||||
| from ...exceptions import ServerError, WebsocketClosed | ||||
| from .frame import WebsocketFrameAssembler | ||||
|  | ||||
|  | ||||
| @@ -38,6 +38,7 @@ class WebsocketImplProtocol: | ||||
|     pings: Dict[bytes, asyncio.Future] | ||||
|     conn_mutex: asyncio.Lock | ||||
|     recv_lock: asyncio.Lock | ||||
|     recv_cancel: Optional[asyncio.Future] | ||||
|     process_event_mutex: asyncio.Lock | ||||
|     can_pause: bool | ||||
|     # Optional[asyncio.Future[None]] | ||||
| @@ -69,6 +70,7 @@ class WebsocketImplProtocol: | ||||
|         self.pings = {} | ||||
|         self.conn_mutex = asyncio.Lock() | ||||
|         self.recv_lock = asyncio.Lock() | ||||
|         self.recv_cancel = None | ||||
|         self.process_event_mutex = asyncio.Lock() | ||||
|         self.data_finished_fut = None | ||||
|         self.can_pause = True | ||||
| @@ -180,12 +182,15 @@ class WebsocketImplProtocol: | ||||
|                 if not isinstance(event, Frame): | ||||
|                     # Event is not a frame. Ignore it. | ||||
|                     continue | ||||
|                 if event.opcode == OP_PONG: | ||||
|                 if event.opcode == Opcode.PONG: | ||||
|                     await self.process_pong(event) | ||||
|                 elif event.opcode == Opcode.CLOSE: | ||||
|                     if self.recv_cancel: | ||||
|                         self.recv_cancel.cancel() | ||||
|                 else: | ||||
|                     await self.assembler.put(event) | ||||
|  | ||||
|     async def process_pong(self, frame: "Frame") -> None: | ||||
|     async def process_pong(self, frame: Frame) -> None: | ||||
|         if frame.data in self.pings: | ||||
|             # Acknowledge all pings up to the one matching this pong. | ||||
|             ping_ids = [] | ||||
| @@ -237,7 +242,7 @@ class WebsocketImplProtocol: | ||||
|             # It is expected for this task to be cancelled during during | ||||
|             # normal operation, when the connection is closed. | ||||
|             logger.debug("Websocket keepalive ping task was cancelled.") | ||||
|         except ConnectionClosed: | ||||
|         except (ConnectionClosed, WebsocketClosed): | ||||
|             logger.debug("Websocket closed. Keepalive ping task exiting.") | ||||
|         except Exception as e: | ||||
|             error_logger.warning( | ||||
| @@ -495,6 +500,7 @@ class WebsocketImplProtocol: | ||||
|         Set ``timeout`` to ``0`` to check if a message was already received. | ||||
|         :raises ~websockets.exceptions.ConnectionClosed: when the | ||||
|             connection is closed | ||||
|         :raises asyncio.CancelledError: if the websocket closes while waiting | ||||
|         :raises ServerError: if two tasks call :meth:`recv` or | ||||
|             :meth:`recv_streaming` concurrently | ||||
|         """ | ||||
| @@ -505,13 +511,28 @@ class WebsocketImplProtocol: | ||||
|                 "already waiting for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state in (CLOSED, CLOSING): | ||||
|             raise ServerError( | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         try: | ||||
|             return await self.assembler.get(timeout) | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             done, pending = await asyncio.wait( | ||||
|                 (self.recv_cancel, self.assembler.get(timeout)), | ||||
|                 return_when=asyncio.FIRST_COMPLETED, | ||||
|             ) | ||||
|             done_task = next(iter(done)) | ||||
|             if done_task is self.recv_cancel: | ||||
|                 # recv was cancelled | ||||
|                 for p in pending: | ||||
|                     p.cancel() | ||||
|                 raise asyncio.CancelledError() | ||||
|             else: | ||||
|                 self.recv_cancel.cancel() | ||||
|                 return done_task.result() | ||||
|         finally: | ||||
|             self.recv_cancel = None | ||||
|             self.recv_lock.release() | ||||
|  | ||||
|     async def recv_burst(self, max_recv=256) -> Sequence[Data]: | ||||
| @@ -537,8 +558,9 @@ class WebsocketImplProtocol: | ||||
|                 "for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state in (CLOSED, CLOSING): | ||||
|             raise ServerError( | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         messages = [] | ||||
| @@ -546,8 +568,19 @@ class WebsocketImplProtocol: | ||||
|             # Prevent pausing the transport when we're | ||||
|             # receiving a burst of messages | ||||
|             self.can_pause = False | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             while True: | ||||
|                 m = await self.assembler.get(timeout=0) | ||||
|                 done, pending = await asyncio.wait( | ||||
|                     (self.recv_cancel, self.assembler.get(timeout=0)), | ||||
|                     return_when=asyncio.FIRST_COMPLETED, | ||||
|                 ) | ||||
|                 done_task = next(iter(done)) | ||||
|                 if done_task is self.recv_cancel: | ||||
|                     # recv_burst was cancelled | ||||
|                     for p in pending: | ||||
|                         p.cancel() | ||||
|                     raise asyncio.CancelledError() | ||||
|                 m = done_task.result() | ||||
|                 if m is None: | ||||
|                     # None left in the burst. This is good! | ||||
|                     break | ||||
| @@ -558,7 +591,9 @@ class WebsocketImplProtocol: | ||||
|                 # Allow an eventloop iteration for the | ||||
|                 # next message to pass into the Assembler | ||||
|                 await asyncio.sleep(0) | ||||
|             self.recv_cancel.cancel() | ||||
|         finally: | ||||
|             self.recv_cancel = None | ||||
|             self.can_pause = True | ||||
|             self.recv_lock.release() | ||||
|         return messages | ||||
| @@ -578,16 +613,25 @@ class WebsocketImplProtocol: | ||||
|                 "is already waiting for the next message" | ||||
|             ) | ||||
|         await self.recv_lock.acquire() | ||||
|         if self.connection.state in (CLOSED, CLOSING): | ||||
|             raise ServerError( | ||||
|         if self.connection.state is CLOSED: | ||||
|             self.recv_lock.release() | ||||
|             raise WebsocketClosed( | ||||
|                 "Cannot receive from websocket interface after it is closed." | ||||
|             ) | ||||
|         try: | ||||
|             cancelled = False | ||||
|             self.recv_cancel = asyncio.Future() | ||||
|             self.can_pause = False | ||||
|             async for m in self.assembler.get_iter(): | ||||
|                 if self.recv_cancel.done(): | ||||
|                     cancelled = True | ||||
|                     break | ||||
|                 yield m | ||||
|             if cancelled: | ||||
|                 raise asyncio.CancelledError() | ||||
|         finally: | ||||
|             self.can_pause = True | ||||
|             self.recv_cancel = None | ||||
|             self.recv_lock.release() | ||||
|  | ||||
|     async def send(self, message: Union[Data, Iterable[Data]]) -> None: | ||||
| @@ -611,7 +655,7 @@ class WebsocketImplProtocol: | ||||
|         async with self.conn_mutex: | ||||
|  | ||||
|             if self.connection.state in (CLOSED, CLOSING): | ||||
|                 raise ServerError( | ||||
|                 raise WebsocketClosed( | ||||
|                     "Cannot write to websocket interface after it is closed." | ||||
|                 ) | ||||
|             if (not self.data_finished_fut) or self.data_finished_fut.done(): | ||||
| @@ -658,7 +702,7 @@ class WebsocketImplProtocol: | ||||
|         """ | ||||
|         async with self.conn_mutex: | ||||
|             if self.connection.state in (CLOSED, CLOSING): | ||||
|                 raise ServerError( | ||||
|                 raise WebsocketClosed( | ||||
|                     "Cannot send a ping when the websocket interface " | ||||
|                     "is closed." | ||||
|                 ) | ||||
| @@ -728,7 +772,7 @@ class WebsocketImplProtocol: | ||||
|                     SanicProtocol.close(self.io_proto, timeout=1.0) | ||||
|  | ||||
|     async def async_data_received(self, data_to_send, events_to_process): | ||||
|         if self.connection.state == OPEN and len(data_to_send) > 0: | ||||
|         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||
|             # receiving data can generate data to send (eg, pong for a ping) | ||||
|             # send connection.data_to_send() | ||||
|             await self.send_data(data_to_send) | ||||
| @@ -747,11 +791,12 @@ class WebsocketImplProtocol: | ||||
|     async def async_eof_received(self, data_to_send, events_to_process): | ||||
|         # receiving EOF can generate data to send | ||||
|         # send connection.data_to_send() | ||||
|         if self.connection.state == OPEN: | ||||
|         if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: | ||||
|             await self.send_data(data_to_send) | ||||
|         if len(events_to_process) > 0: | ||||
|             await self.process_events(events_to_process) | ||||
|  | ||||
|         if self.recv_cancel: | ||||
|             self.recv_cancel.cancel() | ||||
|         if ( | ||||
|             self.auto_closer_task | ||||
|             and not self.auto_closer_task.done() | ||||
| @@ -760,6 +805,7 @@ class WebsocketImplProtocol: | ||||
|         ): | ||||
|             # Auto-close the connection | ||||
|             self.data_finished_fut.set_result(None) | ||||
|             # Cancel the running handler if its waiting | ||||
|         else: | ||||
|             # This will fail the connection appropriately | ||||
|             SanicProtocol.close(self.io_proto, timeout=1.0) | ||||
| @@ -768,10 +814,9 @@ class WebsocketImplProtocol: | ||||
|         self.connection.receive_eof() | ||||
|         data_to_send = self.connection.data_to_send() | ||||
|         events_to_process = self.connection.events_received() | ||||
|         if len(data_to_send) > 0 or len(events_to_process) > 0: | ||||
|             asyncio.create_task( | ||||
|                 self.async_eof_received(data_to_send, events_to_process) | ||||
|             ) | ||||
|         asyncio.create_task( | ||||
|             self.async_eof_received(data_to_send, events_to_process) | ||||
|         ) | ||||
|         return False | ||||
|  | ||||
|     def connection_lost(self, exc): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Ashley Sommer
					Ashley Sommer