Some fixes to the new Websockets impl (#2248)

* First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work.

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* Update sanic/websocket.py

Co-authored-by: Adam Hopkins <adam@amhopkins.com>

* wip, update websockets code to new Sans/IO API

* Refactored new websockets impl into own modules
Incorporated other suggestions made by team

* Another round of work on the new websockets impl
* Added websocket_timeout support (matching previous/legacy support)
* Lots more comments
* Incorporated suggested changes from previous round of review
* Changed RuntimeError usage to ServerError
* Changed SanicException usage to ServerError
* Removed some redundant asserts
* Change remaining asserts to ServerErrors
* Fixed some timeout handling issues
* Fixed websocket.close() handling, and made it more robust
* Made auto_close task smarter and more error-resilient
* Made fail_connection routine smarter and more error-resilient

* Further new websockets impl fixes
* Update compatibility with Websockets v10
* Track server connection state in a more precise way
* Try to handle the shutdown process more gracefully
* Add a new end_connection() helper, to use as an alterative to close() or fail_connection()
* Kill the auto-close task and keepalive-timeout task when sanic is shutdown
* Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation.

* Change a warning message to debug level
Remove default values for deprecated websocket parameters

* Fix flake8 errors

* Fix a couple of missed failing tests

* remove websocket bench from examples

* Integrate suggestions from code reviews
Use Optional[T] instead of union[T,None]
Fix mypy type logic errors
change "is not None" to truthy checks where appropriate
change "is None" to falsy checks were appropriate
Add more debug logging when debug mode is on
Change to using sanic.logger for debug logging rather than error_logger.

* Fix long line lengths of debug messages
Add some new debug messages when websocket IO is paused and unpaused for flow control
Fix websocket example to use app.static()

* remove unused import in websocket example app

* re-run isort after Flake8 fixes

* Some fixes to the new Websockets impl
Will throw WebsocketClosed exception instead of ServerException now when attempting to read or write to closed websocket, this makes it easier to catch
The various ws.recv() methods now have the ability to raise CancelledError into your websocket handler
Fix a niche close-socket negotiation bug
Fix bug where http protocol thought the websocket never sent any response.
Allow data to still send in some cases after websocket enters CLOSING state.
Fix some badly formatted and badly placed comments

* allow eof_received to send back data too, if the connection is in CLOSING state

Co-authored-by: Adam Hopkins <adam@amhopkins.com>
Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
Ashley Sommer 2021-10-01 05:20:57 +10:00 committed by GitHub
parent cf1d2148ac
commit f7abf3db1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 34 deletions

View File

@ -237,6 +237,11 @@ class InvalidSignal(SanicException):
pass pass
class WebsocketClosed(SanicException):
quiet = True
message = "Client has closed the websocket connection"
def abort(status_code: int, message: Optional[Union[str, bytes]] = None): def abort(status_code: int, message: Optional[Union[str, bytes]] = None):
""" """
Raise an exception based on SanicException. Returns the HTTP response Raise an exception based on SanicException. Returns the HTTP response

View File

@ -187,12 +187,12 @@ class Http(metaclass=TouchUpMeta):
if self.response: if self.response:
self.response.stream = None self.response.stream = None
self.init_for_request()
# Exit and disconnect if no more requests can be taken # Exit and disconnect if no more requests can be taken
if self.stage is not Stage.IDLE or not self.keep_alive: if self.stage is not Stage.IDLE or not self.keep_alive:
break break
self.init_for_request()
# Wait for the next request # Wait for the next request
if not self.recv_buffer: if not self.recv_buffer:
await self._receive_more() await self._receive_more()

View File

@ -109,7 +109,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
except Exception: except Exception:
error_logger.exception("protocol.connection_task uncaught") error_logger.exception("protocol.connection_task uncaught")
finally: finally:
if self.app.debug and self._http and self.transport: if (
self.app.debug
and self._http
and self.transport
and not self._http.upgrade_websocket
):
ip = self.transport.get_extra_info("peername") ip = self.transport.get_extra_info("peername")
error_logger.error( error_logger.error(
"Connection lost before response written" "Connection lost before response written"

View File

@ -148,7 +148,6 @@ class WebSocketProtocol(HttpProtocol):
await super().send(rbody.encode()) await super().send(rbody.encode())
else: else:
raise ServerError(resp.body, resp.status_code) raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol( self.websocket = WebsocketImplProtocol(
ws_conn, ws_conn,
ping_interval=self.websocket_ping_interval, ping_interval=self.websocket_ping_interval,

View File

@ -161,10 +161,8 @@ class WebsocketFrameAssembler:
) )
self.message_fetched.set() self.message_fetched.set()
self.chunks = [] self.chunks = []
self.chunks_queue = ( # this should already be None, but set it here for safety
None # this should already be None, but set it here for safety self.chunks_queue = None
)
return message return message
async def get_iter(self) -> AsyncIterator[Data]: async def get_iter(self) -> AsyncIterator[Data]:
@ -193,7 +191,7 @@ class WebsocketFrameAssembler:
if self.message_complete.is_set(): if self.message_complete.is_set():
await self.chunks_queue.put(None) await self.chunks_queue.put(None)
# Locking with get_in_progress ensures only one thread can get here # Locking with get_in_progress ensures only one task can get here
for c in chunks: for c in chunks:
yield c yield c
while True: while True:
@ -232,9 +230,8 @@ class WebsocketFrameAssembler:
) )
self.message_fetched.set() self.message_fetched.set()
self.chunks = ( # this should already be empty, but set it here for safety
[] self.chunks = []
) # this should already be empty, but set it here for safety
self.chunks_queue = None self.chunks_queue = None
async def put(self, frame: Frame) -> None: async def put(self, frame: Frame) -> None:

View File

@ -14,14 +14,14 @@ from typing import (
from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import ConnectionClosed, ConnectionClosedError from websockets.exceptions import ConnectionClosed, ConnectionClosedError
from websockets.frames import OP_PONG, Frame from websockets.frames import Frame, Opcode
from websockets.server import ServerConnection from websockets.server import ServerConnection
from websockets.typing import Data from websockets.typing import Data
from sanic.log import error_logger, logger from sanic.log import error_logger, logger
from sanic.server.protocols.base_protocol import SanicProtocol from sanic.server.protocols.base_protocol import SanicProtocol
from ...exceptions import ServerError from ...exceptions import ServerError, WebsocketClosed
from .frame import WebsocketFrameAssembler from .frame import WebsocketFrameAssembler
@ -38,6 +38,7 @@ class WebsocketImplProtocol:
pings: Dict[bytes, asyncio.Future] pings: Dict[bytes, asyncio.Future]
conn_mutex: asyncio.Lock conn_mutex: asyncio.Lock
recv_lock: asyncio.Lock recv_lock: asyncio.Lock
recv_cancel: Optional[asyncio.Future]
process_event_mutex: asyncio.Lock process_event_mutex: asyncio.Lock
can_pause: bool can_pause: bool
# Optional[asyncio.Future[None]] # Optional[asyncio.Future[None]]
@ -69,6 +70,7 @@ class WebsocketImplProtocol:
self.pings = {} self.pings = {}
self.conn_mutex = asyncio.Lock() self.conn_mutex = asyncio.Lock()
self.recv_lock = asyncio.Lock() self.recv_lock = asyncio.Lock()
self.recv_cancel = None
self.process_event_mutex = asyncio.Lock() self.process_event_mutex = asyncio.Lock()
self.data_finished_fut = None self.data_finished_fut = None
self.can_pause = True self.can_pause = True
@ -180,12 +182,15 @@ class WebsocketImplProtocol:
if not isinstance(event, Frame): if not isinstance(event, Frame):
# Event is not a frame. Ignore it. # Event is not a frame. Ignore it.
continue continue
if event.opcode == OP_PONG: if event.opcode == Opcode.PONG:
await self.process_pong(event) await self.process_pong(event)
elif event.opcode == Opcode.CLOSE:
if self.recv_cancel:
self.recv_cancel.cancel()
else: else:
await self.assembler.put(event) await self.assembler.put(event)
async def process_pong(self, frame: "Frame") -> None: async def process_pong(self, frame: Frame) -> None:
if frame.data in self.pings: if frame.data in self.pings:
# Acknowledge all pings up to the one matching this pong. # Acknowledge all pings up to the one matching this pong.
ping_ids = [] ping_ids = []
@ -237,7 +242,7 @@ class WebsocketImplProtocol:
# It is expected for this task to be cancelled during during # It is expected for this task to be cancelled during during
# normal operation, when the connection is closed. # normal operation, when the connection is closed.
logger.debug("Websocket keepalive ping task was cancelled.") logger.debug("Websocket keepalive ping task was cancelled.")
except ConnectionClosed: except (ConnectionClosed, WebsocketClosed):
logger.debug("Websocket closed. Keepalive ping task exiting.") logger.debug("Websocket closed. Keepalive ping task exiting.")
except Exception as e: except Exception as e:
error_logger.warning( error_logger.warning(
@ -495,6 +500,7 @@ class WebsocketImplProtocol:
Set ``timeout`` to ``0`` to check if a message was already received. Set ``timeout`` to ``0`` to check if a message was already received.
:raises ~websockets.exceptions.ConnectionClosed: when the :raises ~websockets.exceptions.ConnectionClosed: when the
connection is closed connection is closed
:raises asyncio.CancelledError: if the websocket closes while waiting
:raises ServerError: if two tasks call :meth:`recv` or :raises ServerError: if two tasks call :meth:`recv` or
:meth:`recv_streaming` concurrently :meth:`recv_streaming` concurrently
""" """
@ -505,13 +511,28 @@ class WebsocketImplProtocol:
"already waiting for the next message" "already waiting for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING): if self.connection.state is CLOSED:
raise ServerError( self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
) )
try: try:
return await self.assembler.get(timeout) self.recv_cancel = asyncio.Future()
done, pending = await asyncio.wait(
(self.recv_cancel, self.assembler.get(timeout)),
return_when=asyncio.FIRST_COMPLETED,
)
done_task = next(iter(done))
if done_task is self.recv_cancel:
# recv was cancelled
for p in pending:
p.cancel()
raise asyncio.CancelledError()
else:
self.recv_cancel.cancel()
return done_task.result()
finally: finally:
self.recv_cancel = None
self.recv_lock.release() self.recv_lock.release()
async def recv_burst(self, max_recv=256) -> Sequence[Data]: async def recv_burst(self, max_recv=256) -> Sequence[Data]:
@ -537,8 +558,9 @@ class WebsocketImplProtocol:
"for the next message" "for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING): if self.connection.state is CLOSED:
raise ServerError( self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
) )
messages = [] messages = []
@ -546,8 +568,19 @@ class WebsocketImplProtocol:
# Prevent pausing the transport when we're # Prevent pausing the transport when we're
# receiving a burst of messages # receiving a burst of messages
self.can_pause = False self.can_pause = False
self.recv_cancel = asyncio.Future()
while True: while True:
m = await self.assembler.get(timeout=0) done, pending = await asyncio.wait(
(self.recv_cancel, self.assembler.get(timeout=0)),
return_when=asyncio.FIRST_COMPLETED,
)
done_task = next(iter(done))
if done_task is self.recv_cancel:
# recv_burst was cancelled
for p in pending:
p.cancel()
raise asyncio.CancelledError()
m = done_task.result()
if m is None: if m is None:
# None left in the burst. This is good! # None left in the burst. This is good!
break break
@ -558,7 +591,9 @@ class WebsocketImplProtocol:
# Allow an eventloop iteration for the # Allow an eventloop iteration for the
# next message to pass into the Assembler # next message to pass into the Assembler
await asyncio.sleep(0) await asyncio.sleep(0)
self.recv_cancel.cancel()
finally: finally:
self.recv_cancel = None
self.can_pause = True self.can_pause = True
self.recv_lock.release() self.recv_lock.release()
return messages return messages
@ -578,16 +613,25 @@ class WebsocketImplProtocol:
"is already waiting for the next message" "is already waiting for the next message"
) )
await self.recv_lock.acquire() await self.recv_lock.acquire()
if self.connection.state in (CLOSED, CLOSING): if self.connection.state is CLOSED:
raise ServerError( self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed." "Cannot receive from websocket interface after it is closed."
) )
try: try:
cancelled = False
self.recv_cancel = asyncio.Future()
self.can_pause = False self.can_pause = False
async for m in self.assembler.get_iter(): async for m in self.assembler.get_iter():
if self.recv_cancel.done():
cancelled = True
break
yield m yield m
if cancelled:
raise asyncio.CancelledError()
finally: finally:
self.can_pause = True self.can_pause = True
self.recv_cancel = None
self.recv_lock.release() self.recv_lock.release()
async def send(self, message: Union[Data, Iterable[Data]]) -> None: async def send(self, message: Union[Data, Iterable[Data]]) -> None:
@ -611,7 +655,7 @@ class WebsocketImplProtocol:
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING): if self.connection.state in (CLOSED, CLOSING):
raise ServerError( raise WebsocketClosed(
"Cannot write to websocket interface after it is closed." "Cannot write to websocket interface after it is closed."
) )
if (not self.data_finished_fut) or self.data_finished_fut.done(): if (not self.data_finished_fut) or self.data_finished_fut.done():
@ -658,7 +702,7 @@ class WebsocketImplProtocol:
""" """
async with self.conn_mutex: async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING): if self.connection.state in (CLOSED, CLOSING):
raise ServerError( raise WebsocketClosed(
"Cannot send a ping when the websocket interface " "Cannot send a ping when the websocket interface "
"is closed." "is closed."
) )
@ -728,7 +772,7 @@ class WebsocketImplProtocol:
SanicProtocol.close(self.io_proto, timeout=1.0) SanicProtocol.close(self.io_proto, timeout=1.0)
async def async_data_received(self, data_to_send, events_to_process): async def async_data_received(self, data_to_send, events_to_process):
if self.connection.state == OPEN and len(data_to_send) > 0: if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
# receiving data can generate data to send (eg, pong for a ping) # receiving data can generate data to send (eg, pong for a ping)
# send connection.data_to_send() # send connection.data_to_send()
await self.send_data(data_to_send) await self.send_data(data_to_send)
@ -747,11 +791,12 @@ class WebsocketImplProtocol:
async def async_eof_received(self, data_to_send, events_to_process): async def async_eof_received(self, data_to_send, events_to_process):
# receiving EOF can generate data to send # receiving EOF can generate data to send
# send connection.data_to_send() # send connection.data_to_send()
if self.connection.state == OPEN: if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
await self.send_data(data_to_send) await self.send_data(data_to_send)
if len(events_to_process) > 0: if len(events_to_process) > 0:
await self.process_events(events_to_process) await self.process_events(events_to_process)
if self.recv_cancel:
self.recv_cancel.cancel()
if ( if (
self.auto_closer_task self.auto_closer_task
and not self.auto_closer_task.done() and not self.auto_closer_task.done()
@ -760,6 +805,7 @@ class WebsocketImplProtocol:
): ):
# Auto-close the connection # Auto-close the connection
self.data_finished_fut.set_result(None) self.data_finished_fut.set_result(None)
# Cancel the running handler if its waiting
else: else:
# This will fail the connection appropriately # This will fail the connection appropriately
SanicProtocol.close(self.io_proto, timeout=1.0) SanicProtocol.close(self.io_proto, timeout=1.0)
@ -768,10 +814,9 @@ class WebsocketImplProtocol:
self.connection.receive_eof() self.connection.receive_eof()
data_to_send = self.connection.data_to_send() data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received() events_to_process = self.connection.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0: asyncio.create_task(
asyncio.create_task( self.async_eof_received(data_to_send, events_to_process)
self.async_eof_received(data_to_send, events_to_process) )
)
return False return False
def connection_lost(self, exc): def connection_lost(self, exc):