"""Connection session."""

from __future__ import annotations

from collections.abc import Callable, Hashable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, Literal

from aiohttp import web
import voluptuous as vol

from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.http import current_request
from homeassistant.util.json import JsonValueType

from . import const, messages
from .util import describe_request

if TYPE_CHECKING:
    from .http import WebSocketAdapter


current_connection = ContextVar["ActiveConnection | None"](
    "current_connection", default=None
)

type MessageHandler = Callable[[HomeAssistant, ActiveConnection, dict[str, Any]], None]
type BinaryHandler = Callable[[HomeAssistant, ActiveConnection, bytes], None]


class ActiveConnection:
    """Handle an active websocket client connection."""

    __slots__ = (
        "logger",
        "hass",
        "send_message",
        "user",
        "refresh_token_id",
        "subscriptions",
        "last_id",
        "can_coalesce",
        "supported_features",
        "handlers",
        "binary_handlers",
    )

    def __init__(
        self,
        logger: WebSocketAdapter,
        hass: HomeAssistant,
        send_message: Callable[[bytes | str | dict[str, Any]], None],
        user: User,
        refresh_token: RefreshToken,
    ) -> None:
        """Initialize an active connection."""
        self.logger = logger
        self.hass = hass
        self.send_message = send_message
        self.user = user
        self.refresh_token_id = refresh_token.id
        self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
        self.last_id = 0
        self.can_coalesce = False
        self.supported_features: dict[str, float] = {}
        self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = (
            self.hass.data[const.DOMAIN]
        )
        self.binary_handlers: list[BinaryHandler | None] = []
        current_connection.set(self)

    def __repr__(self) -> str:
        """Return the representation."""
        return f"<ActiveConnection {self.get_description(None)}>"

    def set_supported_features(self, features: dict[str, float]) -> None:
        """Set supported features."""
        self.supported_features = features
        self.can_coalesce = const.FEATURE_COALESCE_MESSAGES in features

    def get_description(self, request: web.Request | None) -> str:
        """Return a description of the connection."""
        description = self.user.name or ""
        if request:
            description += " " + describe_request(request)
        return description

    def context(self, msg: dict[str, Any]) -> Context:
        """Return a context."""
        return Context(user_id=self.user.id)

    @callback
    def async_register_binary_handler(
        self, handler: BinaryHandler
    ) -> tuple[int, Callable[[], None]]:
        """Register a temporary binary handler for this connection.

        Returns a binary handler_id (1 byte) and a callback to unregister the handler.
        """
        if len(self.binary_handlers) < 255:
            index = len(self.binary_handlers)
            self.binary_handlers.append(None)
        else:
            # Once the list is full, we search for a None entry to reuse.
            index = None
            for idx, existing in enumerate(self.binary_handlers):
                if existing is None:
                    index = idx
                    break

        if index is None:
            raise RuntimeError("Too many binary handlers registered")

        self.binary_handlers[index] = handler

        @callback
        def unsub() -> None:
            """Unregister the handler."""
            assert index is not None
            self.binary_handlers[index] = None

        return index + 1, unsub

    @callback
    def send_result(self, msg_id: int, result: Any | None = None) -> None:
        """Send a result message."""
        self.send_message(messages.result_message(msg_id, result))

    @callback
    def send_event(self, msg_id: int, event: Any | None = None) -> None:
        """Send a event message."""
        self.send_message(messages.event_message(msg_id, event))

    @callback
    def send_error(
        self,
        msg_id: int,
        code: str,
        message: str,
        translation_key: str | None = None,
        translation_domain: str | None = None,
        translation_placeholders: dict[str, Any] | None = None,
    ) -> None:
        """Send an error message."""
        self.send_message(
            messages.error_message(
                msg_id,
                code,
                message,
                translation_key=translation_key,
                translation_domain=translation_domain,
                translation_placeholders=translation_placeholders,
            )
        )

    @callback
    def async_handle_binary(self, handler_id: int, payload: bytes) -> None:
        """Handle a single incoming binary message."""
        index = handler_id - 1
        if (
            index < 0
            or index >= len(self.binary_handlers)
            or (handler := self.binary_handlers[index]) is None
        ):
            self.logger.error(
                "Received binary message for non-existing handler %s", handler_id
            )
            return

        try:
            handler(self.hass, self, payload)
        except Exception:
            self.logger.exception("Error handling binary message")
            self.binary_handlers[index] = None

    @callback
    def async_handle(self, msg: JsonValueType) -> None:
        """Handle a single incoming message."""
        if (
            # Not using isinstance as we don't care about children
            # as these are always coming from JSON
            type(msg) is not dict  # noqa: E721
            or (
                not (cur_id := msg.get("id"))
                or type(cur_id) is not int  # noqa: E721
                or cur_id < 0
                or not (type_ := msg.get("type"))
                or type(type_) is not str  # noqa: E721
            )
        ):
            self.logger.error("Received invalid command: %s", msg)
            id_ = msg.get("id") if isinstance(msg, dict) else 0
            self.send_message(
                messages.error_message(
                    id_,  # type: ignore[arg-type]
                    const.ERR_INVALID_FORMAT,
                    "Message incorrectly formatted.",
                )
            )
            return

        if cur_id <= self.last_id:
            self.send_message(
                messages.error_message(
                    cur_id, const.ERR_ID_REUSE, "Identifier values have to increase."
                )
            )
            return

        if not (handler_schema := self.handlers.get(type_)):
            self.logger.info("Received unknown command: %s", type_)
            self.send_message(
                messages.error_message(
                    cur_id, const.ERR_UNKNOWN_COMMAND, "Unknown command."
                )
            )
            return

        handler, schema = handler_schema

        try:
            if schema is False:
                if len(msg) > 2:
                    raise vol.Invalid("extra keys not allowed")  # noqa: TRY301
                handler(self.hass, self, msg)
            else:
                handler(self.hass, self, schema(msg))
        except Exception as err:  # noqa: BLE001
            self.async_handle_exception(msg, err)

        self.last_id = cur_id

    @callback
    def async_handle_close(self) -> None:
        """Handle closing down connection."""
        for unsub in self.subscriptions.values():
            try:
                unsub()
            except Exception:
                # If one fails, make sure we still try the rest
                self.logger.exception(
                    "Error unsubscribing from subscription: %s", unsub
                )
        self.subscriptions.clear()
        self.send_message = self._connect_closed_error
        current_request.set(None)
        current_connection.set(None)

    @callback
    def _connect_closed_error(
        self, msg: bytes | str | dict[str, Any] | Callable[[], str]
    ) -> None:
        """Send a message when the connection is closed."""
        self.logger.debug("Tried to send message %s on closed connection", msg)

    @callback
    def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
        """Handle an exception while processing a handler."""
        log_handler = self.logger.error

        code = const.ERR_UNKNOWN_ERROR
        err_message: str | None = None
        translation_domain: str | None = None
        translation_key: str | None = None
        translation_placeholders: dict[str, Any] | None = None

        if isinstance(err, Unauthorized):
            code = const.ERR_UNAUTHORIZED
            err_message = "Unauthorized"
        elif isinstance(err, vol.Invalid):
            code = const.ERR_INVALID_FORMAT
            err_message = vol.humanize.humanize_error(msg, err)
        elif isinstance(err, TimeoutError):
            code = const.ERR_TIMEOUT
            err_message = "Timeout"
        elif isinstance(err, HomeAssistantError):
            err_message = str(err)
            code = const.ERR_HOME_ASSISTANT_ERROR
            translation_domain = err.translation_domain
            translation_key = err.translation_key
            translation_placeholders = err.translation_placeholders

        # This if-check matches all other errors but also matches errors which
        # result in an empty message. In that case we will also log the stack
        # trace so it can be fixed.
        if not err_message:
            err_message = "Unknown error"
            log_handler = self.logger.exception

        self.send_message(
            messages.error_message(
                msg["id"],
                code,
                err_message,
                translation_domain=translation_domain,
                translation_key=translation_key,
                translation_placeholders=translation_placeholders,
            )
        )

        if code:
            err_message += f" ({code})"
        err_message += " " + self.get_description(current_request.get())

        log_handler("Error handling message: %s", err_message)