"""ESPHome voice assistant support."""
from __future__ import annotations

import asyncio
from collections import deque
from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence
import logging
import socket
from typing import cast

from aioesphomeapi import VoiceAssistantEventType
import async_timeout

from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
    PipelineEvent,
    PipelineEventType,
    PipelineNotFound,
    async_pipeline_from_audio_stream,
    select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback

from .const import DOMAIN
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper

_LOGGER = logging.getLogger(__name__)

UDP_PORT = 0  # Set to 0 to let the OS pick a free random port
UDP_MAX_PACKET_SIZE = 1024

_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
    VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
    {
        VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
        VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
        VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
        VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
        VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
        VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
        VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
        VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
        VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
    }
)


class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
    """Receive UDP packets and forward them to the voice assistant."""

    started = False
    stopped = False
    transport: asyncio.DatagramTransport | None = None
    remote_addr: tuple[str, int] | None = None

    def __init__(
        self,
        hass: HomeAssistant,
        entry_data: RuntimeEntryData,
        handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
        handle_finished: Callable[[], None],
        audio_timeout: float = 2.0,
    ) -> None:
        """Initialize UDP receiver."""
        self.context = Context()
        self.hass = hass

        assert entry_data.device_info is not None
        self.device_info = entry_data.device_info

        self.queue: asyncio.Queue[bytes] = asyncio.Queue()
        self.handle_event = handle_event
        self.handle_finished = handle_finished
        self._tts_done = asyncio.Event()
        self.audio_timeout = audio_timeout

    async def start_server(self) -> int:
        """Start accepting connections."""

        def accept_connection() -> VoiceAssistantUDPServer:
            """Accept connection."""
            if self.started:
                raise RuntimeError("Can only start once")
            if self.stopped:
                raise RuntimeError("No longer accepting connections")

            self.started = True
            return self

        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.setblocking(False)

        sock.bind(("", UDP_PORT))

        await asyncio.get_running_loop().create_datagram_endpoint(
            accept_connection, sock=sock
        )

        return cast(int, sock.getsockname()[1])

    @callback
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
        """Store transport for later use."""
        self.transport = cast(asyncio.DatagramTransport, transport)

    @callback
    def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
        """Handle incoming UDP packet."""
        if not self.started or self.stopped:
            return
        if self.remote_addr is None:
            self.remote_addr = addr
        self.queue.put_nowait(data)

    def error_received(self, exc: Exception) -> None:
        """Handle when a send or receive operation raises an OSError.

        (Other than BlockingIOError or InterruptedError.)
        """
        _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
        self.handle_finished()

    @callback
    def stop(self) -> None:
        """Stop the receiver."""
        self.queue.put_nowait(b"")
        self.started = False
        self.stopped = True

    def close(self) -> None:
        """Close the receiver."""
        self.started = False
        self.stopped = True
        if self.transport is not None:
            self.transport.close()

    async def _iterate_packets(self) -> AsyncIterable[bytes]:
        """Iterate over incoming packets."""
        if not self.started or self.stopped:
            raise RuntimeError("Not running")

        while data := await self.queue.get():
            yield data

    def _event_callback(self, event: PipelineEvent) -> None:
        """Handle pipeline events."""

        try:
            event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
        except KeyError:
            _LOGGER.warning("Received unknown pipeline event type: %s", event.type)
            return

        data_to_send = None
        error = False
        if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
            assert event.data is not None
            data_to_send = {"text": event.data["stt_output"]["text"]}
        elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
            assert event.data is not None
            data_to_send = {
                "conversation_id": event.data["intent_output"]["conversation_id"] or "",
            }
        elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
            assert event.data is not None
            data_to_send = {"text": event.data["tts_input"]}
        elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
            assert event.data is not None
            path = event.data["tts_output"]["url"]
            url = async_process_play_media_url(self.hass, path)
            data_to_send = {"url": url}

            if self.device_info.voice_assistant_version >= 2:
                media_id = event.data["tts_output"]["media_id"]
                self.hass.async_create_background_task(
                    self._send_tts(media_id), "esphome_voice_assistant_tts"
                )
            else:
                self._tts_done.set()
        elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
            assert event.data is not None
            data_to_send = {
                "code": event.data["code"],
                "message": event.data["message"],
            }
            self._tts_done.set()
            error = True

        self.handle_event(event_type, data_to_send)
        if error:
            self.handle_finished()

    async def _wait_for_speech(
        self,
        segmenter: VoiceCommandSegmenter,
        chunk_buffer: MutableSequence[bytes],
    ) -> bool:
        """Buffer audio chunks until speech is detected.

        Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue).

        Returns True if speech was detected
        Returns False if the connection was stopped gracefully (b"" put onto the queue).
        """
        # Timeout if no audio comes in for a while.
        async with async_timeout.timeout(self.audio_timeout):
            chunk = await self.queue.get()

        while chunk:
            segmenter.process(chunk)
            # Buffer the data we have taken from the queue
            chunk_buffer.append(chunk)
            if segmenter.in_command:
                return True

            async with async_timeout.timeout(self.audio_timeout):
                chunk = await self.queue.get()

        # If chunk is falsey, `stop()` was called
        return False

    async def _segment_audio(
        self,
        segmenter: VoiceCommandSegmenter,
        chunk_buffer: Sequence[bytes],
    ) -> AsyncIterable[bytes]:
        """Yield audio chunks until voice command has finished.

        Raises asyncio.TimeoutError if no audio data is retrievable from the queue.
        """
        # Buffered chunks first
        for buffered_chunk in chunk_buffer:
            yield buffered_chunk

        # Timeout if no audio comes in for a while.
        async with async_timeout.timeout(self.audio_timeout):
            chunk = await self.queue.get()

        while chunk:
            if not segmenter.process(chunk):
                # Voice command is finished
                break

            yield chunk

            async with async_timeout.timeout(self.audio_timeout):
                chunk = await self.queue.get()

    async def _iterate_packets_with_vad(
        self, pipeline_timeout: float
    ) -> Callable[[], AsyncIterable[bytes]] | None:
        segmenter = VoiceCommandSegmenter()
        chunk_buffer: deque[bytes] = deque(maxlen=100)
        try:
            async with async_timeout.timeout(pipeline_timeout):
                speech_detected = await self._wait_for_speech(segmenter, chunk_buffer)
                if not speech_detected:
                    _LOGGER.debug(
                        "Device stopped sending audio before speech was detected"
                    )
                    self.handle_finished()
                    return None
        except asyncio.TimeoutError:
            self.handle_event(
                VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
                {
                    "code": "speech-timeout",
                    "message": "Timed out waiting for speech",
                },
            )
            self.handle_finished()
            return None

        async def _stream_packets() -> AsyncIterable[bytes]:
            try:
                async for chunk in self._segment_audio(segmenter, chunk_buffer):
                    yield chunk
            except asyncio.TimeoutError:
                self.handle_event(
                    VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
                    {
                        "code": "speech-timeout",
                        "message": "No speech detected",
                    },
                )
                self.handle_finished()

        return _stream_packets

    async def run_pipeline(
        self,
        conversation_id: str | None,
        use_vad: bool = False,
        pipeline_timeout: float = 30.0,
    ) -> None:
        """Run the Voice Assistant pipeline."""

        tts_audio_output = (
            "raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
        )

        if use_vad:
            stt_stream = await self._iterate_packets_with_vad(pipeline_timeout)
            # Error or timeout occurred and was handled already
            if stt_stream is None:
                return
        else:
            stt_stream = self._iterate_packets

        _LOGGER.debug("Starting pipeline")
        try:
            async with async_timeout.timeout(pipeline_timeout):
                await async_pipeline_from_audio_stream(
                    self.hass,
                    context=self.context,
                    event_callback=self._event_callback,
                    stt_metadata=stt.SpeechMetadata(
                        language="",  # set in async_pipeline_from_audio_stream
                        format=stt.AudioFormats.WAV,
                        codec=stt.AudioCodecs.PCM,
                        bit_rate=stt.AudioBitRates.BITRATE_16,
                        sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
                        channel=stt.AudioChannels.CHANNEL_MONO,
                    ),
                    stt_stream=stt_stream(),
                    pipeline_id=pipeline_select.get_chosen_pipeline(
                        self.hass, DOMAIN, self.device_info.mac_address
                    ),
                    conversation_id=conversation_id,
                    tts_audio_output=tts_audio_output,
                )

                # Block until TTS is done sending
                await self._tts_done.wait()

            _LOGGER.debug("Pipeline finished")
        except PipelineNotFound:
            self.handle_event(
                VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
                {
                    "code": "pipeline not found",
                    "message": "Selected pipeline timeout",
                },
            )
            _LOGGER.warning("Pipeline not found")
        except asyncio.TimeoutError:
            self.handle_event(
                VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
                {
                    "code": "pipeline-timeout",
                    "message": "Pipeline timeout",
                },
            )
            _LOGGER.warning("Pipeline timeout")
        finally:
            self.handle_finished()

    async def _send_tts(self, media_id: str) -> None:
        """Send TTS audio to device via UDP."""
        try:
            if self.transport is None:
                return

            _extension, audio_bytes = await tts.async_get_media_source_audio(
                self.hass,
                media_id,
            )

            _LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))

            bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
            sample_offset = 0
            samples_left = len(audio_bytes) // bytes_per_sample

            while samples_left > 0:
                bytes_offset = sample_offset * bytes_per_sample
                chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
                samples_in_chunk = len(chunk) // bytes_per_sample
                samples_left -= samples_in_chunk

                self.transport.sendto(chunk, self.remote_addr)
                await asyncio.sleep(
                    samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99
                )

                sample_offset += samples_in_chunk

        finally:
            self._tts_done.set()