"""Websocket tests for Voice Assistant integration."""

from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import ANY, patch

import pytest

from homeassistant.components import conversation
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
    STORAGE_KEY,
    STORAGE_VERSION,
    STORAGE_VERSION_MINOR,
    Pipeline,
    PipelineData,
    PipelineStorageCollection,
    PipelineStore,
    async_create_default_pipeline,
    async_get_pipeline,
    async_get_pipelines,
    async_migrate_engine,
    async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component

from . import MANY_LANGUAGES
from .conftest import MockSttProvider, MockTTSProvider

from tests.common import flush_store


@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
    """Load the homeassistant integration."""
    with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
        yield


@pytest.fixture(autouse=True)
async def load_homeassistant(hass) -> None:
    """Load the homeassistant integration."""
    assert await async_setup_component(hass, "homeassistant", {})


async def test_load_pipelines(hass: HomeAssistant, init_components) -> None:
    """Make sure that we can load/save data correctly."""

    pipelines = [
        {
            "conversation_engine": "conversation_engine_1",
            "conversation_language": "language_1",
            "language": "language_1",
            "name": "name_1",
            "stt_engine": "stt_engine_1",
            "stt_language": "language_1",
            "tts_engine": "tts_engine_1",
            "tts_language": "language_1",
            "tts_voice": "Arnold Schwarzenegger",
            "wake_word_entity": "wakeword_entity_1",
            "wake_word_id": "wakeword_id_1",
        },
        {
            "conversation_engine": "conversation_engine_2",
            "conversation_language": "language_2",
            "language": "language_2",
            "name": "name_2",
            "stt_engine": "stt_engine_2",
            "stt_language": "language_1",
            "tts_engine": "tts_engine_2",
            "tts_language": "language_2",
            "tts_voice": "The Voice",
            "wake_word_entity": "wakeword_entity_2",
            "wake_word_id": "wakeword_id_2",
        },
        {
            "conversation_engine": "conversation_engine_3",
            "conversation_language": "language_3",
            "language": "language_3",
            "name": "name_3",
            "stt_engine": None,
            "stt_language": None,
            "tts_engine": None,
            "tts_language": None,
            "tts_voice": None,
            "wake_word_entity": "wakeword_entity_3",
            "wake_word_id": "wakeword_id_3",
        },
    ]

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store1 = pipeline_data.pipeline_store
    pipeline_ids = [
        (await store1.async_create_item(pipeline)).id for pipeline in pipelines
    ]
    assert len(store1.data) == 4  # 3 manually created plus a default pipeline
    assert store1.async_get_preferred_item() == list(store1.data)[0]

    await store1.async_delete_item(pipeline_ids[1])
    assert len(store1.data) == 3

    store2 = PipelineStorageCollection(
        PipelineStore(
            hass, STORAGE_VERSION, STORAGE_KEY, minor_version=STORAGE_VERSION_MINOR
        )
    )
    await flush_store(store1.store)
    await store2.async_load()

    assert len(store2.data) == 3

    assert store1.data is not store2.data
    assert store1.data == store2.data
    assert store1.async_get_preferred_item() == store2.async_get_preferred_item()


async def test_loading_pipelines_from_storage(
    hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
    """Test loading stored pipelines on start."""
    async_migrate_engine(
        hass,
        "conversation",
        conversation.OLD_HOME_ASSISTANT_AGENT,
        conversation.HOME_ASSISTANT_AGENT,
    )
    id_1 = "01GX8ZWBAQYWNB1XV3EXEZ75DY"
    hass_storage[STORAGE_KEY] = {
        "version": STORAGE_VERSION,
        "minor_version": STORAGE_VERSION_MINOR,
        "key": "assist_pipeline.pipelines",
        "data": {
            "items": [
                {
                    "conversation_engine": conversation.OLD_HOME_ASSISTANT_AGENT,
                    "conversation_language": "language_1",
                    "id": id_1,
                    "language": "language_1",
                    "name": "name_1",
                    "stt_engine": "stt_engine_1",
                    "stt_language": "language_1",
                    "tts_engine": "tts_engine_1",
                    "tts_language": "language_1",
                    "tts_voice": "Arnold Schwarzenegger",
                    "wake_word_entity": "wakeword_entity_1",
                    "wake_word_id": "wakeword_id_1",
                },
                {
                    "conversation_engine": "conversation_engine_2",
                    "conversation_language": "language_2",
                    "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
                    "language": "language_2",
                    "name": "name_2",
                    "stt_engine": "stt_engine_2",
                    "stt_language": "language_2",
                    "tts_engine": "tts_engine_2",
                    "tts_language": "language_2",
                    "tts_voice": "The Voice",
                    "wake_word_entity": "wakeword_entity_2",
                    "wake_word_id": "wakeword_id_2",
                },
                {
                    "conversation_engine": "conversation_engine_3",
                    "conversation_language": "language_3",
                    "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
                    "language": "language_3",
                    "name": "name_3",
                    "stt_engine": None,
                    "stt_language": None,
                    "tts_engine": None,
                    "tts_language": None,
                    "tts_voice": None,
                    "wake_word_entity": "wakeword_entity_3",
                    "wake_word_id": "wakeword_id_3",
                },
            ],
            "preferred_item": id_1,
        },
    }

    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 3
    assert store.async_get_preferred_item() == id_1
    assert store.data[id_1].conversation_engine == conversation.HOME_ASSISTANT_AGENT


async def test_migrate_pipeline_store(
    hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
    """Test loading stored pipelines from an older version."""
    hass_storage[STORAGE_KEY] = {
        "version": 1,
        "minor_version": 1,
        "key": "assist_pipeline.pipelines",
        "data": {
            "items": [
                {
                    "conversation_engine": "conversation_engine_1",
                    "conversation_language": "language_1",
                    "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
                    "language": "language_1",
                    "name": "name_1",
                    "stt_engine": "stt_engine_1",
                    "stt_language": "language_1",
                    "tts_engine": "tts_engine_1",
                    "tts_language": "language_1",
                    "tts_voice": "Arnold Schwarzenegger",
                },
                {
                    "conversation_engine": "conversation_engine_2",
                    "conversation_language": "language_2",
                    "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
                    "language": "language_2",
                    "name": "name_2",
                    "stt_engine": "stt_engine_2",
                    "stt_language": "language_2",
                    "tts_engine": "tts_engine_2",
                    "tts_language": "language_2",
                    "tts_voice": "The Voice",
                },
                {
                    "conversation_engine": "conversation_engine_3",
                    "conversation_language": "language_3",
                    "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
                    "language": "language_3",
                    "name": "name_3",
                    "stt_engine": None,
                    "stt_language": None,
                    "tts_engine": None,
                    "tts_language": None,
                    "tts_voice": None,
                },
            ],
            "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
        },
    }

    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 3
    assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY"


async def test_create_default_pipeline(
    hass: HomeAssistant, init_supporting_components
) -> None:
    """Test async_create_default_pipeline."""
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    assert (
        await async_create_default_pipeline(
            hass,
            stt_engine_id="bla",
            tts_engine_id="bla",
            pipeline_name="Bla pipeline",
        )
        is None
    )
    assert await async_create_default_pipeline(
        hass,
        stt_engine_id="test",
        tts_engine_id="test",
        pipeline_name="Test pipeline",
    ) == Pipeline(
        conversation_engine="conversation.home_assistant",
        conversation_language="en",
        id=ANY,
        language="en",
        name="Test pipeline",
        stt_engine="test",
        stt_language="en-US",
        tts_engine="test",
        tts_language="en-US",
        tts_voice="james_earl_jones",
        wake_word_entity=None,
        wake_word_id=None,
    )


async def test_get_pipeline(hass: HomeAssistant) -> None:
    """Test async_get_pipeline."""
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    # Test we get the preferred pipeline if none is specified
    pipeline = async_get_pipeline(hass, None)
    assert pipeline.id == store.async_get_preferred_item()

    # Test getting a specific pipeline
    assert pipeline is async_get_pipeline(hass, pipeline.id)


async def test_get_pipelines(hass: HomeAssistant) -> None:
    """Test async_get_pipelines."""
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    pipelines = async_get_pipelines(hass)
    assert list(pipelines) == [
        Pipeline(
            conversation_engine="conversation.home_assistant",
            conversation_language="en",
            id=ANY,
            language="en",
            name="Home Assistant",
            stt_engine=None,
            stt_language=None,
            tts_engine=None,
            tts_language=None,
            tts_voice=None,
            wake_word_entity=None,
            wake_word_id=None,
        )
    ]


@pytest.mark.parametrize(
    ("ha_language", "ha_country", "conv_language", "pipeline_language"),
    [
        ("en", None, "en", "en"),
        ("de", "de", "de", "de"),
        ("de", "ch", "de-CH", "de"),
        ("en", "us", "en", "en"),
        ("en", "uk", "en", "en"),
        ("pt", "pt", "pt", "pt"),
        ("pt", "br", "pt-br", "pt"),
    ],
)
async def test_default_pipeline_no_stt_tts(
    hass: HomeAssistant,
    ha_language: str,
    ha_country: str | None,
    conv_language: str,
    pipeline_language: str,
) -> None:
    """Test async_get_pipeline."""
    hass.config.country = ha_country
    hass.config.language = ha_language
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    # Check the default pipeline
    pipeline = async_get_pipeline(hass, None)
    assert pipeline == Pipeline(
        conversation_engine="conversation.home_assistant",
        conversation_language=conv_language,
        id=pipeline.id,
        language=pipeline_language,
        name="Home Assistant",
        stt_engine=None,
        stt_language=None,
        tts_engine=None,
        tts_language=None,
        tts_voice=None,
        wake_word_entity=None,
        wake_word_id=None,
    )


@pytest.mark.parametrize(
    (
        "ha_language",
        "ha_country",
        "conv_language",
        "pipeline_language",
        "stt_language",
        "tts_language",
    ),
    [
        ("en", None, "en", "en", "en", "en"),
        ("de", "de", "de", "de", "de", "de"),
        ("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
        ("en", "us", "en", "en", "en", "en"),
        ("en", "uk", "en", "en", "en", "en"),
        ("pt", "pt", "pt", "pt", "pt", "pt"),
        ("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
    ],
)
async def test_default_pipeline(
    hass: HomeAssistant,
    init_supporting_components,
    mock_stt_provider: MockSttProvider,
    mock_tts_provider: MockTTSProvider,
    ha_language: str,
    ha_country: str | None,
    conv_language: str,
    pipeline_language: str,
    stt_language: str,
    tts_language: str,
) -> None:
    """Test async_get_pipeline."""
    hass.config.country = ha_country
    hass.config.language = ha_language

    with (
        patch.object(mock_stt_provider, "_supported_languages", MANY_LANGUAGES),
        patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES),
    ):
        assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    # Check the default pipeline
    pipeline = async_get_pipeline(hass, None)
    assert pipeline == Pipeline(
        conversation_engine="conversation.home_assistant",
        conversation_language=conv_language,
        id=pipeline.id,
        language=pipeline_language,
        name="Home Assistant",
        stt_engine="test",
        stt_language=stt_language,
        tts_engine="test",
        tts_language=tts_language,
        tts_voice=None,
        wake_word_entity=None,
        wake_word_id=None,
    )


async def test_default_pipeline_unsupported_stt_language(
    hass: HomeAssistant,
    init_supporting_components,
    mock_stt_provider: MockSttProvider,
) -> None:
    """Test async_get_pipeline."""
    with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
        assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    # Check the default pipeline
    pipeline = async_get_pipeline(hass, None)
    assert pipeline == Pipeline(
        conversation_engine="conversation.home_assistant",
        conversation_language="en",
        id=pipeline.id,
        language="en",
        name="Home Assistant",
        stt_engine=None,
        stt_language=None,
        tts_engine="test",
        tts_language="en-US",
        tts_voice="james_earl_jones",
        wake_word_entity=None,
        wake_word_id=None,
    )


async def test_default_pipeline_unsupported_tts_language(
    hass: HomeAssistant,
    init_supporting_components,
    mock_tts_provider: MockTTSProvider,
) -> None:
    """Test async_get_pipeline."""
    with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
        assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    # Check the default pipeline
    pipeline = async_get_pipeline(hass, None)
    assert pipeline == Pipeline(
        conversation_engine="conversation.home_assistant",
        conversation_language="en",
        id=pipeline.id,
        language="en",
        name="Home Assistant",
        stt_engine="test",
        stt_language="en-US",
        tts_engine=None,
        tts_language=None,
        tts_voice=None,
        wake_word_entity=None,
        wake_word_id=None,
    )


async def test_update_pipeline(
    hass: HomeAssistant,
    hass_storage: dict[str, Any],
) -> None:
    """Test async_update_pipeline."""
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipelines = async_get_pipelines(hass)
    pipelines = list(pipelines)
    assert pipelines == [
        Pipeline(
            conversation_engine="conversation.home_assistant",
            conversation_language="en",
            id=ANY,
            language="en",
            name="Home Assistant",
            stt_engine=None,
            stt_language=None,
            tts_engine=None,
            tts_language=None,
            tts_voice=None,
            wake_word_entity=None,
            wake_word_id=None,
        )
    ]

    pipeline = pipelines[0]
    await async_update_pipeline(
        hass,
        pipeline,
        conversation_engine="homeassistant_1",
        conversation_language="de",
        language="de",
        name="Home Assistant 1",
        stt_engine="stt.test_1",
        stt_language="de",
        tts_engine="test_1",
        tts_language="de",
        tts_voice="test_voice",
        wake_word_entity="wake_work.test_1",
        wake_word_id="wake_word_id_1",
    )

    pipelines = async_get_pipelines(hass)
    pipelines = list(pipelines)
    pipeline = pipelines[0]
    assert pipelines == [
        Pipeline(
            conversation_engine="homeassistant_1",
            conversation_language="de",
            id=pipeline.id,
            language="de",
            name="Home Assistant 1",
            stt_engine="stt.test_1",
            stt_language="de",
            tts_engine="test_1",
            tts_language="de",
            tts_voice="test_voice",
            wake_word_entity="wake_work.test_1",
            wake_word_id="wake_word_id_1",
        )
    ]
    assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
    assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
        "conversation_engine": "homeassistant_1",
        "conversation_language": "de",
        "id": pipeline.id,
        "language": "de",
        "name": "Home Assistant 1",
        "stt_engine": "stt.test_1",
        "stt_language": "de",
        "tts_engine": "test_1",
        "tts_language": "de",
        "tts_voice": "test_voice",
        "wake_word_entity": "wake_work.test_1",
        "wake_word_id": "wake_word_id_1",
    }

    await async_update_pipeline(
        hass,
        pipeline,
        stt_engine="stt.test_2",
        stt_language="en",
        tts_engine="test_2",
        tts_language="en",
    )

    pipelines = async_get_pipelines(hass)
    pipelines = list(pipelines)
    assert pipelines == [
        Pipeline(
            conversation_engine="homeassistant_1",
            conversation_language="de",
            id=pipeline.id,
            language="de",
            name="Home Assistant 1",
            stt_engine="stt.test_2",
            stt_language="en",
            tts_engine="test_2",
            tts_language="en",
            tts_voice="test_voice",
            wake_word_entity="wake_work.test_1",
            wake_word_id="wake_word_id_1",
        )
    ]
    assert len(hass_storage[STORAGE_KEY]["data"]["items"]) == 1
    assert hass_storage[STORAGE_KEY]["data"]["items"][0] == {
        "conversation_engine": "homeassistant_1",
        "conversation_language": "de",
        "id": pipeline.id,
        "language": "de",
        "name": "Home Assistant 1",
        "stt_engine": "stt.test_2",
        "stt_language": "en",
        "tts_engine": "test_2",
        "tts_language": "en",
        "tts_voice": "test_voice",
        "wake_word_entity": "wake_work.test_1",
        "wake_word_id": "wake_word_id_1",
    }


async def test_migrate_after_load(
    hass: HomeAssistant, init_supporting_components
) -> None:
    """Test migrating an engine after done loading."""
    assert await async_setup_component(hass, "assist_pipeline", {})

    pipeline_data: PipelineData = hass.data[DOMAIN]
    store = pipeline_data.pipeline_store
    assert len(store.data) == 1

    assert (
        await async_create_default_pipeline(
            hass,
            stt_engine_id="bla",
            tts_engine_id="bla",
            pipeline_name="Bla pipeline",
        )
        is None
    )
    pipeline = await async_create_default_pipeline(
        hass,
        stt_engine_id="test",
        tts_engine_id="test",
        pipeline_name="Test pipeline",
    )
    assert pipeline is not None

    async_migrate_engine(hass, "stt", "test", "stt.test")
    async_migrate_engine(hass, "tts", "test", "tts.test")

    await hass.async_block_till_done(wait_background_tasks=True)

    pipeline_updated = async_get_pipeline(hass, pipeline.id)

    assert pipeline_updated.stt_engine == "stt.test"
    assert pipeline_updated.tts_engine == "tts.test"