"""Test Websocket API http module."""
import asyncio
from datetime import timedelta
from typing import Any, cast
from unittest.mock import patch

from aiohttp import ServerDisconnectedError, WSMsgType, web
import pytest

from homeassistant.components.websocket_api import (
    async_register_command,
    const,
    http,
    websocket_command,
)
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import HomeAssistant, callback
from homeassistant.util.dt import utcnow

from tests.common import async_fire_time_changed
from tests.typing import (
    MockHAClientWebSocket,
    WebSocketGenerator,
)


@pytest.fixture
def mock_low_queue():
    """Mock a low queue."""
    with patch("homeassistant.components.websocket_api.http.MAX_PENDING_MSG", 1):
        yield


@pytest.fixture
def mock_low_peak():
    """Mock a low queue."""
    with patch("homeassistant.components.websocket_api.http.PENDING_MSG_PEAK", 5):
        yield


async def test_pending_msg_overflow(
    hass: HomeAssistant, mock_low_queue, websocket_client: MockHAClientWebSocket
) -> None:
    """Test pending messages overflows."""
    for idx in range(10):
        await websocket_client.send_json({"id": idx + 1, "type": "ping"})
    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.close


async def test_cleanup_on_cancellation(
    hass: HomeAssistant, websocket_client: MockHAClientWebSocket
) -> None:
    """Test cleanup on cancellation."""

    subscriptions = None

    # Register a handler that registers a subscription
    @callback
    @websocket_command(
        {
            "type": "fake_subscription",
        }
    )
    def fake_subscription(
        hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
    ) -> None:
        nonlocal subscriptions
        msg_id: int = msg["id"]
        connection.subscriptions[msg_id] = callback(lambda: None)
        connection.send_result(msg_id)
        subscriptions = connection.subscriptions

    async_register_command(hass, fake_subscription)

    # Register a handler that raises on cancel
    @callback
    @websocket_command(
        {
            "type": "subscription_that_raises_on_cancel",
        }
    )
    def subscription_that_raises_on_cancel(
        hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
    ) -> None:
        nonlocal subscriptions
        msg_id: int = msg["id"]

        @callback
        def _raise():
            raise ValueError()

        connection.subscriptions[msg_id] = _raise
        connection.send_result(msg_id)
        subscriptions = connection.subscriptions

    async_register_command(hass, subscription_that_raises_on_cancel)

    # Register a handler that cancels in handler
    @callback
    @websocket_command(
        {
            "type": "cancel_in_handler",
        }
    )
    def cancel_in_handler(
        hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
    ) -> None:
        raise asyncio.CancelledError()

    async_register_command(hass, cancel_in_handler)

    await websocket_client.send_json({"id": 1, "type": "ping"})
    msg = await websocket_client.receive_json()
    assert msg["id"] == 1
    assert msg["type"] == "pong"
    assert not subscriptions
    await websocket_client.send_json({"id": 2, "type": "fake_subscription"})
    msg = await websocket_client.receive_json()
    assert msg["id"] == 2
    assert msg["type"] == const.TYPE_RESULT
    assert msg["success"]
    assert len(subscriptions) == 2
    await websocket_client.send_json(
        {"id": 3, "type": "subscription_that_raises_on_cancel"}
    )
    msg = await websocket_client.receive_json()
    assert msg["id"] == 3
    assert msg["type"] == const.TYPE_RESULT
    assert msg["success"]
    assert len(subscriptions) == 3
    await websocket_client.send_json({"id": 4, "type": "cancel_in_handler"})
    await hass.async_block_till_done()
    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.close
    assert len(subscriptions) == 0


async def test_pending_msg_peak(
    hass: HomeAssistant,
    mock_low_peak,
    hass_ws_client: WebSocketGenerator,
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test pending msg overflow command."""
    orig_handler = http.WebSocketHandler
    setup_instance: http.WebSocketHandler | None = None

    def instantiate_handler(*args):
        nonlocal setup_instance
        setup_instance = orig_handler(*args)
        return setup_instance

    with patch(
        "homeassistant.components.websocket_api.http.WebSocketHandler",
        instantiate_handler,
    ):
        websocket_client = await hass_ws_client()

    instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)

    # Fill the queue past the allowed peak
    for _ in range(10):
        instance._send_message({"overload": "message"})

    async_fire_time_changed(
        hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
    )

    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.close
    assert "Client unable to keep up with pending messages" in caplog.text
    assert "Stayed over 5 for 5 seconds" in caplog.text
    assert "overload" in caplog.text


async def test_pending_msg_peak_recovery(
    hass: HomeAssistant,
    mock_low_peak,
    hass_ws_client: WebSocketGenerator,
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test pending msg nears the peak but recovers."""
    orig_handler = http.WebSocketHandler
    setup_instance: http.WebSocketHandler | None = None

    def instantiate_handler(*args):
        nonlocal setup_instance
        setup_instance = orig_handler(*args)
        return setup_instance

    with patch(
        "homeassistant.components.websocket_api.http.WebSocketHandler",
        instantiate_handler,
    ):
        websocket_client = await hass_ws_client()

    instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)

    # Make sure the call later is started
    for _ in range(10):
        instance._send_message({})

    for _ in range(10):
        msg = await websocket_client.receive()
        assert msg.type == WSMsgType.TEXT

    instance._send_message({})
    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.TEXT

    # Cleanly shutdown
    instance._send_message({})
    instance._handle_task.cancel()

    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.TEXT
    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.close
    assert "Client unable to keep up with pending messages" not in caplog.text


async def test_pending_msg_peak_but_does_not_overflow(
    hass: HomeAssistant,
    mock_low_peak,
    hass_ws_client: WebSocketGenerator,
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test pending msg hits the low peak but recovers and does not overflow."""
    orig_handler = http.WebSocketHandler
    setup_instance: http.WebSocketHandler | None = None

    def instantiate_handler(*args):
        nonlocal setup_instance
        setup_instance = orig_handler(*args)
        return setup_instance

    with patch(
        "homeassistant.components.websocket_api.http.WebSocketHandler",
        instantiate_handler,
    ):
        websocket_client = await hass_ws_client()

    instance: http.WebSocketHandler = cast(http.WebSocketHandler, setup_instance)

    # Kill writer task and fill queue past peak
    for _ in range(5):
        instance._message_queue.append(None)

    # Trigger the peak check
    instance._send_message({})

    # Clear the queue
    instance._message_queue.clear()

    # Trigger the peak clear
    instance._send_message({})

    async_fire_time_changed(
        hass, utcnow() + timedelta(seconds=const.PENDING_MSG_PEAK_TIME + 1)
    )

    msg = await websocket_client.receive()
    assert msg.type == WSMsgType.TEXT

    assert "Client unable to keep up with pending messages" not in caplog.text


async def test_non_json_message(
    hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
) -> None:
    """Test trying to serialize non JSON objects."""
    bad_data = object()
    hass.states.async_set("test_domain.entity", "testing", {"bad": bad_data})
    await websocket_client.send_json({"id": 5, "type": "get_states"})

    msg = await websocket_client.receive_json()
    assert msg["id"] == 5
    assert msg["type"] == const.TYPE_RESULT
    assert msg["success"]
    assert msg["result"] == []
    assert "Unable to serialize to JSON. Bad data found" in caplog.text
    assert "State: test_domain.entity" in caplog.text
    assert "bad=<object" in caplog.text


async def test_prepare_fail(
    hass: HomeAssistant,
    hass_ws_client: WebSocketGenerator,
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test failing to prepare."""
    with patch(
        "homeassistant.components.websocket_api.http.web.WebSocketResponse.prepare",
        side_effect=(asyncio.TimeoutError, web.WebSocketResponse.prepare),
    ), pytest.raises(ServerDisconnectedError):
        await hass_ws_client(hass)

    assert "Timeout preparing request" in caplog.text


async def test_binary_message(
    hass: HomeAssistant, websocket_client, caplog: pytest.LogCaptureFixture
) -> None:
    """Test binary messages."""
    binary_payloads = {
        104: ([], asyncio.Future()),
        105: ([], asyncio.Future()),
    }

    # Register a handler
    @callback
    @websocket_command(
        {
            "type": "get_binary_message_handler",
        }
    )
    def get_binary_message_handler(
        hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
    ):
        unsub = None

        @callback
        def binary_message_handler(
            hass: HomeAssistant, connection: ActiveConnection, payload: bytes
        ):
            nonlocal unsub
            if msg["id"] == 103:
                raise ValueError("Boom")

            if payload:
                binary_payloads[msg["id"]][0].append(payload)
            else:
                binary_payloads[msg["id"]][1].set_result(
                    b"".join(binary_payloads[msg["id"]][0])
                )
                unsub()

        prefix, unsub = connection.async_register_binary_handler(binary_message_handler)

        connection.send_result(msg["id"], {"prefix": prefix})

    async_register_command(hass, get_binary_message_handler)

    # Register multiple binary handlers
    for i in range(101, 106):
        await websocket_client.send_json(
            {"id": i, "type": "get_binary_message_handler"}
        )
        result = await websocket_client.receive_json()
        assert result["id"] == i
        assert result["type"] == const.TYPE_RESULT
        assert result["success"]
        assert result["result"]["prefix"] == i - 100

    # Send message to binary
    await websocket_client.send_bytes((0).to_bytes(1, "big") + b"test0")
    await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
    await websocket_client.send_bytes((3).to_bytes(1, "big") + b"test3")
    await websocket_client.send_bytes((10).to_bytes(1, "big") + b"test10")
    await websocket_client.send_bytes((4).to_bytes(1, "big") + b"test4")
    await websocket_client.send_bytes((4).to_bytes(1, "big") + b"")
    await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5")
    await websocket_client.send_bytes((5).to_bytes(1, "big") + b"test5-2")
    await websocket_client.send_bytes((5).to_bytes(1, "big") + b"")

    # Verify received
    assert await binary_payloads[104][1] == b"test4"
    assert await binary_payloads[105][1] == b"test5test5-2"
    assert "Error handling binary message" in caplog.text
    assert "Received binary message for non-existing handler 0" in caplog.text
    assert "Received binary message for non-existing handler 3" in caplog.text
    assert "Received binary message for non-existing handler 10" in caplog.text