Some fixes to the new Websockets impl (#2248)
* First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * Update sanic/websocket.py Co-authored-by: Adam Hopkins <adam@amhopkins.com> * wip, update websockets code to new Sans/IO API * Refactored new websockets impl into own modules Incorporated other suggestions made by team * Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient * Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. * Change a warning message to debug level Remove default values for deprecated websocket parameters * Fix flake8 errors * Fix a couple of missed failing tests * remove websocket bench from examples * Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. * Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() * remove unused import in websocket example app * re-run isort after Flake8 fixes * Some fixes to the new Websockets impl Will throw WebsocketClosed exception instead of ServerException now when attempting to read or write to closed websocket, this makes it easier to catch The various ws.recv() methods now have the ability to raise CancelledError into your websocket handler Fix a niche close-socket negotiation bug Fix bug where http protocol thought the websocket never sent any response. Allow data to still send in some cases after websocket enters CLOSING state. Fix some badly formatted and badly placed comments * allow eof_received to send back data too, if the connection is in CLOSING state Co-authored-by: Adam Hopkins <adam@amhopkins.com> Co-authored-by: Adam Hopkins <admhpkns@gmail.com>
This commit is contained in:
parent
cf1d2148ac
commit
f7abf3db1b
|
@ -237,6 +237,11 @@ class InvalidSignal(SanicException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketClosed(SanicException):
|
||||||
|
quiet = True
|
||||||
|
message = "Client has closed the websocket connection"
|
||||||
|
|
||||||
|
|
||||||
def abort(status_code: int, message: Optional[Union[str, bytes]] = None):
|
def abort(status_code: int, message: Optional[Union[str, bytes]] = None):
|
||||||
"""
|
"""
|
||||||
Raise an exception based on SanicException. Returns the HTTP response
|
Raise an exception based on SanicException. Returns the HTTP response
|
||||||
|
|
|
@ -187,12 +187,12 @@ class Http(metaclass=TouchUpMeta):
|
||||||
if self.response:
|
if self.response:
|
||||||
self.response.stream = None
|
self.response.stream = None
|
||||||
|
|
||||||
self.init_for_request()
|
|
||||||
|
|
||||||
# Exit and disconnect if no more requests can be taken
|
# Exit and disconnect if no more requests can be taken
|
||||||
if self.stage is not Stage.IDLE or not self.keep_alive:
|
if self.stage is not Stage.IDLE or not self.keep_alive:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self.init_for_request()
|
||||||
|
|
||||||
# Wait for the next request
|
# Wait for the next request
|
||||||
if not self.recv_buffer:
|
if not self.recv_buffer:
|
||||||
await self._receive_more()
|
await self._receive_more()
|
||||||
|
|
|
@ -109,7 +109,12 @@ class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
|
||||||
except Exception:
|
except Exception:
|
||||||
error_logger.exception("protocol.connection_task uncaught")
|
error_logger.exception("protocol.connection_task uncaught")
|
||||||
finally:
|
finally:
|
||||||
if self.app.debug and self._http and self.transport:
|
if (
|
||||||
|
self.app.debug
|
||||||
|
and self._http
|
||||||
|
and self.transport
|
||||||
|
and not self._http.upgrade_websocket
|
||||||
|
):
|
||||||
ip = self.transport.get_extra_info("peername")
|
ip = self.transport.get_extra_info("peername")
|
||||||
error_logger.error(
|
error_logger.error(
|
||||||
"Connection lost before response written"
|
"Connection lost before response written"
|
||||||
|
|
|
@ -148,7 +148,6 @@ class WebSocketProtocol(HttpProtocol):
|
||||||
await super().send(rbody.encode())
|
await super().send(rbody.encode())
|
||||||
else:
|
else:
|
||||||
raise ServerError(resp.body, resp.status_code)
|
raise ServerError(resp.body, resp.status_code)
|
||||||
|
|
||||||
self.websocket = WebsocketImplProtocol(
|
self.websocket = WebsocketImplProtocol(
|
||||||
ws_conn,
|
ws_conn,
|
||||||
ping_interval=self.websocket_ping_interval,
|
ping_interval=self.websocket_ping_interval,
|
||||||
|
|
|
@ -161,10 +161,8 @@ class WebsocketFrameAssembler:
|
||||||
)
|
)
|
||||||
self.message_fetched.set()
|
self.message_fetched.set()
|
||||||
self.chunks = []
|
self.chunks = []
|
||||||
self.chunks_queue = (
|
# this should already be None, but set it here for safety
|
||||||
None # this should already be None, but set it here for safety
|
self.chunks_queue = None
|
||||||
)
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def get_iter(self) -> AsyncIterator[Data]:
|
async def get_iter(self) -> AsyncIterator[Data]:
|
||||||
|
@ -193,7 +191,7 @@ class WebsocketFrameAssembler:
|
||||||
if self.message_complete.is_set():
|
if self.message_complete.is_set():
|
||||||
await self.chunks_queue.put(None)
|
await self.chunks_queue.put(None)
|
||||||
|
|
||||||
# Locking with get_in_progress ensures only one thread can get here
|
# Locking with get_in_progress ensures only one task can get here
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
yield c
|
yield c
|
||||||
while True:
|
while True:
|
||||||
|
@ -232,9 +230,8 @@ class WebsocketFrameAssembler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.message_fetched.set()
|
self.message_fetched.set()
|
||||||
self.chunks = (
|
# this should already be empty, but set it here for safety
|
||||||
[]
|
self.chunks = []
|
||||||
) # this should already be empty, but set it here for safety
|
|
||||||
self.chunks_queue = None
|
self.chunks_queue = None
|
||||||
|
|
||||||
async def put(self, frame: Frame) -> None:
|
async def put(self, frame: Frame) -> None:
|
||||||
|
|
|
@ -14,14 +14,14 @@ from typing import (
|
||||||
|
|
||||||
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
from websockets.connection import CLOSED, CLOSING, OPEN, Event
|
||||||
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
from websockets.exceptions import ConnectionClosed, ConnectionClosedError
|
||||||
from websockets.frames import OP_PONG, Frame
|
from websockets.frames import Frame, Opcode
|
||||||
from websockets.server import ServerConnection
|
from websockets.server import ServerConnection
|
||||||
from websockets.typing import Data
|
from websockets.typing import Data
|
||||||
|
|
||||||
from sanic.log import error_logger, logger
|
from sanic.log import error_logger, logger
|
||||||
from sanic.server.protocols.base_protocol import SanicProtocol
|
from sanic.server.protocols.base_protocol import SanicProtocol
|
||||||
|
|
||||||
from ...exceptions import ServerError
|
from ...exceptions import ServerError, WebsocketClosed
|
||||||
from .frame import WebsocketFrameAssembler
|
from .frame import WebsocketFrameAssembler
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ class WebsocketImplProtocol:
|
||||||
pings: Dict[bytes, asyncio.Future]
|
pings: Dict[bytes, asyncio.Future]
|
||||||
conn_mutex: asyncio.Lock
|
conn_mutex: asyncio.Lock
|
||||||
recv_lock: asyncio.Lock
|
recv_lock: asyncio.Lock
|
||||||
|
recv_cancel: Optional[asyncio.Future]
|
||||||
process_event_mutex: asyncio.Lock
|
process_event_mutex: asyncio.Lock
|
||||||
can_pause: bool
|
can_pause: bool
|
||||||
# Optional[asyncio.Future[None]]
|
# Optional[asyncio.Future[None]]
|
||||||
|
@ -69,6 +70,7 @@ class WebsocketImplProtocol:
|
||||||
self.pings = {}
|
self.pings = {}
|
||||||
self.conn_mutex = asyncio.Lock()
|
self.conn_mutex = asyncio.Lock()
|
||||||
self.recv_lock = asyncio.Lock()
|
self.recv_lock = asyncio.Lock()
|
||||||
|
self.recv_cancel = None
|
||||||
self.process_event_mutex = asyncio.Lock()
|
self.process_event_mutex = asyncio.Lock()
|
||||||
self.data_finished_fut = None
|
self.data_finished_fut = None
|
||||||
self.can_pause = True
|
self.can_pause = True
|
||||||
|
@ -180,12 +182,15 @@ class WebsocketImplProtocol:
|
||||||
if not isinstance(event, Frame):
|
if not isinstance(event, Frame):
|
||||||
# Event is not a frame. Ignore it.
|
# Event is not a frame. Ignore it.
|
||||||
continue
|
continue
|
||||||
if event.opcode == OP_PONG:
|
if event.opcode == Opcode.PONG:
|
||||||
await self.process_pong(event)
|
await self.process_pong(event)
|
||||||
|
elif event.opcode == Opcode.CLOSE:
|
||||||
|
if self.recv_cancel:
|
||||||
|
self.recv_cancel.cancel()
|
||||||
else:
|
else:
|
||||||
await self.assembler.put(event)
|
await self.assembler.put(event)
|
||||||
|
|
||||||
async def process_pong(self, frame: "Frame") -> None:
|
async def process_pong(self, frame: Frame) -> None:
|
||||||
if frame.data in self.pings:
|
if frame.data in self.pings:
|
||||||
# Acknowledge all pings up to the one matching this pong.
|
# Acknowledge all pings up to the one matching this pong.
|
||||||
ping_ids = []
|
ping_ids = []
|
||||||
|
@ -237,7 +242,7 @@ class WebsocketImplProtocol:
|
||||||
# It is expected for this task to be cancelled during during
|
# It is expected for this task to be cancelled during during
|
||||||
# normal operation, when the connection is closed.
|
# normal operation, when the connection is closed.
|
||||||
logger.debug("Websocket keepalive ping task was cancelled.")
|
logger.debug("Websocket keepalive ping task was cancelled.")
|
||||||
except ConnectionClosed:
|
except (ConnectionClosed, WebsocketClosed):
|
||||||
logger.debug("Websocket closed. Keepalive ping task exiting.")
|
logger.debug("Websocket closed. Keepalive ping task exiting.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_logger.warning(
|
error_logger.warning(
|
||||||
|
@ -495,6 +500,7 @@ class WebsocketImplProtocol:
|
||||||
Set ``timeout`` to ``0`` to check if a message was already received.
|
Set ``timeout`` to ``0`` to check if a message was already received.
|
||||||
:raises ~websockets.exceptions.ConnectionClosed: when the
|
:raises ~websockets.exceptions.ConnectionClosed: when the
|
||||||
connection is closed
|
connection is closed
|
||||||
|
:raises asyncio.CancelledError: if the websocket closes while waiting
|
||||||
:raises ServerError: if two tasks call :meth:`recv` or
|
:raises ServerError: if two tasks call :meth:`recv` or
|
||||||
:meth:`recv_streaming` concurrently
|
:meth:`recv_streaming` concurrently
|
||||||
"""
|
"""
|
||||||
|
@ -505,13 +511,28 @@ class WebsocketImplProtocol:
|
||||||
"already waiting for the next message"
|
"already waiting for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.connection.state is CLOSED:
|
||||||
raise ServerError(
|
self.recv_lock.release()
|
||||||
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return await self.assembler.get(timeout)
|
self.recv_cancel = asyncio.Future()
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
(self.recv_cancel, self.assembler.get(timeout)),
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
done_task = next(iter(done))
|
||||||
|
if done_task is self.recv_cancel:
|
||||||
|
# recv was cancelled
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
else:
|
||||||
|
self.recv_cancel.cancel()
|
||||||
|
return done_task.result()
|
||||||
finally:
|
finally:
|
||||||
|
self.recv_cancel = None
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
|
|
||||||
async def recv_burst(self, max_recv=256) -> Sequence[Data]:
|
async def recv_burst(self, max_recv=256) -> Sequence[Data]:
|
||||||
|
@ -537,8 +558,9 @@ class WebsocketImplProtocol:
|
||||||
"for the next message"
|
"for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.connection.state is CLOSED:
|
||||||
raise ServerError(
|
self.recv_lock.release()
|
||||||
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
)
|
)
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -546,8 +568,19 @@ class WebsocketImplProtocol:
|
||||||
# Prevent pausing the transport when we're
|
# Prevent pausing the transport when we're
|
||||||
# receiving a burst of messages
|
# receiving a burst of messages
|
||||||
self.can_pause = False
|
self.can_pause = False
|
||||||
|
self.recv_cancel = asyncio.Future()
|
||||||
while True:
|
while True:
|
||||||
m = await self.assembler.get(timeout=0)
|
done, pending = await asyncio.wait(
|
||||||
|
(self.recv_cancel, self.assembler.get(timeout=0)),
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
done_task = next(iter(done))
|
||||||
|
if done_task is self.recv_cancel:
|
||||||
|
# recv_burst was cancelled
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
m = done_task.result()
|
||||||
if m is None:
|
if m is None:
|
||||||
# None left in the burst. This is good!
|
# None left in the burst. This is good!
|
||||||
break
|
break
|
||||||
|
@ -558,7 +591,9 @@ class WebsocketImplProtocol:
|
||||||
# Allow an eventloop iteration for the
|
# Allow an eventloop iteration for the
|
||||||
# next message to pass into the Assembler
|
# next message to pass into the Assembler
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
self.recv_cancel.cancel()
|
||||||
finally:
|
finally:
|
||||||
|
self.recv_cancel = None
|
||||||
self.can_pause = True
|
self.can_pause = True
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
return messages
|
return messages
|
||||||
|
@ -578,16 +613,25 @@ class WebsocketImplProtocol:
|
||||||
"is already waiting for the next message"
|
"is already waiting for the next message"
|
||||||
)
|
)
|
||||||
await self.recv_lock.acquire()
|
await self.recv_lock.acquire()
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.connection.state is CLOSED:
|
||||||
raise ServerError(
|
self.recv_lock.release()
|
||||||
|
raise WebsocketClosed(
|
||||||
"Cannot receive from websocket interface after it is closed."
|
"Cannot receive from websocket interface after it is closed."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
cancelled = False
|
||||||
|
self.recv_cancel = asyncio.Future()
|
||||||
self.can_pause = False
|
self.can_pause = False
|
||||||
async for m in self.assembler.get_iter():
|
async for m in self.assembler.get_iter():
|
||||||
|
if self.recv_cancel.done():
|
||||||
|
cancelled = True
|
||||||
|
break
|
||||||
yield m
|
yield m
|
||||||
|
if cancelled:
|
||||||
|
raise asyncio.CancelledError()
|
||||||
finally:
|
finally:
|
||||||
self.can_pause = True
|
self.can_pause = True
|
||||||
|
self.recv_cancel = None
|
||||||
self.recv_lock.release()
|
self.recv_lock.release()
|
||||||
|
|
||||||
async def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
async def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||||
|
@ -611,7 +655,7 @@ class WebsocketImplProtocol:
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
|
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
raise ServerError(
|
raise WebsocketClosed(
|
||||||
"Cannot write to websocket interface after it is closed."
|
"Cannot write to websocket interface after it is closed."
|
||||||
)
|
)
|
||||||
if (not self.data_finished_fut) or self.data_finished_fut.done():
|
if (not self.data_finished_fut) or self.data_finished_fut.done():
|
||||||
|
@ -658,7 +702,7 @@ class WebsocketImplProtocol:
|
||||||
"""
|
"""
|
||||||
async with self.conn_mutex:
|
async with self.conn_mutex:
|
||||||
if self.connection.state in (CLOSED, CLOSING):
|
if self.connection.state in (CLOSED, CLOSING):
|
||||||
raise ServerError(
|
raise WebsocketClosed(
|
||||||
"Cannot send a ping when the websocket interface "
|
"Cannot send a ping when the websocket interface "
|
||||||
"is closed."
|
"is closed."
|
||||||
)
|
)
|
||||||
|
@ -728,7 +772,7 @@ class WebsocketImplProtocol:
|
||||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
|
||||||
async def async_data_received(self, data_to_send, events_to_process):
|
async def async_data_received(self, data_to_send, events_to_process):
|
||||||
if self.connection.state == OPEN and len(data_to_send) > 0:
|
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||||
# receiving data can generate data to send (eg, pong for a ping)
|
# receiving data can generate data to send (eg, pong for a ping)
|
||||||
# send connection.data_to_send()
|
# send connection.data_to_send()
|
||||||
await self.send_data(data_to_send)
|
await self.send_data(data_to_send)
|
||||||
|
@ -747,11 +791,12 @@ class WebsocketImplProtocol:
|
||||||
async def async_eof_received(self, data_to_send, events_to_process):
|
async def async_eof_received(self, data_to_send, events_to_process):
|
||||||
# receiving EOF can generate data to send
|
# receiving EOF can generate data to send
|
||||||
# send connection.data_to_send()
|
# send connection.data_to_send()
|
||||||
if self.connection.state == OPEN:
|
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
|
||||||
await self.send_data(data_to_send)
|
await self.send_data(data_to_send)
|
||||||
if len(events_to_process) > 0:
|
if len(events_to_process) > 0:
|
||||||
await self.process_events(events_to_process)
|
await self.process_events(events_to_process)
|
||||||
|
if self.recv_cancel:
|
||||||
|
self.recv_cancel.cancel()
|
||||||
if (
|
if (
|
||||||
self.auto_closer_task
|
self.auto_closer_task
|
||||||
and not self.auto_closer_task.done()
|
and not self.auto_closer_task.done()
|
||||||
|
@ -760,6 +805,7 @@ class WebsocketImplProtocol:
|
||||||
):
|
):
|
||||||
# Auto-close the connection
|
# Auto-close the connection
|
||||||
self.data_finished_fut.set_result(None)
|
self.data_finished_fut.set_result(None)
|
||||||
|
# Cancel the running handler if its waiting
|
||||||
else:
|
else:
|
||||||
# This will fail the connection appropriately
|
# This will fail the connection appropriately
|
||||||
SanicProtocol.close(self.io_proto, timeout=1.0)
|
SanicProtocol.close(self.io_proto, timeout=1.0)
|
||||||
|
@ -768,10 +814,9 @@ class WebsocketImplProtocol:
|
||||||
self.connection.receive_eof()
|
self.connection.receive_eof()
|
||||||
data_to_send = self.connection.data_to_send()
|
data_to_send = self.connection.data_to_send()
|
||||||
events_to_process = self.connection.events_received()
|
events_to_process = self.connection.events_received()
|
||||||
if len(data_to_send) > 0 or len(events_to_process) > 0:
|
asyncio.create_task(
|
||||||
asyncio.create_task(
|
self.async_eof_received(data_to_send, events_to_process)
|
||||||
self.async_eof_received(data_to_send, events_to_process)
|
)
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user