"""Test Wyoming satellite."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
import io
from typing import Any
from unittest.mock import patch
import wave

from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event
from wyoming.info import Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection

from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components.wyoming.data import WyomingService
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component

from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient

from tests.common import MockConfigEntry


async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
    """Set up config entry for Wyoming satellite.

    This is separated from the satellite_config_entry method in conftest.py so
    we can patch functions before the satellite task is run during setup.
    """
    entry = MockConfigEntry(
        domain="wyoming",
        data={
            "host": "1.2.3.4",
            "port": 1234,
        },
        title="Test Satellite",
    )
    entry.add_to_hass(hass)
    await hass.config_entries.async_setup(entry.entry_id)
    await hass.async_block_till_done()

    return entry


def get_test_wav() -> bytes:
    """Get bytes for test WAV file."""
    with io.BytesIO() as wav_io:
        with wave.open(wav_io, "wb") as wav_file:
            wav_file.setframerate(22050)
            wav_file.setsampwidth(2)
            wav_file.setnchannels(1)

            # Single frame
            wav_file.writeframes(b"123")

        return wav_io.getvalue()


class SatelliteAsyncTcpClient(MockAsyncTcpClient):
    """Satellite AsyncTcpClient."""

    def __init__(self, responses: list[Event]) -> None:
        """Initialize client."""
        super().__init__(responses)

        self.connect_event = asyncio.Event()
        self.run_satellite_event = asyncio.Event()
        self.detect_event = asyncio.Event()

        self.detection_event = asyncio.Event()
        self.detection: Detection | None = None

        self.transcribe_event = asyncio.Event()
        self.transcribe: Transcribe | None = None

        self.voice_started_event = asyncio.Event()
        self.voice_started: VoiceStarted | None = None

        self.voice_stopped_event = asyncio.Event()
        self.voice_stopped: VoiceStopped | None = None

        self.transcript_event = asyncio.Event()
        self.transcript: Transcript | None = None

        self.synthesize_event = asyncio.Event()
        self.synthesize: Synthesize | None = None

        self.tts_audio_start_event = asyncio.Event()
        self.tts_audio_chunk_event = asyncio.Event()
        self.tts_audio_stop_event = asyncio.Event()
        self.tts_audio_chunk: AudioChunk | None = None

        self.error_event = asyncio.Event()
        self.error: Error | None = None

        self.pong_event = asyncio.Event()
        self.pong: Pong | None = None

        self.ping_event = asyncio.Event()
        self.ping: Ping | None = None

        self._mic_audio_chunk = AudioChunk(
            rate=16000, width=2, channels=1, audio=b"chunk"
        ).event()

    async def connect(self) -> None:
        """Connect."""
        self.connect_event.set()

    async def write_event(self, event: Event):
        """Send."""
        if RunSatellite.is_type(event.type):
            self.run_satellite_event.set()
        elif Detect.is_type(event.type):
            self.detect_event.set()
        elif Detection.is_type(event.type):
            self.detection = Detection.from_event(event)
            self.detection_event.set()
        elif Transcribe.is_type(event.type):
            self.transcribe = Transcribe.from_event(event)
            self.transcribe_event.set()
        elif VoiceStarted.is_type(event.type):
            self.voice_started = VoiceStarted.from_event(event)
            self.voice_started_event.set()
        elif VoiceStopped.is_type(event.type):
            self.voice_stopped = VoiceStopped.from_event(event)
            self.voice_stopped_event.set()
        elif Transcript.is_type(event.type):
            self.transcript = Transcript.from_event(event)
            self.transcript_event.set()
        elif Synthesize.is_type(event.type):
            self.synthesize = Synthesize.from_event(event)
            self.synthesize_event.set()
        elif AudioStart.is_type(event.type):
            self.tts_audio_start_event.set()
        elif AudioChunk.is_type(event.type):
            self.tts_audio_chunk = AudioChunk.from_event(event)
            self.tts_audio_chunk_event.set()
        elif AudioStop.is_type(event.type):
            self.tts_audio_stop_event.set()
        elif Error.is_type(event.type):
            self.error = Error.from_event(event)
            self.error_event.set()
        elif Pong.is_type(event.type):
            self.pong = Pong.from_event(event)
            self.pong_event.set()
        elif Ping.is_type(event.type):
            self.ping = Ping.from_event(event)
            self.ping_event.set()

    async def read_event(self) -> Event | None:
        """Receive."""
        event = await super().read_event()

        # Keep sending audio chunks instead of None
        return event or self._mic_audio_chunk

    def inject_event(self, event: Event) -> None:
        """Put an event in as the next response."""
        self.responses = [event, *self.responses]


async def test_satellite_pipeline(hass: HomeAssistant) -> None:
    """Test running a pipeline with a satellite."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE,
            end_stage=PipelineStage.TTS,
            restart_on_end=True,
        ).event(),
    ]

    pipeline_kwargs: dict[str, Any] = {}
    pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
        None
    )
    run_pipeline_called = asyncio.Event()
    audio_chunk_received = asyncio.Event()

    async def async_pipeline_from_audio_stream(
        hass: HomeAssistant,
        context,
        event_callback,
        stt_metadata,
        stt_stream,
        **kwargs,
    ) -> None:
        nonlocal pipeline_kwargs, pipeline_event_callback
        pipeline_kwargs = kwargs
        pipeline_event_callback = event_callback

        run_pipeline_called.set()
        async for chunk in stt_stream:
            if chunk:
                audio_chunk_received.set()
                break

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            async_pipeline_from_audio_stream,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
            return_value=("wav", get_test_wav()),
        ),
        patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
    ):
        entry = await setup_config_entry(hass)
        device: SatelliteDevice = hass.data[wyoming.DOMAIN][
            entry.entry_id
        ].satellite.device

        async with asyncio.timeout(1):
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        async with asyncio.timeout(1):
            await run_pipeline_called.wait()

            # Reset so we can check the pipeline is automatically restarted below
            run_pipeline_called.clear()

        assert pipeline_event_callback is not None
        assert pipeline_kwargs.get("device_id") == device.device_id

        # Test a ping
        mock_client.inject_event(Ping("test-ping").event())

        # Pong is expected with the same text
        async with asyncio.timeout(1):
            await mock_client.pong_event.wait()

        assert mock_client.pong is not None
        assert mock_client.pong.text == "test-ping"

        # The client should have received the first ping
        async with asyncio.timeout(1):
            await mock_client.ping_event.wait()

        assert mock_client.ping is not None

        # Reset and send a pong back.
        # We will get a second ping by the end of the test.
        mock_client.ping_event.clear()
        mock_client.ping = None
        mock_client.inject_event(Pong().event())

        # Start detecting wake word
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.WAKE_WORD_START
            )
        )
        async with asyncio.timeout(1):
            await mock_client.detect_event.wait()

        assert not device.is_active
        assert not device.is_muted

        # Push in some audio
        mock_client.inject_event(
            AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
        )

        # Wake word is detected
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.WAKE_WORD_END,
                {"wake_word_output": {"wake_word_id": "test_wake_word"}},
            )
        )
        async with asyncio.timeout(1):
            await mock_client.detection_event.wait()

        assert mock_client.detection is not None
        assert mock_client.detection.name == "test_wake_word"

        # "Assist in progress" sensor should be active now
        assert device.is_active

        # Speech-to-text started
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.STT_START,
                {"metadata": {"language": "en"}},
            )
        )
        async with asyncio.timeout(1):
            await mock_client.transcribe_event.wait()

        assert mock_client.transcribe is not None
        assert mock_client.transcribe.language == "en"

        # Push in some audio
        mock_client.inject_event(
            AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
        )

        # User started speaking
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
            )
        )
        async with asyncio.timeout(1):
            await mock_client.voice_started_event.wait()

        assert mock_client.voice_started is not None
        assert mock_client.voice_started.timestamp == 1234

        # User stopped speaking
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
            )
        )
        async with asyncio.timeout(1):
            await mock_client.voice_stopped_event.wait()

        assert mock_client.voice_stopped is not None
        assert mock_client.voice_stopped.timestamp == 5678

        # Speech-to-text transcription
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.STT_END,
                {"stt_output": {"text": "test transcript"}},
            )
        )
        async with asyncio.timeout(1):
            await mock_client.transcript_event.wait()

        assert mock_client.transcript is not None
        assert mock_client.transcript.text == "test transcript"

        # Text-to-speech text
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.TTS_START,
                {
                    "tts_input": "test text to speak",
                    "voice": "test voice",
                },
            )
        )
        async with asyncio.timeout(1):
            await mock_client.synthesize_event.wait()

        assert mock_client.synthesize is not None
        assert mock_client.synthesize.text == "test text to speak"
        assert mock_client.synthesize.voice is not None
        assert mock_client.synthesize.voice.name == "test voice"

        # Text-to-speech media
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.TTS_END,
                {"tts_output": {"media_id": "test media id"}},
            )
        )
        async with asyncio.timeout(1):
            await mock_client.tts_audio_start_event.wait()
            await mock_client.tts_audio_chunk_event.wait()
            await mock_client.tts_audio_stop_event.wait()

        # Verify audio chunk from test WAV
        assert mock_client.tts_audio_chunk is not None
        assert mock_client.tts_audio_chunk.rate == 22050
        assert mock_client.tts_audio_chunk.width == 2
        assert mock_client.tts_audio_chunk.channels == 1
        assert mock_client.tts_audio_chunk.audio == b"123"

        # Pipeline finished
        pipeline_event_callback(
            assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
        )
        assert not device.is_active

        # The client should have received another ping by now
        async with asyncio.timeout(1):
            await mock_client.ping_event.wait()

        assert mock_client.ping is not None

        # Pipeline should automatically restart
        async with asyncio.timeout(1):
            await run_pipeline_called.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_satellite_muted(hass: HomeAssistant) -> None:
    """Test callback for a satellite that has been muted."""
    on_muted_event = asyncio.Event()

    original_make_satellite = wyoming._make_satellite
    original_on_muted = wyoming.satellite.WyomingSatellite.on_muted

    def make_muted_satellite(
        hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
    ):
        satellite = original_make_satellite(hass, config_entry, service)
        satellite.device.set_is_muted(True)

        return satellite

    async def on_muted(self):
        # Trigger original function
        self._muted_changed_event.set()
        await original_on_muted(self)

        # Ensure satellite stops
        self.is_running = False

        # Proceed with test
        self.device.set_is_muted(False)
        on_muted_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch("homeassistant.components.wyoming._make_satellite", make_muted_satellite),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
            on_muted,
        ),
    ):
        entry = await setup_config_entry(hass)
        async with asyncio.timeout(1):
            await on_muted_event.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_satellite_restart(hass: HomeAssistant) -> None:
    """Test pipeline loop restart after unexpected error."""
    on_restart_event = asyncio.Event()

    original_on_restart = wyoming.satellite.WyomingSatellite.on_restart

    async def on_restart(self):
        await original_on_restart(self)
        self.stop()
        on_restart_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
            side_effect=RuntimeError(),
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
            on_restart,
        ),
        patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
    ):
        await setup_config_entry(hass)
        async with asyncio.timeout(1):
            await on_restart_event.wait()


async def test_satellite_reconnect(hass: HomeAssistant) -> None:
    """Test satellite reconnect call after connection refused."""
    num_reconnects = 0
    reconnect_event = asyncio.Event()
    stopped_event = asyncio.Event()

    original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect

    async def on_reconnect(self):
        await original_on_reconnect(self)

        nonlocal num_reconnects
        num_reconnects += 1
        if num_reconnects >= 2:
            reconnect_event.set()
            self.stop()

    async def on_stopped(self):
        stopped_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
            side_effect=ConnectionRefusedError(),
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
            on_reconnect,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
            on_stopped,
        ),
        patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
    ):
        await setup_config_entry(hass)
        async with asyncio.timeout(1):
            await reconnect_event.wait()
            await stopped_event.wait()


async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
    """Test satellite disconnecting before pipeline run."""
    on_restart_event = asyncio.Event()

    async def on_restart(self):
        self.stop()
        on_restart_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            MockAsyncTcpClient([]),  # no RunPipeline event
        ),
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
        ) as mock_run_pipeline,
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
            on_restart,
        ),
    ):
        await setup_config_entry(hass)
        async with asyncio.timeout(1):
            await on_restart_event.wait()

        # Pipeline should never have run
        mock_run_pipeline.assert_not_called()


async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
    """Test satellite disconnecting during pipeline run."""
    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]  # no audio chunks after RunPipeline

    on_restart_event = asyncio.Event()
    on_stopped_event = asyncio.Event()

    async def on_restart(self):
        # Pretend sensor got stuck on
        self.device.is_active = True
        self.stop()
        on_restart_event.set()

    async def on_stopped(self):
        on_stopped_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            MockAsyncTcpClient(events),
        ),
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
        ) as mock_run_pipeline,
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
            on_restart,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
            on_stopped,
        ),
    ):
        entry = await setup_config_entry(hass)
        device: SatelliteDevice = hass.data[wyoming.DOMAIN][
            entry.entry_id
        ].satellite.device

        async with asyncio.timeout(1):
            await on_restart_event.wait()
            await on_stopped_event.wait()

        # Pipeline should have run once
        mock_run_pipeline.assert_called_once()

        # Sensor should have been turned off
        assert not device.is_active


async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
    """Test satellite error occurring during pipeline run."""
    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]  # no audio chunks after RunPipeline

    pipeline_event = asyncio.Event()

    def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
        pipeline_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            wraps=_async_pipeline_from_audio_stream,
        ) as mock_run_pipeline,
    ):
        await setup_config_entry(hass)

        async with asyncio.timeout(1):
            await pipeline_event.wait()
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        mock_run_pipeline.assert_called_once()
        event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
        event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.ERROR,
                {"code": "test code", "message": "test message"},
            )
        )

        async with asyncio.timeout(1):
            await mock_client.error_event.wait()

        assert mock_client.error is not None
        assert mock_client.error.text == "test message"
        assert mock_client.error.code == "test code"


async def test_tts_not_wav(hass: HomeAssistant) -> None:
    """Test satellite receiving non-WAV audio from text-to-speech."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
    error_event = asyncio.Event()

    async def _stream_tts(self, media_id):
        try:
            await original_stream_tts(self, media_id)
        except ValueError:
            error_event.set()

    events = [
        RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
    ]
    pipeline_event = asyncio.Event()

    def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
        pipeline_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            wraps=_async_pipeline_from_audio_stream,
        ) as mock_run_pipeline,
        patch(
            "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
            return_value=("mp3", bytes(1)),
        ),
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
            _stream_tts,
        ),
    ):
        entry = await setup_config_entry(hass)

        async with asyncio.timeout(1):
            await pipeline_event.wait()
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        mock_run_pipeline.assert_called_once()
        event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]

        # Text-to-speech text
        event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.TTS_START,
                {
                    "tts_input": "test text to speak",
                    "voice": "test voice",
                },
            )
        )
        async with asyncio.timeout(1):
            await mock_client.synthesize_event.wait()

        # Text-to-speech media
        event_callback(
            assist_pipeline.PipelineEvent(
                assist_pipeline.PipelineEventType.TTS_END,
                {"tts_output": {"media_id": "test media id"}},
            )
        )

        # Expect error because only WAV is supported
        async with asyncio.timeout(1):
            await error_event.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_pipeline_changed(hass: HomeAssistant) -> None:
    """Test that changing the pipeline setting stops the current pipeline."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]

    pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
        None
    )
    run_pipeline_called = asyncio.Event()
    pipeline_stopped = asyncio.Event()

    async def async_pipeline_from_audio_stream(
        hass: HomeAssistant,
        context,
        event_callback,
        stt_metadata,
        stt_stream,
        **kwargs,
    ) -> None:
        nonlocal pipeline_event_callback
        pipeline_event_callback = event_callback

        run_pipeline_called.set()
        async for _chunk in stt_stream:
            pass

        pipeline_stopped.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            async_pipeline_from_audio_stream,
        ),
    ):
        entry = await setup_config_entry(hass)
        device: SatelliteDevice = hass.data[wyoming.DOMAIN][
            entry.entry_id
        ].satellite.device

        async with asyncio.timeout(1):
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        # Pipeline has started
        async with asyncio.timeout(1):
            await run_pipeline_called.wait()

        assert pipeline_event_callback is not None

        # Change pipelines
        device.set_pipeline_name("different pipeline")

        # Running pipeline should be cancelled
        async with asyncio.timeout(1):
            await pipeline_stopped.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_audio_settings_changed(hass: HomeAssistant) -> None:
    """Test that changing audio settings stops the current pipeline."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]

    pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
        None
    )
    run_pipeline_called = asyncio.Event()
    pipeline_stopped = asyncio.Event()

    async def async_pipeline_from_audio_stream(
        hass: HomeAssistant,
        context,
        event_callback,
        stt_metadata,
        stt_stream,
        **kwargs,
    ) -> None:
        nonlocal pipeline_event_callback
        pipeline_event_callback = event_callback

        run_pipeline_called.set()
        async for _chunk in stt_stream:
            pass

        pipeline_stopped.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            async_pipeline_from_audio_stream,
        ),
    ):
        entry = await setup_config_entry(hass)
        device: SatelliteDevice = hass.data[wyoming.DOMAIN][
            entry.entry_id
        ].satellite.device

        async with asyncio.timeout(1):
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        # Pipeline has started
        async with asyncio.timeout(1):
            await run_pipeline_called.wait()

        assert pipeline_event_callback is not None

        # Change audio setting
        device.set_noise_suppression_level(1)

        # Running pipeline should be cancelled
        async with asyncio.timeout(1):
            await pipeline_stopped.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_invalid_stages(hass: HomeAssistant) -> None:
    """Test error when providing invalid pipeline stages."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]

    original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
    start_stage_event = asyncio.Event()
    end_stage_event = asyncio.Event()

    def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
        # Set bad start stage
        run_pipeline.start_stage = PipelineStage.INTENT
        run_pipeline.end_stage = PipelineStage.TTS

        try:
            original_run_pipeline_once(self, run_pipeline)
        except ValueError:
            start_stage_event.set()

        # Set bad end stage
        run_pipeline.start_stage = PipelineStage.WAKE
        run_pipeline.end_stage = PipelineStage.INTENT

        try:
            original_run_pipeline_once(self, run_pipeline)
        except ValueError:
            end_stage_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
            _run_pipeline_once,
        ),
    ):
        entry = await setup_config_entry(hass)

        async with asyncio.timeout(1):
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        async with asyncio.timeout(1):
            await start_stage_event.wait()
            await end_stage_event.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
    """Test that an AudioStop message stops the current pipeline."""
    assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})

    events = [
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]

    pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
        None
    )
    run_pipeline_called = asyncio.Event()
    pipeline_stopped = asyncio.Event()

    async def async_pipeline_from_audio_stream(
        hass: HomeAssistant,
        context,
        event_callback,
        stt_metadata,
        stt_stream,
        **kwargs,
    ) -> None:
        nonlocal pipeline_event_callback
        pipeline_event_callback = event_callback

        run_pipeline_called.set()
        async for _chunk in stt_stream:
            pass

        pipeline_stopped.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ) as mock_client,
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            async_pipeline_from_audio_stream,
        ),
    ):
        entry = await setup_config_entry(hass)

        async with asyncio.timeout(1):
            await mock_client.connect_event.wait()
            await mock_client.run_satellite_event.wait()

        # Pipeline has started
        async with asyncio.timeout(1):
            await run_pipeline_called.wait()

        assert pipeline_event_callback is not None

        # Client sends stop message
        mock_client.inject_event(AudioStop().event())

        # Running pipeline should be cancelled
        async with asyncio.timeout(1):
            await pipeline_stopped.wait()

        # Stop the satellite
        await hass.config_entries.async_unload(entry.entry_id)
        await hass.async_block_till_done()


async def test_wake_word_phrase(hass: HomeAssistant) -> None:
    """Test that wake word phrase from info is given to pipeline."""
    events = [
        # Fake local wake word detection
        Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
        Detection(name="Test Model").event(),
        RunPipeline(
            start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
        ).event(),
    ]

    pipeline_event = asyncio.Event()

    def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
        pipeline_event.set()

    with (
        patch(
            "homeassistant.components.wyoming.data.load_wyoming_info",
            return_value=SATELLITE_INFO,
        ),
        patch(
            "homeassistant.components.wyoming.satellite.AsyncTcpClient",
            SatelliteAsyncTcpClient(events),
        ),
        patch(
            "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
            wraps=_async_pipeline_from_audio_stream,
        ) as mock_run_pipeline,
    ):
        await setup_config_entry(hass)

        async with asyncio.timeout(1):
            await pipeline_event.wait()

        # async_pipeline_from_audio_stream will receive the wake word phrase for
        # deconfliction.
        mock_run_pipeline.assert_called_once()
        assert (
            mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
        )