diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 00d587e2bb4..d639933ece6 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -14,11 +14,11 @@ from .const import ATTR_SPEAKER, DOMAIN from .data import WyomingService from .devices import SatelliteDevice from .models import DomainDataItem -from .satellite import WyomingSatellite _LOGGER = logging.getLogger(__name__) SATELLITE_PLATFORMS = [ + Platform.ASSIST_SATELLITE, Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH, @@ -47,51 +47,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: entry.async_on_unload(entry.add_update_listener(update_listener)) if (satellite_info := service.info.satellite) is not None: - # Create satellite device, etc. - item.satellite = _make_satellite(hass, entry, service) + # Create satellite device + dev_reg = dr.async_get(hass) - # Set up satellite sensors, switches, etc. - await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) - - # Start satellite communication - entry.async_create_background_task( - hass, - item.satellite.run(), - f"Satellite {satellite_info.name}", + # Use config entry id since only one satellite per entry is supported + satellite_id = entry.entry_id + device = dev_reg.async_get_or_create( + config_entry_id=entry.entry_id, + identifiers={(DOMAIN, satellite_id)}, + name=satellite_info.name, + suggested_area=satellite_info.area, ) - entry.async_on_unload(item.satellite.stop) + item.device = SatelliteDevice( + satellite_id=satellite_id, + device_id=device.id, + ) + + # Set up satellite entity, sensors, switches, etc. + await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS) return True -def _make_satellite( - hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService -) -> WyomingSatellite: - """Create Wyoming satellite/device from config entry and Wyoming service.""" - satellite_info = service.info.satellite - assert satellite_info is not None - - dev_reg = dr.async_get(hass) - - # Use config entry id since only one satellite per entry is supported - satellite_id = config_entry.entry_id - - device = dev_reg.async_get_or_create( - config_entry_id=config_entry.entry_id, - identifiers={(DOMAIN, satellite_id)}, - name=satellite_info.name, - suggested_area=satellite_info.area, - ) - - satellite_device = SatelliteDevice( - satellite_id=satellite_id, - device_id=device.id, - ) - - return WyomingSatellite(hass, config_entry, service, satellite_device) - - async def update_listener(hass: HomeAssistant, entry: ConfigEntry): """Handle options update.""" await hass.config_entries.async_reload(entry.entry_id) @@ -102,7 +80,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: item: DomainDataItem = hass.data[DOMAIN][entry.entry_id] platforms = list(item.service.platforms) - if item.satellite is not None: + if item.device is not None: platforms += SATELLITE_PLATFORMS unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms) diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/assist_satellite.py similarity index 82% rename from homeassistant/components/wyoming/satellite.py rename to homeassistant/components/wyoming/assist_satellite.py index 781f0706c68..83422bd686a 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/assist_satellite.py @@ -1,12 +1,12 @@ -"""Support for Wyoming satellite services.""" +"""Assist satellite entity for Wyoming integration.""" + +from __future__ import annotations import asyncio from collections.abc import AsyncGenerator import io import logging -import time -from typing import Final -from uuid import uuid4 +from typing import Any, Final import wave from wyoming.asr import Transcribe, Transcript @@ -18,20 +18,29 @@ from wyoming.info import Describe, Info from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import PauseSatellite, RunSatellite +from wyoming.snd import Played from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.vad import VoiceStarted, VoiceStopped from wyoming.wake import Detect, Detection -from homeassistant.components import assist_pipeline, intent, stt, tts -from homeassistant.components.assist_pipeline import select as pipeline_select -from homeassistant.components.assist_pipeline.vad import VadSensitivity +from homeassistant.components import assist_pipeline, intent, tts +from homeassistant.components.assist_pipeline import PipelineEvent +from homeassistant.components.assist_satellite import ( + AssistSatelliteConfiguration, + AssistSatelliteEntity, + AssistSatelliteEntityDescription, +) from homeassistant.config_entries import ConfigEntry -from homeassistant.core import Context, HomeAssistant, callback +from homeassistant.const import EntityCategory +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN from .data import WyomingService from .devices import SatelliteDevice +from .entity import WyomingSatelliteEntity +from .models import DomainDataItem _LOGGER = logging.getLogger(__name__) @@ -41,7 +50,6 @@ _RESTART_SECONDS: Final = 3 _PING_TIMEOUT: Final = 5 _PING_SEND_DELAY: Final = 2 _PIPELINE_FINISH_TIMEOUT: Final = 1 -_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes # Wyoming stage -> Assist stage _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { @@ -52,21 +60,47 @@ _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { } -class WyomingSatellite: - """Remove voice satellite running the Wyoming protocol.""" +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming Assist satellite entity.""" + domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + assert domain_data.device is not None + + async_add_entities( + [ + WyomingAssistSatellite( + hass, domain_data.service, domain_data.device, config_entry + ) + ] + ) + + +class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): + """Assist satellite for Wyoming devices.""" + + entity_description = AssistSatelliteEntityDescription(key="assist_satellite") + _attr_translation_key = "assist_satellite" + _attr_entity_category = EntityCategory.CONFIG + _attr_name = None def __init__( self, hass: HomeAssistant, - config_entry: ConfigEntry, service: WyomingService, device: SatelliteDevice, + config_entry: ConfigEntry, ) -> None: - """Initialize satellite.""" - self.hass = hass - self.config_entry = config_entry + """Initialize an Assist satellite.""" + WyomingSatelliteEntity.__init__(self, device) + AssistSatelliteEntity.__init__(self) + self.service = service self.device = device + self.config_entry = config_entry + self.is_running = True self._client: AsyncTcpClient | None = None @@ -84,6 +118,160 @@ class WyomingSatellite: self.device.set_pipeline_listener(self._pipeline_changed) self.device.set_audio_settings_listener(self._audio_settings_changed) + @property + def pipeline_entity_id(self) -> str | None: + """Return the entity ID of the pipeline to use for the next conversation.""" + return self.device.get_pipeline_entity_id(self.hass) + + @property + def vad_sensitivity_entity_id(self) -> str | None: + """Return the entity ID of the VAD sensitivity to use for the next conversation.""" + return self.device.get_vad_sensitivity_entity_id(self.hass) + + @property + def tts_options(self) -> dict[str, Any] | None: + """Options passed for text-to-speech.""" + return { + tts.ATTR_PREFERRED_FORMAT: "wav", + tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + } + + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + await super().async_added_to_hass() + self.start_satellite() + + async def async_will_remove_from_hass(self) -> None: + """Run when entity will be removed from hass.""" + await super().async_will_remove_from_hass() + self.stop_satellite() + + @callback + def async_get_configuration( + self, + ) -> AssistSatelliteConfiguration: + """Get the current satellite configuration.""" + raise NotImplementedError + + async def async_set_configuration( + self, config: AssistSatelliteConfiguration + ) -> None: + """Set the current satellite configuration.""" + raise NotImplementedError + + def on_pipeline_event(self, event: PipelineEvent) -> None: + """Set state based on pipeline stage.""" + assert self._client is not None + + if event.type == assist_pipeline.PipelineEventType.RUN_END: + # Pipeline run is complete + self._is_pipeline_running = False + self._pipeline_ended_event.set() + self.device.set_is_active(False) + elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: + self.hass.add_job(self._client.write_event(Detect().event())) + elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: + # Wake word detection + # Inform client of wake word detection + if event.data and (wake_word_output := event.data.get("wake_word_output")): + detection = Detection( + name=wake_word_output["wake_word_id"], + timestamp=wake_word_output.get("timestamp"), + ) + self.hass.add_job(self._client.write_event(detection.event())) + elif event.type == assist_pipeline.PipelineEventType.STT_START: + # Speech-to-text + self.device.set_is_active(True) + + if event.data: + self.hass.add_job( + self._client.write_event( + Transcribe(language=event.data["metadata"]["language"]).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: + # User started speaking + if event.data: + self.hass.add_job( + self._client.write_event( + VoiceStarted(timestamp=event.data["timestamp"]).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: + # User stopped speaking + if event.data: + self.hass.add_job( + self._client.write_event( + VoiceStopped(timestamp=event.data["timestamp"]).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.STT_END: + # Speech-to-text transcript + if event.data: + # Inform client of transript + stt_text = event.data["stt_output"]["text"] + self.hass.add_job( + self._client.write_event(Transcript(text=stt_text).event()) + ) + elif event.type == assist_pipeline.PipelineEventType.TTS_START: + # Text-to-speech text + if event.data: + # Inform client of text + self.hass.add_job( + self._client.write_event( + Synthesize( + text=event.data["tts_input"], + voice=SynthesizeVoice( + name=event.data.get("voice"), + language=event.data.get("language"), + ), + ).event() + ) + ) + elif event.type == assist_pipeline.PipelineEventType.TTS_END: + # TTS stream + if event.data and (tts_output := event.data["tts_output"]): + media_id = tts_output["media_id"] + self.hass.add_job(self._stream_tts(media_id)) + elif event.type == assist_pipeline.PipelineEventType.ERROR: + # Pipeline error + if event.data: + self.hass.add_job( + self._client.write_event( + Error( + text=event.data["message"], code=event.data["code"] + ).event() + ) + ) + + # ------------------------------------------------------------------------- + + def start_satellite(self) -> None: + """Start satellite task.""" + self.is_running = True + + self.config_entry.async_create_background_task( + self.hass, self.run(), "wyoming satellite run" + ) + + def stop_satellite(self) -> None: + """Signal satellite task to stop running.""" + # Stop existing pipeline + self._audio_queue.put_nowait(None) + + # Tell satellite to stop running + self._send_pause() + + # Stop task loop + self.is_running = False + + # Unblock waiting for unmuted + self._muted_changed_event.set() + + # ------------------------------------------------------------------------- + async def run(self) -> None: """Run and maintain a connection to satellite.""" _LOGGER.debug("Running satellite task") @@ -110,6 +298,9 @@ class WyomingSatellite: except Exception as err: # noqa: BLE001 _LOGGER.debug("%s: %s", err.__class__.__name__, str(err)) + # Stop any existing pipeline + self._audio_queue.put_nowait(None) + # Ensure sensor is off (before restart) self.device.set_is_active(False) @@ -123,17 +314,6 @@ class WyomingSatellite: await self.on_stopped() - def stop(self) -> None: - """Signal satellite task to stop running.""" - # Tell satellite to stop running - self._send_pause() - - # Stop task loop - self.is_running = False - - # Unblock waiting for unmuted - self._muted_changed_event.set() - async def on_restart(self) -> None: """Block until pipeline loop will be restarted.""" _LOGGER.warning( @@ -151,7 +331,7 @@ class WyomingSatellite: await asyncio.sleep(_RECONNECT_SECONDS) async def on_muted(self) -> None: - """Block until device may be unmated again.""" + """Block until device may be unmuted again.""" await self._muted_changed_event.wait() async def on_stopped(self) -> None: @@ -252,6 +432,7 @@ class WyomingSatellite: done, pending = await asyncio.wait( pending, return_when=asyncio.FIRST_COMPLETED ) + if pipeline_ended_task in done: # Pipeline run end event was received _LOGGER.debug("Pipeline finished") @@ -302,7 +483,7 @@ class WyomingSatellite: elif AudioStop.is_type(client_event.type) and self._is_pipeline_running: # Stop pipeline _LOGGER.debug("Client requested pipeline to stop") - self._audio_queue.put_nowait(b"") + self._audio_queue.put_nowait(None) elif Info.is_type(client_event.type): client_info = Info.from_event(client_event) _LOGGER.debug("Updated client info: %s", client_info) @@ -329,6 +510,9 @@ class WyomingSatellite: break _LOGGER.debug("Client detected wake word: %s", wake_word_phrase) + elif Played.is_type(client_event.type): + # TTS response has finished playing on satellite + self.tts_response_finished() else: _LOGGER.debug("Unexpected event from satellite: %s", client_event) @@ -353,72 +537,20 @@ class WyomingSatellite: if end_stage is None: raise ValueError(f"Invalid end stage: {end_stage}") - pipeline_id = pipeline_select.get_chosen_pipeline( - self.hass, - DOMAIN, - self.device.satellite_id, - ) - pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id) - assert pipeline is not None - # We will push audio in through a queue self._audio_queue = asyncio.Queue() - stt_stream = self._stt_stream() - - # Start pipeline running - _LOGGER.debug( - "Starting pipeline %s from %s to %s", - pipeline.name, - start_stage, - end_stage, - ) - - # Reset conversation id, if necessary - if (self._conversation_id_time is None) or ( - (time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC - ): - self._conversation_id = None - - if self._conversation_id is None: - self._conversation_id = str(uuid4()) - - # Update timeout - self._conversation_id_time = time.monotonic() self._is_pipeline_running = True self._pipeline_ended_event.clear() self.config_entry.async_create_background_task( self.hass, - assist_pipeline.async_pipeline_from_audio_stream( - self.hass, - context=Context(), - event_callback=self._event_callback, - stt_metadata=stt.SpeechMetadata( - language=pipeline.language, - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=stt_stream, + self.async_accept_pipeline_from_satellite( + audio_stream=self._stt_stream(), start_stage=start_stage, end_stage=end_stage, - tts_audio_output="wav", - pipeline_id=pipeline_id, - audio_settings=assist_pipeline.AudioSettings( - noise_suppression_level=self.device.noise_suppression_level, - auto_gain_dbfs=self.device.auto_gain, - volume_multiplier=self.device.volume_multiplier, - silence_seconds=VadSensitivity.to_seconds( - self.device.vad_sensitivity - ), - ), - device_id=self.device.device_id, wake_word_phrase=wake_word_phrase, - conversation_id=self._conversation_id, ), - name="wyoming satellite pipeline", + "wyoming satellite pipeline", ) async def _send_delayed_ping(self) -> None: @@ -431,91 +563,6 @@ class WyomingSatellite: except ConnectionError: pass # handled with timeout - def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None: - """Translate pipeline events into Wyoming events.""" - assert self._client is not None - - if event.type == assist_pipeline.PipelineEventType.RUN_END: - # Pipeline run is complete - self._is_pipeline_running = False - self._pipeline_ended_event.set() - self.device.set_is_active(False) - elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: - self.hass.add_job(self._client.write_event(Detect().event())) - elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: - # Wake word detection - # Inform client of wake word detection - if event.data and (wake_word_output := event.data.get("wake_word_output")): - detection = Detection( - name=wake_word_output["wake_word_id"], - timestamp=wake_word_output.get("timestamp"), - ) - self.hass.add_job(self._client.write_event(detection.event())) - elif event.type == assist_pipeline.PipelineEventType.STT_START: - # Speech-to-text - self.device.set_is_active(True) - - if event.data: - self.hass.add_job( - self._client.write_event( - Transcribe(language=event.data["metadata"]["language"]).event() - ) - ) - elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: - # User started speaking - if event.data: - self.hass.add_job( - self._client.write_event( - VoiceStarted(timestamp=event.data["timestamp"]).event() - ) - ) - elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: - # User stopped speaking - if event.data: - self.hass.add_job( - self._client.write_event( - VoiceStopped(timestamp=event.data["timestamp"]).event() - ) - ) - elif event.type == assist_pipeline.PipelineEventType.STT_END: - # Speech-to-text transcript - if event.data: - # Inform client of transript - stt_text = event.data["stt_output"]["text"] - self.hass.add_job( - self._client.write_event(Transcript(text=stt_text).event()) - ) - elif event.type == assist_pipeline.PipelineEventType.TTS_START: - # Text-to-speech text - if event.data: - # Inform client of text - self.hass.add_job( - self._client.write_event( - Synthesize( - text=event.data["tts_input"], - voice=SynthesizeVoice( - name=event.data.get("voice"), - language=event.data.get("language"), - ), - ).event() - ) - ) - elif event.type == assist_pipeline.PipelineEventType.TTS_END: - # TTS stream - if event.data and (tts_output := event.data["tts_output"]): - media_id = tts_output["media_id"] - self.hass.add_job(self._stream_tts(media_id)) - elif event.type == assist_pipeline.PipelineEventType.ERROR: - # Pipeline error - if event.data: - self.hass.add_job( - self._client.write_event( - Error( - text=event.data["message"], code=event.data["code"] - ).event() - ) - ) - async def _connect(self) -> None: """Connect to satellite over TCP.""" await self._disconnect() @@ -576,16 +623,16 @@ class WyomingSatellite: async def _stt_stream(self) -> AsyncGenerator[bytes]: """Yield audio chunks from a queue.""" - try: - is_first_chunk = True - while chunk := await self._audio_queue.get(): - if is_first_chunk: - is_first_chunk = False - _LOGGER.debug("Receiving audio from satellite") + is_first_chunk = True + while chunk := await self._audio_queue.get(): + if chunk is None: + break - yield chunk - except asyncio.CancelledError: - pass # ignore + if is_first_chunk: + is_first_chunk = False + _LOGGER.debug("Receiving audio from satellite") + + yield chunk @callback def _handle_timer( diff --git a/homeassistant/components/wyoming/binary_sensor.py b/homeassistant/components/wyoming/binary_sensor.py index ac5db0cda99..24ee073ec4d 100644 --- a/homeassistant/components/wyoming/binary_sensor.py +++ b/homeassistant/components/wyoming/binary_sensor.py @@ -28,9 +28,9 @@ async def async_setup_entry( item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] # Setup is only forwarded for satellites - assert item.satellite is not None + assert item.device is not None - async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)]) + async_add_entities([WyomingSatelliteAssistInProgress(item.device)]) class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity): diff --git a/homeassistant/components/wyoming/entity.py b/homeassistant/components/wyoming/entity.py index 4591283036f..1ce105fb860 100644 --- a/homeassistant/components/wyoming/entity.py +++ b/homeassistant/components/wyoming/entity.py @@ -6,7 +6,7 @@ from homeassistant.helpers import entity from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from .const import DOMAIN -from .satellite import SatelliteDevice +from .devices import SatelliteDevice class WyomingSatelliteEntity(entity.Entity): diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index 30104a88dce..b837d2a9e76 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -3,7 +3,12 @@ "name": "Wyoming Protocol", "codeowners": ["@balloob", "@synesthesiam"], "config_flow": true, - "dependencies": ["assist_pipeline", "intent", "conversation"], + "dependencies": [ + "assist_satellite", + "assist_pipeline", + "intent", + "conversation" + ], "documentation": "https://www.home-assistant.io/integrations/wyoming", "integration_type": "service", "iot_class": "local_push", diff --git a/homeassistant/components/wyoming/models.py b/homeassistant/components/wyoming/models.py index 066af144d78..b819d06f916 100644 --- a/homeassistant/components/wyoming/models.py +++ b/homeassistant/components/wyoming/models.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from .data import WyomingService -from .satellite import WyomingSatellite +from .devices import SatelliteDevice @dataclass @@ -11,4 +11,4 @@ class DomainDataItem: """Domain data item.""" service: WyomingService - satellite: WyomingSatellite | None = None + device: SatelliteDevice | None = None diff --git a/homeassistant/components/wyoming/number.py b/homeassistant/components/wyoming/number.py index 5e769eeb06d..d9a58cc3333 100644 --- a/homeassistant/components/wyoming/number.py +++ b/homeassistant/components/wyoming/number.py @@ -30,13 +30,12 @@ async def async_setup_entry( item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] # Setup is only forwarded for satellites - assert item.satellite is not None + assert item.device is not None - device = item.satellite.device async_add_entities( [ - WyomingSatelliteAutoGainNumber(device), - WyomingSatelliteVolumeMultiplierNumber(device), + WyomingSatelliteAutoGainNumber(item.device), + WyomingSatelliteVolumeMultiplierNumber(item.device), ] ) diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index f852b4d0434..bbcaab81710 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -42,14 +42,13 @@ async def async_setup_entry( item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] # Setup is only forwarded for satellites - assert item.satellite is not None + assert item.device is not None - device = item.satellite.device async_add_entities( [ - WyomingSatellitePipelineSelect(hass, device), - WyomingSatelliteNoiseSuppressionLevelSelect(device), - WyomingSatelliteVadSensitivitySelect(hass, device), + WyomingSatellitePipelineSelect(hass, item.device), + WyomingSatelliteNoiseSuppressionLevelSelect(item.device), + WyomingSatelliteVadSensitivitySelect(hass, item.device), ] ) diff --git a/homeassistant/components/wyoming/switch.py b/homeassistant/components/wyoming/switch.py index c012c60bc5a..308429331c3 100644 --- a/homeassistant/components/wyoming/switch.py +++ b/homeassistant/components/wyoming/switch.py @@ -27,9 +27,9 @@ async def async_setup_entry( item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] # Setup is only forwarded for satellites - assert item.satellite is not None + assert item.device is not None - async_add_entities([WyomingSatelliteMuteSwitch(item.satellite.device)]) + async_add_entities([WyomingSatelliteMuteSwitch(item.device)]) class WyomingSatelliteMuteSwitch( @@ -51,7 +51,7 @@ class WyomingSatelliteMuteSwitch( # Default to off self._attr_is_on = (state is not None) and (state.state == STATE_ON) - self._device.is_muted = self._attr_is_on + self._device.set_is_muted(self._attr_is_on) async def async_turn_on(self, **kwargs: Any) -> None: """Turn on.""" diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 5bfbbfe87b2..30703159994 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -150,10 +150,10 @@ async def reload_satellite( return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.run" + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run" ) as _run_mock, ): # _run_mock: satellite task does not actually run await hass.config_entries.async_reload(config_entry_id) - return hass.data[DOMAIN][config_entry_id].satellite.device + return hass.data[DOMAIN][config_entry_id].device diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 770186d92aa..d504f98a5b0 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -152,7 +152,7 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.run" + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run" ) as _run_mock, ): # _run_mock: satellite task does not actually run @@ -164,4 +164,4 @@ async def satellite_device( hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry ) -> SatelliteDevice: """Get a satellite device fixture.""" - return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device + return hass.data[DOMAIN][satellite_config_entry.entry_id].device diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 1a291153ad0..f293f976242 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -23,6 +23,7 @@ from wyoming.vad import VoiceStarted, VoiceStopped from wyoming.wake import Detect, Detection from homeassistant.components import assist_pipeline, wyoming +from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite from homeassistant.components.wyoming.devices import SatelliteDevice from homeassistant.const import STATE_ON from homeassistant.core import HomeAssistant, State @@ -240,23 +241,22 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", async_pipeline_from_audio_stream, ), patch( - "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + "homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio", return_value=("wav", get_test_wav()), ), - patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0), + patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0), ): entry = await setup_config_entry(hass) - device: SatelliteDevice = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite.device + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device + assert device is not None async with asyncio.timeout(1): await mock_client.connect_event.wait() @@ -443,7 +443,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: """Test callback for a satellite that has been muted.""" on_muted_event = asyncio.Event() - original_on_muted = wyoming.satellite.WyomingSatellite.on_muted + original_on_muted = WyomingAssistSatellite.on_muted async def on_muted(self): # Trigger original function @@ -462,12 +462,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None: "homeassistant.components.wyoming.data.load_wyoming_info", return_value=SATELLITE_INFO, ), + patch( + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", + SatelliteAsyncTcpClient([]), + ), patch( "homeassistant.components.wyoming.switch.WyomingSatelliteMuteSwitch.async_get_last_state", return_value=State("switch.test_mute", STATE_ON), ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_muted", on_muted, ), ): @@ -484,11 +488,11 @@ async def test_satellite_restart(hass: HomeAssistant) -> None: """Test pipeline loop restart after unexpected error.""" on_restart_event = asyncio.Event() - original_on_restart = wyoming.satellite.WyomingSatellite.on_restart + original_on_restart = WyomingAssistSatellite.on_restart async def on_restart(self): await original_on_restart(self) - self.stop() + self.stop_satellite() on_restart_event.set() with ( @@ -497,14 +501,14 @@ async def test_satellite_restart(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._connect_and_loop", side_effect=RuntimeError(), ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart", on_restart, ), - patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0), + patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0), ): await setup_config_entry(hass) async with asyncio.timeout(1): @@ -517,7 +521,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: reconnect_event = asyncio.Event() stopped_event = asyncio.Event() - original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect + original_on_reconnect = WyomingAssistSatellite.on_reconnect async def on_reconnect(self): await original_on_reconnect(self) @@ -526,7 +530,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: num_reconnects += 1 if num_reconnects >= 2: reconnect_event.set() - self.stop() + self.stop_satellite() async def on_stopped(self): stopped_event.set() @@ -537,18 +541,20 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient.connect", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect", side_effect=ConnectionRefusedError(), ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_reconnect", on_reconnect, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped", on_stopped, ), - patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0), + patch( + "homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0 + ), ): await setup_config_entry(hass) async with asyncio.timeout(1): @@ -561,7 +567,7 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None on_restart_event = asyncio.Event() async def on_restart(self): - self.stop() + self.stop_satellite() on_restart_event.set() with ( @@ -570,14 +576,14 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", MockAsyncTcpClient([]), # no RunPipeline event ), patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart", on_restart, ), ): @@ -603,7 +609,7 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None async def on_restart(self): # Pretend sensor got stuck on self.device.is_active = True - self.stop() + self.stop_satellite() on_restart_event.set() async def on_stopped(self): @@ -615,25 +621,23 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", MockAsyncTcpClient(events), ), patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart", on_restart, ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped", on_stopped, ), ): entry = await setup_config_entry(hass) - device: SatelliteDevice = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite.device + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device async with asyncio.timeout(1): await on_restart_event.wait() @@ -665,11 +669,11 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", wraps=_async_pipeline_from_audio_stream, ) as mock_run_pipeline, ): @@ -701,7 +705,7 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None: """Test satellite receiving non-WAV audio from text-to-speech.""" assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) - original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts + original_stream_tts = WyomingAssistSatellite._stream_tts error_event = asyncio.Event() async def _stream_tts(self, media_id): @@ -724,19 +728,19 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", wraps=_async_pipeline_from_audio_stream, ) as mock_run_pipeline, patch( - "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", + "homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio", return_value=("mp3", bytes(1)), ), patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts", _stream_tts, ), ): @@ -819,18 +823,16 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", async_pipeline_from_audio_stream, ), ): entry = await setup_config_entry(hass) - device: SatelliteDevice = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite.device + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device async with asyncio.timeout(1): await mock_client.connect_event.wait() @@ -893,18 +895,16 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", async_pipeline_from_audio_stream, ), ): entry = await setup_config_entry(hass) - device: SatelliteDevice = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite.device + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device async with asyncio.timeout(1): await mock_client.connect_event.wait() @@ -938,7 +938,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None: ).event(), ] - original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once + original_run_pipeline_once = WyomingAssistSatellite._run_pipeline_once start_stage_event = asyncio.Event() end_stage_event = asyncio.Event() @@ -967,11 +967,11 @@ async def test_invalid_stages(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once", + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._run_pipeline_once", _run_pipeline_once, ), ): @@ -1029,11 +1029,11 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ) as mock_client, patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", async_pipeline_from_audio_stream, ), ): @@ -1083,11 +1083,11 @@ async def test_wake_word_phrase(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient(events), ), patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", wraps=_async_pipeline_from_audio_stream, ) as mock_run_pipeline, ): @@ -1114,14 +1114,12 @@ async def test_timers(hass: HomeAssistant) -> None: return_value=SATELLITE_INFO, ), patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", SatelliteAsyncTcpClient([]), ) as mock_client, ): entry = await setup_config_entry(hass) - device: SatelliteDevice = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite.device + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device async with asyncio.timeout(1): await mock_client.connect_event.wait() @@ -1285,104 +1283,3 @@ async def test_timers(hass: HomeAssistant) -> None: timer_finished = mock_client.timer_finished assert timer_finished is not None assert timer_finished.id == timer_started.id - - -async def test_satellite_conversation_id(hass: HomeAssistant) -> None: - """Test that the same conversation id is used until timeout.""" - assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) - - events = [ - RunPipeline( - start_stage=PipelineStage.WAKE, - end_stage=PipelineStage.TTS, - restart_on_end=True, - ).event(), - ] - - pipeline_kwargs: dict[str, Any] = {} - pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = ( - None - ) - run_pipeline_called = asyncio.Event() - - async def async_pipeline_from_audio_stream( - hass: HomeAssistant, - context, - event_callback, - stt_metadata, - stt_stream, - **kwargs, - ) -> None: - nonlocal pipeline_kwargs, pipeline_event_callback - pipeline_kwargs = kwargs - pipeline_event_callback = event_callback - - run_pipeline_called.set() - - with ( - patch( - "homeassistant.components.wyoming.data.load_wyoming_info", - return_value=SATELLITE_INFO, - ), - patch( - "homeassistant.components.wyoming.satellite.AsyncTcpClient", - SatelliteAsyncTcpClient(events), - ) as mock_client, - patch( - "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", - async_pipeline_from_audio_stream, - ), - patch( - "homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio", - return_value=("wav", get_test_wav()), - ), - patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0), - ): - entry = await setup_config_entry(hass) - satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][ - entry.entry_id - ].satellite - - async with asyncio.timeout(1): - await mock_client.connect_event.wait() - await mock_client.run_satellite_event.wait() - - async with asyncio.timeout(1): - await run_pipeline_called.wait() - - assert pipeline_event_callback is not None - - # A conversation id should have been generated - conversation_id = pipeline_kwargs.get("conversation_id") - assert conversation_id - - # Reset and run again - run_pipeline_called.clear() - pipeline_kwargs.clear() - - pipeline_event_callback( - assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) - ) - - async with asyncio.timeout(1): - await run_pipeline_called.wait() - - # Should be the same conversation id - assert pipeline_kwargs.get("conversation_id") == conversation_id - - # Reset and run again, but this time "time out" - satellite._conversation_id_time = None - run_pipeline_called.clear() - pipeline_kwargs.clear() - - pipeline_event_callback( - assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) - ) - - async with asyncio.timeout(1): - await run_pipeline_called.wait() - - # Should be a different conversation id - new_conversation_id = pipeline_kwargs.get("conversation_id") - assert new_conversation_id - assert new_conversation_id != conversation_id