import asyncio
import codecs
from typing import TYPE_CHECKING, AsyncIterator, List, Optional

from websockets.frames import Frame, Opcode
from websockets.typing import Data

from sanic.exceptions import ServerError

if TYPE_CHECKING:
    from .impl import WebsocketImplProtocol

UTF8Decoder = codecs.getincrementaldecoder("utf-8")


class WebsocketFrameAssembler:
    """
    Assemble a message from frames.
    Code borrowed from aaugustin/websockets project:
    https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py
    """

    __slots__ = (
        "protocol",
        "read_mutex",
        "write_mutex",
        "message_complete",
        "message_fetched",
        "get_in_progress",
        "decoder",
        "completed_queue",
        "chunks",
        "chunks_queue",
        "paused",
        "get_id",
        "put_id",
    )
    if TYPE_CHECKING:
        protocol: "WebsocketImplProtocol"
        read_mutex: asyncio.Lock
        write_mutex: asyncio.Lock
        message_complete: asyncio.Event
        message_fetched: asyncio.Event
        completed_queue: asyncio.Queue
        get_in_progress: bool
        decoder: Optional[codecs.IncrementalDecoder]
        # For streaming chunks rather than messages:
        chunks: List[Data]
        chunks_queue: Optional[asyncio.Queue[Optional[Data]]]
        paused: bool

    def __init__(self, protocol) -> None:
        self.protocol = protocol

        self.read_mutex = asyncio.Lock()
        self.write_mutex = asyncio.Lock()

        self.completed_queue = asyncio.Queue(maxsize=1)  # type: asyncio.Queue[Data]

        # put() sets this event to tell get() that a message can be fetched.
        self.message_complete = asyncio.Event()
        # get() sets this event to let put()
        self.message_fetched = asyncio.Event()

        # This flag prevents concurrent calls to get() by user code.
        self.get_in_progress = False

        # Decoder for text frames, None for binary frames.
        self.decoder = None

        # Buffer data from frames belonging to the same message.
        self.chunks = []

        # When switching from "buffering" to "streaming", we use a thread-safe
        # queue for transferring frames from the writing thread (library code)
        # to the reading thread (user code). We're buffering when chunks_queue
        # is None and streaming when it's a Queue. None is a sentinel
        # value marking the end of the stream, superseding message_complete.

        # Stream data from frames belonging to the same message.
        self.chunks_queue = None

        # Flag to indicate we've paused the protocol
        self.paused = False

    async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
        """
        Read the next message.
        :meth:`get` returns a single :class:`str` or :class:`bytes`.
        If the :message was fragmented, :meth:`get` waits until the last frame
        is received, then it reassembles the message.
        If ``timeout`` is set and elapses before a complete message is
        received, :meth:`get` returns ``None``.
        """
        completed: bool
        async with self.read_mutex:
            if timeout is not None and timeout <= 0:
                if not self.message_complete.is_set():
                    return None
            if self.get_in_progress:
                # This should be guarded against with the read_mutex,
                # exception is only here as a failsafe
                raise ServerError(
                    "Called get() on Websocket frame assembler "
                    "while asynchronous get is already in progress."
                )
            self.get_in_progress = True

            # If the message_complete event isn't set yet, release the lock to
            # allow put() to run and eventually set it.
            # Locking with get_in_progress ensures only one task can get here.
            if timeout is None:
                completed = await self.message_complete.wait()
            elif timeout <= 0:
                completed = self.message_complete.is_set()
            else:
                try:
                    await asyncio.wait_for(
                        self.message_complete.wait(), timeout=timeout
                    )
                except asyncio.TimeoutError:
                    ...
                finally:
                    completed = self.message_complete.is_set()

            # Unpause the transport, if its paused
            if self.paused:
                self.protocol.resume_frames()
                self.paused = False
            if not self.get_in_progress:  # no cov
                # This should be guarded against with the read_mutex,
                # exception is here as a failsafe
                raise ServerError(
                    "State of Websocket frame assembler was modified while an "
                    "asynchronous get was in progress."
                )
            self.get_in_progress = False

            # Waiting for a complete message timed out.
            if not completed:
                return None
            if not self.message_complete.is_set():
                return None

            self.message_complete.clear()

            joiner: Data = b"" if self.decoder is None else ""
            # mypy cannot figure out that chunks have the proper type.
            message: Data = joiner.join(self.chunks)  # type: ignore
            if self.message_fetched.is_set():
                # This should be guarded against with the read_mutex,
                # and get_in_progress check, this exception is here
                # as a failsafe
                raise ServerError(
                    "Websocket get() found a message when " "state was already fetched."
                )
            self.message_fetched.set()
            self.chunks = []
            # this should already be None, but set it here for safety
            self.chunks_queue = None
            return message

    async def get_iter(self) -> AsyncIterator[Data]:
        """
        Stream the next message.
        Iterating the return value of :meth:`get_iter` yields a :class:`str`
        or :class:`bytes` for each frame in the message.
        """
        async with self.read_mutex:
            if self.get_in_progress:
                # This should be guarded against with the read_mutex,
                # exception is only here as a failsafe
                raise ServerError(
                    "Called get_iter on Websocket frame assembler "
                    "while asynchronous get is already in progress."
                )
            self.get_in_progress = True

            chunks = self.chunks
            self.chunks = []
            self.chunks_queue = asyncio.Queue()

            # Sending None in chunk_queue supersedes setting message_complete
            # when switching to "streaming". If message is already complete
            # when the switch happens, put() didn't send None, so we have to.
            if self.message_complete.is_set():
                await self.chunks_queue.put(None)

            # Locking with get_in_progress ensures only one task can get here
            for c in chunks:
                yield c
            while True:
                chunk = await self.chunks_queue.get()
                if chunk is None:
                    break
                yield chunk

            # Unpause the transport, if its paused
            if self.paused:
                self.protocol.resume_frames()
                self.paused = False
            if not self.get_in_progress:  # no cov
                # This should be guarded against with the read_mutex,
                # exception is here as a failsafe
                raise ServerError(
                    "State of Websocket frame assembler was modified while an "
                    "asynchronous get was in progress."
                )
            self.get_in_progress = False
            if not self.message_complete.is_set():  # no cov
                # This should be guarded against with the read_mutex,
                # exception is here as a failsafe
                raise ServerError(
                    "Websocket frame assembler chunks queue ended before "
                    "message was complete."
                )
            self.message_complete.clear()
            if self.message_fetched.is_set():  # no cov
                # This should be guarded against with the read_mutex,
                # and get_in_progress check, this exception is
                # here as a failsafe
                raise ServerError(
                    "Websocket get_iter() found a message when state was "
                    "already fetched."
                )

            self.message_fetched.set()
            # this should already be empty, but set it here for safety
            self.chunks = []
            self.chunks_queue = None

    async def put(self, frame: Frame) -> None:
        """
        Add ``frame`` to the next message.
        When ``frame`` is the final frame in a message, :meth:`put` waits
        until the message is fetched, either by calling :meth:`get` or by
        iterating the return value of :meth:`get_iter`.
        :meth:`put` assumes that the stream of frames respects the protocol.
        If it doesn't, the behavior is undefined.
        """

        async with self.write_mutex:
            if frame.opcode is Opcode.TEXT:
                self.decoder = UTF8Decoder(errors="strict")
            elif frame.opcode is Opcode.BINARY:
                self.decoder = None
            elif frame.opcode is Opcode.CONT:
                pass
            else:
                # Ignore control frames.
                return
            data: Data
            if self.decoder is not None:
                data = self.decoder.decode(frame.data, frame.fin)
            else:
                data = frame.data
            if self.chunks_queue is None:
                self.chunks.append(data)
            else:
                await self.chunks_queue.put(data)

            if not frame.fin:
                return
            if not self.get_in_progress:
                # nobody is waiting for this frame, so try to pause subsequent
                # frames at the protocol level
                self.paused = self.protocol.pause_frames()
            # Message is complete. Wait until it's fetched to return.

            if self.chunks_queue is not None:
                await self.chunks_queue.put(None)
            if self.message_complete.is_set():
                # This should be guarded against with the write_mutex
                raise ServerError(
                    "Websocket put() got a new message when a message was "
                    "already in its chamber."
                )
            self.message_complete.set()  # Signal to get() it can serve the
            if self.message_fetched.is_set():
                # This should be guarded against with the write_mutex
                raise ServerError(
                    "Websocket put() got a new message when the previous "
                    "message was not yet fetched."
                )

            # Allow get() to run and eventually set the event.
            await self.message_fetched.wait()
            self.message_fetched.clear()
            self.decoder = None