"""Mobile app websocket API."""
from __future__ import annotations

from functools import wraps

import voluptuous as vol

from homeassistant.components import websocket_api
from homeassistant.core import callback

from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
from .push_notification import PushChannel


@callback
def async_setup_commands(hass):
    """Set up the mobile app websocket API."""
    websocket_api.async_register_command(hass, handle_push_notification_channel)
    websocket_api.async_register_command(hass, handle_push_notification_confirm)


def _ensure_webhook_access(func):
    """Decorate WS function to ensure user owns the webhook ID."""

    @callback
    @wraps(func)
    def with_webhook_access(hass, connection, msg):
        # Validate that the webhook ID is registered to the user of the websocket connection
        config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])

        if config_entry is None:
            connection.send_error(
                msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
            )
            return

        if config_entry.data[CONF_USER_ID] != connection.user.id:
            connection.send_error(
                msg["id"],
                websocket_api.ERR_UNAUTHORIZED,
                "User not linked to this webhook ID",
            )
            return

        func(hass, connection, msg)

    return with_webhook_access


@callback
@_ensure_webhook_access
@websocket_api.websocket_command(
    {
        vol.Required("type"): "mobile_app/push_notification_confirm",
        vol.Required("webhook_id"): str,
        vol.Required("confirm_id"): str,
    }
)
def handle_push_notification_confirm(hass, connection, msg):
    """Confirm receipt of a push notification."""
    channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
        msg["webhook_id"]
    )
    if channel is None:
        connection.send_error(
            msg["id"],
            websocket_api.ERR_NOT_FOUND,
            "Push notification channel not found",
        )
        return

    if channel.async_confirm_notification(msg["confirm_id"]):
        connection.send_result(msg["id"])
    else:
        connection.send_error(
            msg["id"],
            websocket_api.ERR_NOT_FOUND,
            "Push notification channel not found",
        )


@websocket_api.websocket_command(
    {
        vol.Required("type"): "mobile_app/push_notification_channel",
        vol.Required("webhook_id"): str,
        vol.Optional("support_confirm", default=False): bool,
    }
)
@_ensure_webhook_access
@websocket_api.async_response
async def handle_push_notification_channel(hass, connection, msg):
    """Set up a direct push notification channel."""
    webhook_id = msg["webhook_id"]
    registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]

    if webhook_id in registered_channels:
        await registered_channels[webhook_id].async_teardown()

    @callback
    def on_channel_teardown():
        """Handle teardown."""
        if registered_channels.get(webhook_id) == channel:
            registered_channels.pop(webhook_id)

        # Remove subscription from connection if still exists
        connection.subscriptions.pop(msg["id"], None)

    channel = registered_channels[webhook_id] = PushChannel(
        hass,
        webhook_id,
        msg["support_confirm"],
        lambda data: connection.send_message(
            websocket_api.messages.event_message(msg["id"], data)
        ),
        on_channel_teardown,
    )

    connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
        channel.async_teardown()
    )
    connection.send_result(msg["id"])