290 lines
11 KiB
Python

import asyncio
import codecs
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
from websockets.frames import Frame, Opcode
from websockets.typing import Data
from sanic.exceptions import ServerError
if TYPE_CHECKING:
from .impl import WebsocketImplProtocol
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
class WebsocketFrameAssembler:
"""
Assemble a message from frames.
Code borrowed from aaugustin/websockets project:
https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
"""
__slots__ = (
"protocol",
"read_mutex",
"write_mutex",
"message_complete",
"message_fetched",
"get_in_progress",
"decoder",
"completed_queue",
"chunks",
"chunks_queue",
"paused",
"get_id",
"put_id",
)
if TYPE_CHECKING:
protocol: "WebsocketImplProtocol"
read_mutex: asyncio.Lock
write_mutex: asyncio.Lock
message_complete: asyncio.Event
message_fetched: asyncio.Event
completed_queue: asyncio.Queue
get_in_progress: bool
decoder: Optional[codecs.IncrementalDecoder]
# For streaming chunks rather than messages:
chunks: List[Data]
chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
paused: bool
def __init__(self, protocol) -> None:
self.protocol = protocol
self.read_mutex = asyncio.Lock()
self.write_mutex = asyncio.Lock()
self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data]
# put() sets this event to tell get() that a message can be fetched.
self.message_complete = asyncio.Event()
# get() sets this event to let put()
self.message_fetched = asyncio.Event()
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False
# Decoder for text frames, None for binary frames.
self.decoder = None
# Buffer data from frames belonging to the same message.
self.chunks = []
# When switching from "buffering" to "streaming", we use a thread-safe
# queue for transferring frames from the writing thread (library code)
# to the reading thread (user code). We're buffering when chunks_queue
# is None and streaming when it's a Queue. None is a sentinel
# value marking the end of the stream, superseding message_complete.
# Stream data from frames belonging to the same message.
self.chunks_queue = None
# Flag to indicate we've paused the protocol
self.paused = False
async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
"""
Read the next message.
:meth:`get` returns a single :class:`str` or :class:`bytes`.
If the :message was fragmented, :meth:`get` waits until the last frame
is received, then it reassembles the message.
If ``timeout`` is set and elapses before a complete message is
received, :meth:`get` returns ``None``.
"""
completed: bool
async with self.read_mutex:
if timeout is not None and timeout <= 0:
if not self.message_complete.is_set():
return None
if self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is only here as a failsafe
raise ServerError(
"Called get() on Websocket frame assembler "
"while asynchronous get is already in progress."
)
self.get_in_progress = True
# If the message_complete event isn't set yet, release the lock to
# allow put() to run and eventually set it.
# Locking with get_in_progress ensures only one task can get here.
if timeout is None:
completed = await self.message_complete.wait()
elif timeout <= 0:
completed = self.message_complete.is_set()
else:
try:
await asyncio.wait_for(
self.message_complete.wait(), timeout=timeout
)
except asyncio.TimeoutError:
...
finally:
completed = self.message_complete.is_set()
# Unpause the transport, if its paused
if self.paused:
self.protocol.resume_frames()
self.paused = False
if not self.get_in_progress: # no cov
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"State of Websocket frame assembler was modified while an "
"asynchronous get was in progress."
)
self.get_in_progress = False
# Waiting for a complete message timed out.
if not completed:
return None
if not self.message_complete.is_set():
return None
self.message_complete.clear()
joiner: Data = b"" if self.decoder is None else ""
# mypy cannot figure out that chunks have the proper type.
message: Data = joiner.join(self.chunks) # type: ignore
if self.message_fetched.is_set():
# This should be guarded against with the read_mutex,
# and get_in_progress check, this exception is here
# as a failsafe
raise ServerError(
"Websocket get() found a message when " "state was already fetched."
)
self.message_fetched.set()
self.chunks = []
# this should already be None, but set it here for safety
self.chunks_queue = None
return message
async def get_iter(self) -> AsyncIterator[Data]:
"""
Stream the next message.
Iterating the return value of :meth:`get_iter` yields a :class:`str`
or :class:`bytes` for each frame in the message.
"""
async with self.read_mutex:
if self.get_in_progress:
# This should be guarded against with the read_mutex,
# exception is only here as a failsafe
raise ServerError(
"Called get_iter on Websocket frame assembler "
"while asynchronous get is already in progress."
)
self.get_in_progress = True
chunks = self.chunks
self.chunks = []
self.chunks_queue = asyncio.Queue()
# Sending None in chunk_queue supersedes setting message_complete
# when switching to "streaming". If message is already complete
# when the switch happens, put() didn't send None, so we have to.
if self.message_complete.is_set():
await self.chunks_queue.put(None)
# Locking with get_in_progress ensures only one task can get here
for c in chunks:
yield c
while True:
chunk = await self.chunks_queue.get()
if chunk is None:
break
yield chunk
# Unpause the transport, if its paused
if self.paused:
self.protocol.resume_frames()
self.paused = False
if not self.get_in_progress: # no cov
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"State of Websocket frame assembler was modified while an "
"asynchronous get was in progress."
)
self.get_in_progress = False
if not self.message_complete.is_set(): # no cov
# This should be guarded against with the read_mutex,
# exception is here as a failsafe
raise ServerError(
"Websocket frame assembler chunks queue ended before "
"message was complete."
)
self.message_complete.clear()
if self.message_fetched.is_set(): # no cov
# This should be guarded against with the read_mutex,
# and get_in_progress check, this exception is
# here as a failsafe
raise ServerError(
"Websocket get_iter() found a message when state was "
"already fetched."
)
self.message_fetched.set()
# this should already be empty, but set it here for safety
self.chunks = []
self.chunks_queue = None
async def put(self, frame: Frame) -> None:
"""
Add ``frame`` to the next message.
When ``frame`` is the final frame in a message, :meth:`put` waits
until the message is fetched, either by calling :meth:`get` or by
iterating the return value of :meth:`get_iter`.
:meth:`put` assumes that the stream of frames respects the protocol.
If it doesn't, the behavior is undefined.
"""
async with self.write_mutex:
if frame.opcode is Opcode.TEXT:
self.decoder = UTF8Decoder(errors="strict")
elif frame.opcode is Opcode.BINARY:
self.decoder = None
elif frame.opcode is Opcode.CONT:
pass
else:
# Ignore control frames.
return
data: Data
if self.decoder is not None:
data = self.decoder.decode(frame.data, frame.fin)
else:
data = frame.data
if self.chunks_queue is None:
self.chunks.append(data)
else:
await self.chunks_queue.put(data)
if not frame.fin:
return
if not self.get_in_progress:
# nobody is waiting for this frame, so try to pause subsequent
# frames at the protocol level
self.paused = self.protocol.pause_frames()
# Message is complete. Wait until it's fetched to return.
if self.chunks_queue is not None:
await self.chunks_queue.put(None)
if self.message_complete.is_set():
# This should be guarded against with the write_mutex
raise ServerError(
"Websocket put() got a new message when a message was "
"already in its chamber."
)
self.message_complete.set() # Signal to get() it can serve the
if self.message_fetched.is_set():
# This should be guarded against with the write_mutex
raise ServerError(
"Websocket put() got a new message when the previous "
"message was not yet fetched."
)
# Allow get() to run and eventually set the event.
await self.message_fetched.wait()
self.message_fetched.clear()
self.decoder = None