"""Classes for voice assistant pipelines."""
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable, Callable
from dataclasses import asdict, dataclass, field
import logging
from typing import Any

import voluptuous as vol

from homeassistant.backports.enum import StrEnum
from homeassistant.components import conversation, media_source, stt, tts, websocket_api
from homeassistant.components.tts.media_source import (
    generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.collection import (
    CollectionError,
    ItemNotFound,
    SerializedStorageCollection,
    StorageCollection,
    StorageCollectionWebsocket,
)
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict

from .const import DOMAIN
from .error import (
    IntentRecognitionError,
    PipelineError,
    SpeechToTextError,
    TextToSpeechError,
)

_LOGGER = logging.getLogger(__name__)

STORAGE_KEY = f"{DOMAIN}.pipelines"
STORAGE_VERSION = 1

STORAGE_FIELDS = {
    vol.Required("conversation_engine"): str,
    vol.Required("language"): str,
    vol.Required("name"): str,
    vol.Required("stt_engine"): str,
    vol.Required("tts_engine"): str,
}

STORED_PIPELINE_RUNS = 10

SAVE_DELAY = 10


async def async_get_pipeline(
    hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
) -> Pipeline | None:
    """Get a pipeline by id or create one for a language."""
    pipeline_data: PipelineData = hass.data[DOMAIN]

    if pipeline_id is not None:
        return pipeline_data.pipeline_store.data.get(pipeline_id)

    # Construct a pipeline for the required/configured language
    language = language or hass.config.language
    return await pipeline_data.pipeline_store.async_create_item(
        {
            "name": language,
            "language": language,
            "stt_engine": None,  # first engine
            "conversation_engine": None,  # first agent
            "tts_engine": None,  # first engine
        }
    )


class PipelineEventType(StrEnum):
    """Event types emitted during a pipeline run."""

    RUN_START = "run-start"
    RUN_END = "run-end"
    STT_START = "stt-start"
    STT_END = "stt-end"
    INTENT_START = "intent-start"
    INTENT_END = "intent-end"
    TTS_START = "tts-start"
    TTS_END = "tts-end"
    ERROR = "error"


@dataclass(frozen=True)
class PipelineEvent:
    """Events emitted during a pipeline run."""

    type: PipelineEventType
    data: dict[str, Any] | None = None
    timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())


PipelineEventCallback = Callable[[PipelineEvent], None]


@dataclass(frozen=True)
class Pipeline:
    """A voice assistant pipeline."""

    conversation_engine: str | None
    language: str | None
    name: str
    stt_engine: str | None
    tts_engine: str | None

    id: str = field(default_factory=ulid_util.ulid)

    def to_json(self) -> dict[str, Any]:
        """Return a JSON serializable representation for storage."""
        return {
            "conversation_engine": self.conversation_engine,
            "id": self.id,
            "language": self.language,
            "name": self.name,
            "stt_engine": self.stt_engine,
            "tts_engine": self.tts_engine,
        }


class PipelineStage(StrEnum):
    """Stages of a pipeline."""

    STT = "stt"
    INTENT = "intent"
    TTS = "tts"


PIPELINE_STAGE_ORDER = [
    PipelineStage.STT,
    PipelineStage.INTENT,
    PipelineStage.TTS,
]


class PipelineRunValidationError(Exception):
    """Error when a pipeline run is not valid."""


class InvalidPipelineStagesError(PipelineRunValidationError):
    """Error when given an invalid combination of start/end stages."""

    def __init__(
        self,
        start_stage: PipelineStage,
        end_stage: PipelineStage,
    ) -> None:
        """Set error message."""
        super().__init__(
            f"Invalid stage combination: start={start_stage}, end={end_stage}"
        )


@dataclass
class PipelineRun:
    """Running context for a pipeline."""

    hass: HomeAssistant
    context: Context
    pipeline: Pipeline
    start_stage: PipelineStage
    end_stage: PipelineStage
    event_callback: PipelineEventCallback
    language: str = None  # type: ignore[assignment]
    runner_data: Any | None = None
    stt_provider: stt.Provider | None = None
    intent_agent: str | None = None
    tts_engine: str | None = None
    tts_options: dict | None = None

    id: str = field(default_factory=ulid_util.ulid)

    def __post_init__(self) -> None:
        """Set language for pipeline."""
        self.language = self.pipeline.language or self.hass.config.language

        # stt -> intent -> tts
        if PIPELINE_STAGE_ORDER.index(self.end_stage) < PIPELINE_STAGE_ORDER.index(
            self.start_stage
        ):
            raise InvalidPipelineStagesError(self.start_stage, self.end_stage)

        pipeline_data: PipelineData = self.hass.data[DOMAIN]
        if self.pipeline.id not in pipeline_data.pipeline_runs:
            pipeline_data.pipeline_runs[self.pipeline.id] = LimitedSizeDict(
                size_limit=STORED_PIPELINE_RUNS
            )
        pipeline_data.pipeline_runs[self.pipeline.id][self.id] = []

    @callback
    def process_event(self, event: PipelineEvent) -> None:
        """Log an event and call listener."""
        self.event_callback(event)
        pipeline_data: PipelineData = self.hass.data[DOMAIN]
        if self.id not in pipeline_data.pipeline_runs[self.pipeline.id]:
            # This run has been evicted from the logged pipeline runs already
            return
        pipeline_data.pipeline_runs[self.pipeline.id][self.id].append(event)

    def start(self) -> None:
        """Emit run start event."""
        data = {
            "pipeline": self.pipeline.name,
            "language": self.language,
        }
        if self.runner_data is not None:
            data["runner_data"] = self.runner_data

        self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))

    def end(self) -> None:
        """Emit run end event."""
        self.process_event(
            PipelineEvent(
                PipelineEventType.RUN_END,
            )
        )

    async def prepare_speech_to_text(self, metadata: stt.SpeechMetadata) -> None:
        """Prepare speech to text."""
        stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine)

        if stt_provider is None:
            engine = self.pipeline.stt_engine or "default"
            raise SpeechToTextError(
                code="stt-provider-missing",
                message=f"No speech to text provider for: {engine}",
            )

        if not stt_provider.check_metadata(metadata):
            raise SpeechToTextError(
                code="stt-provider-unsupported-metadata",
                message=(
                    f"Provider {stt_provider.name} does not support input speech "
                    "to text metadata"
                ),
            )

        self.stt_provider = stt_provider

    async def speech_to_text(
        self,
        metadata: stt.SpeechMetadata,
        stream: AsyncIterable[bytes],
    ) -> str:
        """Run speech to text portion of pipeline. Returns the spoken text."""
        if self.stt_provider is None:
            raise RuntimeError("Speech to text was not prepared")

        engine = self.stt_provider.name

        self.process_event(
            PipelineEvent(
                PipelineEventType.STT_START,
                {
                    "engine": engine,
                    "metadata": asdict(metadata),
                },
            )
        )

        try:
            # Transcribe audio stream
            result = await self.stt_provider.async_process_audio_stream(
                metadata, stream
            )
        except Exception as src_error:
            _LOGGER.exception("Unexpected error during speech to text")
            raise SpeechToTextError(
                code="stt-stream-failed",
                message="Unexpected error during speech to text",
            ) from src_error

        _LOGGER.debug("speech-to-text result %s", result)

        if result.result != stt.SpeechResultState.SUCCESS:
            raise SpeechToTextError(
                code="stt-stream-failed",
                message="Speech to text failed",
            )

        if not result.text:
            raise SpeechToTextError(
                code="stt-no-text-recognized", message="No text recognized"
            )

        self.process_event(
            PipelineEvent(
                PipelineEventType.STT_END,
                {
                    "stt_output": {
                        "text": result.text,
                    }
                },
            )
        )

        return result.text

    async def prepare_recognize_intent(self) -> None:
        """Prepare recognizing an intent."""
        agent_info = conversation.async_get_agent_info(
            self.hass,
            # If no conversation engine is set, use the Home Assistant agent
            # (the conversation integration default is currently the last one set)
            self.pipeline.conversation_engine or conversation.HOME_ASSISTANT_AGENT,
        )

        if agent_info is None:
            engine = self.pipeline.conversation_engine or "default"
            raise IntentRecognitionError(
                code="intent-not-supported",
                message=f"Intent recognition engine {engine} is not found",
            )

        self.intent_agent = agent_info["id"]

    async def recognize_intent(
        self, intent_input: str, conversation_id: str | None
    ) -> str:
        """Run intent recognition portion of pipeline. Returns text to speak."""
        if self.intent_agent is None:
            raise RuntimeError("Recognize intent was not prepared")

        self.process_event(
            PipelineEvent(
                PipelineEventType.INTENT_START,
                {
                    "engine": self.intent_agent,
                    "intent_input": intent_input,
                },
            )
        )

        try:
            conversation_result = await conversation.async_converse(
                hass=self.hass,
                text=intent_input,
                conversation_id=conversation_id,
                context=self.context,
                language=self.language,
                agent_id=self.intent_agent,
            )
        except Exception as src_error:
            _LOGGER.exception("Unexpected error during intent recognition")
            raise IntentRecognitionError(
                code="intent-failed",
                message="Unexpected error during intent recognition",
            ) from src_error

        _LOGGER.debug("conversation result %s", conversation_result)

        self.process_event(
            PipelineEvent(
                PipelineEventType.INTENT_END,
                {"intent_output": conversation_result.as_dict()},
            )
        )

        speech: str = conversation_result.response.speech.get("plain", {}).get(
            "speech", ""
        )

        return speech

    async def prepare_text_to_speech(self) -> None:
        """Prepare text to speech."""
        engine = tts.async_resolve_engine(self.hass, self.pipeline.tts_engine)

        if engine is None:
            engine = self.pipeline.tts_engine or "default"
            raise TextToSpeechError(
                code="tts-not-supported",
                message=f"Text to speech engine '{engine}' not found",
            )

        if not await tts.async_support_options(
            self.hass,
            engine,
            self.language,
            self.tts_options,
        ):
            raise TextToSpeechError(
                code="tts-not-supported",
                message=(
                    f"Text to speech engine {engine} "
                    f"does not support language {self.language} or options {self.tts_options}"
                ),
            )

        self.tts_engine = engine

    async def text_to_speech(self, tts_input: str) -> str:
        """Run text to speech portion of pipeline. Returns URL of TTS audio."""
        if self.tts_engine is None:
            raise RuntimeError("Text to speech was not prepared")

        self.process_event(
            PipelineEvent(
                PipelineEventType.TTS_START,
                {
                    "engine": self.tts_engine,
                    "tts_input": tts_input,
                },
            )
        )

        try:
            # Synthesize audio and get URL
            tts_media_id = tts_generate_media_source_id(
                self.hass,
                tts_input,
                engine=self.tts_engine,
                language=self.language,
                options=self.tts_options,
            )
            tts_media = await media_source.async_resolve_media(
                self.hass,
                tts_media_id,
                None,
            )
        except Exception as src_error:
            _LOGGER.exception("Unexpected error during text to speech")
            raise TextToSpeechError(
                code="tts-failed",
                message="Unexpected error during text to speech",
            ) from src_error

        _LOGGER.debug("TTS result %s", tts_media)

        self.process_event(
            PipelineEvent(
                PipelineEventType.TTS_END,
                {
                    "tts_output": {
                        "media_id": tts_media_id,
                        **asdict(tts_media),
                    }
                },
            )
        )

        return tts_media.url


@dataclass
class PipelineInput:
    """Input to a pipeline run."""

    run: PipelineRun

    stt_metadata: stt.SpeechMetadata | None = None
    """Metadata of stt input audio. Required when start_stage = stt."""

    stt_stream: AsyncIterable[bytes] | None = None
    """Input audio for stt. Required when start_stage = stt."""

    intent_input: str | None = None
    """Input for conversation agent. Required when start_stage = intent."""

    tts_input: str | None = None
    """Input for text to speech. Required when start_stage = tts."""

    conversation_id: str | None = None

    async def execute(self) -> None:
        """Run pipeline."""
        self.run.start()
        current_stage = self.run.start_stage

        try:
            # Speech to text
            intent_input = self.intent_input
            if current_stage == PipelineStage.STT:
                assert self.stt_metadata is not None
                assert self.stt_stream is not None
                intent_input = await self.run.speech_to_text(
                    self.stt_metadata,
                    self.stt_stream,
                )
                current_stage = PipelineStage.INTENT

            if self.run.end_stage != PipelineStage.STT:
                tts_input = self.tts_input

                if current_stage == PipelineStage.INTENT:
                    assert intent_input is not None
                    tts_input = await self.run.recognize_intent(
                        intent_input, self.conversation_id
                    )
                    current_stage = PipelineStage.TTS

                if self.run.end_stage != PipelineStage.INTENT:
                    if current_stage == PipelineStage.TTS:
                        assert tts_input is not None
                        await self.run.text_to_speech(tts_input)

        except PipelineError as err:
            self.run.process_event(
                PipelineEvent(
                    PipelineEventType.ERROR,
                    {"code": err.code, "message": err.message},
                )
            )
            return

        self.run.end()

    async def validate(self) -> None:
        """Validate pipeline input against start stage."""
        if self.run.start_stage == PipelineStage.STT:
            if self.stt_metadata is None:
                raise PipelineRunValidationError(
                    "stt_metadata is required for speech to text"
                )

            if self.stt_stream is None:
                raise PipelineRunValidationError(
                    "stt_stream is required for speech to text"
                )
        elif self.run.start_stage == PipelineStage.INTENT:
            if self.intent_input is None:
                raise PipelineRunValidationError(
                    "intent_input is required for intent recognition"
                )
        elif self.run.start_stage == PipelineStage.TTS:
            if self.tts_input is None:
                raise PipelineRunValidationError(
                    "tts_input is required for text to speech"
                )

        start_stage_index = PIPELINE_STAGE_ORDER.index(self.run.start_stage)

        prepare_tasks = []

        if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
            # self.stt_metadata can't be None or we'd raise above
            prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))  # type: ignore[arg-type]

        if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
            prepare_tasks.append(self.run.prepare_recognize_intent())

        if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.TTS):
            prepare_tasks.append(self.run.prepare_text_to_speech())

        if prepare_tasks:
            await asyncio.gather(*prepare_tasks)


class PipelinePreferred(CollectionError):
    """Raised when attempting to delete the preferred pipelen."""

    def __init__(self, item_id: str) -> None:
        """Initialize pipeline preferred error."""
        super().__init__(f"Item {item_id} preferred.")
        self.item_id = item_id


class SerializedPipelineStorageCollection(SerializedStorageCollection):
    """Serialized pipeline storage collection."""

    preferred_item: str | None


class PipelineStorageCollection(
    StorageCollection[Pipeline, SerializedPipelineStorageCollection]
):
    """Pipeline storage collection."""

    CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

    _preferred_item: str | None = None

    async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
        """Load the data."""
        if not (data := await super()._async_load_data()):
            return data

        self._preferred_item = data["preferred_item"]

        return data

    async def _process_create_data(self, data: dict) -> dict:
        """Validate the config is valid."""
        # We don't need to validate, the WS API has already validated
        return data

    @callback
    def _get_suggested_id(self, info: dict) -> str:
        """Suggest an ID based on the config."""
        return ulid_util.ulid()

    async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
        """Return a new updated item."""
        return Pipeline(id=item.id, **update_data)

    def _create_item(self, item_id: str, data: dict) -> Pipeline:
        """Create an item from validated config."""
        if self._preferred_item is None:
            self._preferred_item = item_id
        return Pipeline(id=item_id, **data)

    def _deserialize_item(self, data: dict) -> Pipeline:
        """Create an item from its serialized representation."""
        return Pipeline(**data)

    def _serialize_item(self, item_id: str, item: Pipeline) -> dict:
        """Return the serialized representation of an item for storing."""
        return item.to_json()

    async def async_delete_item(self, item_id: str) -> None:
        """Delete item."""
        if self._preferred_item == item_id:
            raise PipelinePreferred(item_id)
        await super().async_delete_item(item_id)

    @callback
    def async_get_preferred_item(self) -> str | None:
        """Get the id of the preferred item."""
        return self._preferred_item

    @callback
    def async_set_preferred_item(self, item_id: str) -> None:
        """Set the preferred pipeline."""
        if item_id not in self.data:
            raise ItemNotFound(item_id)
        self._preferred_item = item_id
        self._async_schedule_save()

    @callback
    def _data_to_save(self) -> SerializedPipelineStorageCollection:
        """Return JSON-compatible date for storing to file."""
        base_data = super()._base_data_to_save()
        return {
            "items": base_data["items"],
            "preferred_item": self._preferred_item,
        }


class PipelineStorageCollectionWebsocket(
    StorageCollectionWebsocket[PipelineStorageCollection]
):
    """Class to expose storage collection management over websocket."""

    @callback
    def async_setup(
        self,
        hass: HomeAssistant,
        *,
        create_list: bool = True,
        create_create: bool = True,
    ) -> None:
        """Set up the websocket commands."""
        super().async_setup(hass, create_list=create_list, create_create=create_create)

        websocket_api.async_register_command(
            hass,
            f"{self.api_prefix}/set_preferred",
            websocket_api.require_admin(
                websocket_api.async_response(self.ws_set_preferred_item)
            ),
            websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
                {
                    vol.Required("type"): f"{self.api_prefix}/set_preferred",
                    vol.Required(self.item_id_key): str,
                }
            ),
        )

    def ws_list_item(
        self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
    ) -> None:
        """List items."""
        connection.send_result(
            msg["id"],
            {
                "pipelines": self.storage_collection.async_items(),
                "preferred_pipeline": self.storage_collection.async_get_preferred_item(),
            },
        )

    async def ws_delete_item(
        self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
    ) -> None:
        """Delete an item."""
        try:
            await super().ws_delete_item(hass, connection, msg)
        except PipelinePreferred as exc:
            connection.send_error(
                msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
            )

    async def ws_set_preferred_item(
        self,
        hass: HomeAssistant,
        connection: websocket_api.ActiveConnection,
        msg: dict[str, Any],
    ) -> None:
        """Set the preferred item."""
        try:
            self.storage_collection.async_set_preferred_item(msg[self.item_id_key])
        except ItemNotFound:
            connection.send_error(
                msg["id"], websocket_api.const.ERR_NOT_FOUND, "unknown item"
            )
            return
        connection.send_result(msg["id"])


@dataclass
class PipelineData:
    """Store and debug data stored in hass.data."""

    pipeline_runs: dict[str, LimitedSizeDict[str, list[PipelineEvent]]]
    pipeline_store: PipelineStorageCollection


async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
    """Set up the pipeline storage collection."""
    pipeline_store = PipelineStorageCollection(
        Store(hass, STORAGE_VERSION, STORAGE_KEY)
    )
    await pipeline_store.async_load()
    PipelineStorageCollectionWebsocket(
        pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
    ).async_setup(hass)
    hass.data[DOMAIN] = PipelineData({}, pipeline_store)