290 lines
11 KiB
Python
290 lines
11 KiB
Python
import asyncio
|
|
import codecs
|
|
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
|
|
|
from websockets.frames import Frame, Opcode
|
|
from websockets.typing import Data
|
|
|
|
from sanic.exceptions import ServerError
|
|
|
|
if TYPE_CHECKING:
|
|
from .impl import WebsocketImplProtocol
|
|
|
|
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
|
|
|
|
|
|
class WebsocketFrameAssembler:
|
|
"""
|
|
Assemble a message from frames.
|
|
Code borrowed from aaugustin/websockets project:
|
|
https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
|
|
"""
|
|
|
|
__slots__ = (
|
|
"protocol",
|
|
"read_mutex",
|
|
"write_mutex",
|
|
"message_complete",
|
|
"message_fetched",
|
|
"get_in_progress",
|
|
"decoder",
|
|
"completed_queue",
|
|
"chunks",
|
|
"chunks_queue",
|
|
"paused",
|
|
"get_id",
|
|
"put_id",
|
|
)
|
|
if TYPE_CHECKING:
|
|
protocol: "WebsocketImplProtocol"
|
|
read_mutex: asyncio.Lock
|
|
write_mutex: asyncio.Lock
|
|
message_complete: asyncio.Event
|
|
message_fetched: asyncio.Event
|
|
completed_queue: asyncio.Queue
|
|
get_in_progress: bool
|
|
decoder: Optional[codecs.IncrementalDecoder]
|
|
# For streaming chunks rather than messages:
|
|
chunks: List[Data]
|
|
chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
|
|
paused: bool
|
|
|
|
def __init__(self, protocol) -> None:
|
|
self.protocol = protocol
|
|
|
|
self.read_mutex = asyncio.Lock()
|
|
self.write_mutex = asyncio.Lock()
|
|
|
|
self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data]
|
|
|
|
# put() sets this event to tell get() that a message can be fetched.
|
|
self.message_complete = asyncio.Event()
|
|
# get() sets this event to let put()
|
|
self.message_fetched = asyncio.Event()
|
|
|
|
# This flag prevents concurrent calls to get() by user code.
|
|
self.get_in_progress = False
|
|
|
|
# Decoder for text frames, None for binary frames.
|
|
self.decoder = None
|
|
|
|
# Buffer data from frames belonging to the same message.
|
|
self.chunks = []
|
|
|
|
# When switching from "buffering" to "streaming", we use a thread-safe
|
|
# queue for transferring frames from the writing thread (library code)
|
|
# to the reading thread (user code). We're buffering when chunks_queue
|
|
# is None and streaming when it's a Queue. None is a sentinel
|
|
# value marking the end of the stream, superseding message_complete.
|
|
|
|
# Stream data from frames belonging to the same message.
|
|
self.chunks_queue = None
|
|
|
|
# Flag to indicate we've paused the protocol
|
|
self.paused = False
|
|
|
|
async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
|
|
"""
|
|
Read the next message.
|
|
:meth:`get` returns a single :class:`str` or :class:`bytes`.
|
|
If the :message was fragmented, :meth:`get` waits until the last frame
|
|
is received, then it reassembles the message.
|
|
If ``timeout`` is set and elapses before a complete message is
|
|
received, :meth:`get` returns ``None``.
|
|
"""
|
|
completed: bool
|
|
async with self.read_mutex:
|
|
if timeout is not None and timeout <= 0:
|
|
if not self.message_complete.is_set():
|
|
return None
|
|
if self.get_in_progress:
|
|
# This should be guarded against with the read_mutex,
|
|
# exception is only here as a failsafe
|
|
raise ServerError(
|
|
"Called get() on Websocket frame assembler "
|
|
"while asynchronous get is already in progress."
|
|
)
|
|
self.get_in_progress = True
|
|
|
|
# If the message_complete event isn't set yet, release the lock to
|
|
# allow put() to run and eventually set it.
|
|
# Locking with get_in_progress ensures only one task can get here.
|
|
if timeout is None:
|
|
completed = await self.message_complete.wait()
|
|
elif timeout <= 0:
|
|
completed = self.message_complete.is_set()
|
|
else:
|
|
try:
|
|
await asyncio.wait_for(
|
|
self.message_complete.wait(), timeout=timeout
|
|
)
|
|
except asyncio.TimeoutError:
|
|
...
|
|
finally:
|
|
completed = self.message_complete.is_set()
|
|
|
|
# Unpause the transport, if its paused
|
|
if self.paused:
|
|
self.protocol.resume_frames()
|
|
self.paused = False
|
|
if not self.get_in_progress: # no cov
|
|
# This should be guarded against with the read_mutex,
|
|
# exception is here as a failsafe
|
|
raise ServerError(
|
|
"State of Websocket frame assembler was modified while an "
|
|
"asynchronous get was in progress."
|
|
)
|
|
self.get_in_progress = False
|
|
|
|
# Waiting for a complete message timed out.
|
|
if not completed:
|
|
return None
|
|
if not self.message_complete.is_set():
|
|
return None
|
|
|
|
self.message_complete.clear()
|
|
|
|
joiner: Data = b"" if self.decoder is None else ""
|
|
# mypy cannot figure out that chunks have the proper type.
|
|
message: Data = joiner.join(self.chunks) # type: ignore
|
|
if self.message_fetched.is_set():
|
|
# This should be guarded against with the read_mutex,
|
|
# and get_in_progress check, this exception is here
|
|
# as a failsafe
|
|
raise ServerError(
|
|
"Websocket get() found a message when " "state was already fetched."
|
|
)
|
|
self.message_fetched.set()
|
|
self.chunks = []
|
|
# this should already be None, but set it here for safety
|
|
self.chunks_queue = None
|
|
return message
|
|
|
|
async def get_iter(self) -> AsyncIterator[Data]:
|
|
"""
|
|
Stream the next message.
|
|
Iterating the return value of :meth:`get_iter` yields a :class:`str`
|
|
or :class:`bytes` for each frame in the message.
|
|
"""
|
|
async with self.read_mutex:
|
|
if self.get_in_progress:
|
|
# This should be guarded against with the read_mutex,
|
|
# exception is only here as a failsafe
|
|
raise ServerError(
|
|
"Called get_iter on Websocket frame assembler "
|
|
"while asynchronous get is already in progress."
|
|
)
|
|
self.get_in_progress = True
|
|
|
|
chunks = self.chunks
|
|
self.chunks = []
|
|
self.chunks_queue = asyncio.Queue()
|
|
|
|
# Sending None in chunk_queue supersedes setting message_complete
|
|
# when switching to "streaming". If message is already complete
|
|
# when the switch happens, put() didn't send None, so we have to.
|
|
if self.message_complete.is_set():
|
|
await self.chunks_queue.put(None)
|
|
|
|
# Locking with get_in_progress ensures only one task can get here
|
|
for c in chunks:
|
|
yield c
|
|
while True:
|
|
chunk = await self.chunks_queue.get()
|
|
if chunk is None:
|
|
break
|
|
yield chunk
|
|
|
|
# Unpause the transport, if its paused
|
|
if self.paused:
|
|
self.protocol.resume_frames()
|
|
self.paused = False
|
|
if not self.get_in_progress: # no cov
|
|
# This should be guarded against with the read_mutex,
|
|
# exception is here as a failsafe
|
|
raise ServerError(
|
|
"State of Websocket frame assembler was modified while an "
|
|
"asynchronous get was in progress."
|
|
)
|
|
self.get_in_progress = False
|
|
if not self.message_complete.is_set(): # no cov
|
|
# This should be guarded against with the read_mutex,
|
|
# exception is here as a failsafe
|
|
raise ServerError(
|
|
"Websocket frame assembler chunks queue ended before "
|
|
"message was complete."
|
|
)
|
|
self.message_complete.clear()
|
|
if self.message_fetched.is_set(): # no cov
|
|
# This should be guarded against with the read_mutex,
|
|
# and get_in_progress check, this exception is
|
|
# here as a failsafe
|
|
raise ServerError(
|
|
"Websocket get_iter() found a message when state was "
|
|
"already fetched."
|
|
)
|
|
|
|
self.message_fetched.set()
|
|
# this should already be empty, but set it here for safety
|
|
self.chunks = []
|
|
self.chunks_queue = None
|
|
|
|
async def put(self, frame: Frame) -> None:
|
|
"""
|
|
Add ``frame`` to the next message.
|
|
When ``frame`` is the final frame in a message, :meth:`put` waits
|
|
until the message is fetched, either by calling :meth:`get` or by
|
|
iterating the return value of :meth:`get_iter`.
|
|
:meth:`put` assumes that the stream of frames respects the protocol.
|
|
If it doesn't, the behavior is undefined.
|
|
"""
|
|
|
|
async with self.write_mutex:
|
|
if frame.opcode is Opcode.TEXT:
|
|
self.decoder = UTF8Decoder(errors="strict")
|
|
elif frame.opcode is Opcode.BINARY:
|
|
self.decoder = None
|
|
elif frame.opcode is Opcode.CONT:
|
|
pass
|
|
else:
|
|
# Ignore control frames.
|
|
return
|
|
data: Data
|
|
if self.decoder is not None:
|
|
data = self.decoder.decode(frame.data, frame.fin)
|
|
else:
|
|
data = frame.data
|
|
if self.chunks_queue is None:
|
|
self.chunks.append(data)
|
|
else:
|
|
await self.chunks_queue.put(data)
|
|
|
|
if not frame.fin:
|
|
return
|
|
if not self.get_in_progress:
|
|
# nobody is waiting for this frame, so try to pause subsequent
|
|
# frames at the protocol level
|
|
self.paused = self.protocol.pause_frames()
|
|
# Message is complete. Wait until it's fetched to return.
|
|
|
|
if self.chunks_queue is not None:
|
|
await self.chunks_queue.put(None)
|
|
if self.message_complete.is_set():
|
|
# This should be guarded against with the write_mutex
|
|
raise ServerError(
|
|
"Websocket put() got a new message when a message was "
|
|
"already in its chamber."
|
|
)
|
|
self.message_complete.set() # Signal to get() it can serve the
|
|
if self.message_fetched.is_set():
|
|
# This should be guarded against with the write_mutex
|
|
raise ServerError(
|
|
"Websocket put() got a new message when the previous "
|
|
"message was not yet fetched."
|
|
)
|
|
|
|
# Allow get() to run and eventually set the event.
|
|
await self.message_fetched.wait()
|
|
self.message_fetched.clear()
|
|
self.decoder = None
|