from typing import TYPE_CHECKING, Optional, Sequence, cast

from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection
from websockets.typing import Subprotocol

from sanic.exceptions import ServerError
from sanic.log import logger
from sanic.server import HttpProtocol

from ..websockets.impl import WebsocketImplProtocol


if TYPE_CHECKING:
    from websockets import http11


class WebSocketProtocol(HttpProtocol):
    __slots__ = (
        "websocket",
        "websocket_timeout",
        "websocket_max_size",
        "websocket_ping_interval",
        "websocket_ping_timeout",
    )

    def __init__(
        self,
        *args,
        websocket_timeout: float = 10.0,
        websocket_max_size: Optional[int] = None,
        websocket_ping_interval: Optional[float] = 20.0,
        websocket_ping_timeout: Optional[float] = 20.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.websocket: Optional[WebsocketImplProtocol] = None
        self.websocket_timeout = websocket_timeout
        self.websocket_max_size = websocket_max_size
        self.websocket_ping_interval = websocket_ping_interval
        self.websocket_ping_timeout = websocket_ping_timeout

    def connection_lost(self, exc):
        if self.websocket is not None:
            self.websocket.connection_lost(exc)
        super().connection_lost(exc)

    def data_received(self, data):
        if self.websocket is not None:
            self.websocket.data_received(data)
        else:
            # Pass it to HttpProtocol handler first
            # That will (hopefully) upgrade it to a websocket.
            super().data_received(data)

    def eof_received(self) -> Optional[bool]:
        if self.websocket is not None:
            return self.websocket.eof_received()
        else:
            return False

    def close(self, timeout: Optional[float] = None):
        # Called by HttpProtocol at the end of connection_task
        # If we've upgraded to websocket, we do our own closing
        if self.websocket is not None:
            # Note, we don't want to use websocket.close()
            # That is used for user's application code to send a
            # websocket close packet. This is different.
            self.websocket.end_connection(1001)
        else:
            super().close()

    def close_if_idle(self):
        # Called by Sanic Server when shutting down
        # If we've upgraded to websocket, shut it down
        if self.websocket is not None:
            if self.websocket.connection.state in (CLOSING, CLOSED):
                return True
            elif self.websocket.loop is not None:
                self.websocket.loop.create_task(self.websocket.close(1001))
            else:
                self.websocket.end_connection(1001)
        else:
            return super().close_if_idle()

    async def websocket_handshake(
        self, request, subprotocols: Optional[Sequence[str]] = None
    ):
        # let the websockets package do the handshake with the client
        try:
            if subprotocols is not None:
                # subprotocols can be a set or frozenset,
                # but ServerConnection needs a list
                subprotocols = cast(
                    Optional[Sequence[Subprotocol]],
                    list(
                        [
                            Subprotocol(subprotocol)
                            for subprotocol in subprotocols
                        ]
                    ),
                )
            ws_conn = ServerConnection(
                max_size=self.websocket_max_size,
                subprotocols=subprotocols,
                state=OPEN,
                logger=logger,
            )
            resp: "http11.Response" = ws_conn.accept(request)
        except Exception:
            msg = (
                "Failed to open a WebSocket connection.\n"
                "See server log for more information.\n"
            )
            raise ServerError(msg, status_code=500)
        if 100 <= resp.status_code <= 299:
            first_line = (
                f"HTTP/1.1 {resp.status_code} {resp.reason_phrase}\r\n"
            ).encode()
            rbody = bytearray(first_line)
            rbody += (
                "".join([f"{k}: {v}\r\n" for k, v in resp.headers.items()])
            ).encode()
            rbody += b"\r\n"
            if resp.body is not None:
                rbody += resp.body
                rbody += b"\r\n\r\n"
            await super().send(rbody)
        else:
            raise ServerError(resp.body, resp.status_code)
        self.websocket = WebsocketImplProtocol(
            ws_conn,
            ping_interval=self.websocket_ping_interval,
            ping_timeout=self.websocket_ping_timeout,
            close_timeout=self.websocket_timeout,
        )
        loop = (
            request.transport.loop
            if hasattr(request, "transport")
            and hasattr(request.transport, "loop")
            else None
        )
        await self.websocket.connection_made(self, loop=loop)
        return self.websocket