Migrate Wyoming satellite to Assist satellite entity (#128488)

* Migrate Wyoming satellite to Assist satellite entity

* Fix tests

* Update homeassistant/components/wyoming/assist_satellite.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Update homeassistant/components/wyoming/assist_satellite.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-10-16 08:59:44 -05:00 committed by GitHub
parent c294130080
commit bcac851677
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 325 additions and 400 deletions

View file

@ -14,11 +14,11 @@ from .const import ATTR_SPEAKER, DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
from .models import DomainDataItem
from .satellite import WyomingSatellite
_LOGGER = logging.getLogger(__name__)
SATELLITE_PLATFORMS = [
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,
@ -47,51 +47,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry.async_on_unload(entry.add_update_listener(update_listener))
if (satellite_info := service.info.satellite) is not None:
# Create satellite device, etc.
item.satellite = _make_satellite(hass, entry, service)
# Create satellite device
dev_reg = dr.async_get(hass)
# Set up satellite sensors, switches, etc.
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
# Start satellite communication
entry.async_create_background_task(
hass,
item.satellite.run(),
f"Satellite {satellite_info.name}",
# Use config entry id since only one satellite per entry is supported
satellite_id = entry.entry_id
device = dev_reg.async_get_or_create(
config_entry_id=entry.entry_id,
identifiers={(DOMAIN, satellite_id)},
name=satellite_info.name,
suggested_area=satellite_info.area,
)
entry.async_on_unload(item.satellite.stop)
item.device = SatelliteDevice(
satellite_id=satellite_id,
device_id=device.id,
)
# Set up satellite entity, sensors, switches, etc.
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
return True
def _make_satellite(
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
) -> WyomingSatellite:
"""Create Wyoming satellite/device from config entry and Wyoming service."""
satellite_info = service.info.satellite
assert satellite_info is not None
dev_reg = dr.async_get(hass)
# Use config entry id since only one satellite per entry is supported
satellite_id = config_entry.entry_id
device = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={(DOMAIN, satellite_id)},
name=satellite_info.name,
suggested_area=satellite_info.area,
)
satellite_device = SatelliteDevice(
satellite_id=satellite_id,
device_id=device.id,
)
return WyomingSatellite(hass, config_entry, service, satellite_device)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
@ -102,7 +80,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
platforms = list(item.service.platforms)
if item.satellite is not None:
if item.device is not None:
platforms += SATELLITE_PLATFORMS
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)

View file

@ -1,12 +1,12 @@
"""Support for Wyoming satellite services."""
"""Assist satellite entity for Wyoming integration."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
import io
import logging
import time
from typing import Final
from uuid import uuid4
from typing import Any, Final
import wave
from wyoming.asr import Transcribe, Transcript
@ -18,20 +18,29 @@ from wyoming.info import Describe, Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import PauseSatellite, RunSatellite
from wyoming.snd import Played
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, intent, stt, tts
from homeassistant.components.assist_pipeline import select as pipeline_select
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components import assist_pipeline, intent, tts
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
from .entity import WyomingSatelliteEntity
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__)
@ -41,7 +50,6 @@ _RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -52,21 +60,47 @@ _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
}
class WyomingSatellite:
"""Remove voice satellite running the Wyoming protocol."""
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming Assist satellite entity."""
domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
assert domain_data.device is not None
async_add_entities(
[
WyomingAssistSatellite(
hass, domain_data.service, domain_data.device, config_entry
)
]
)
class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
"""Assist satellite for Wyoming devices."""
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_entity_category = EntityCategory.CONFIG
_attr_name = None
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
service: WyomingService,
device: SatelliteDevice,
config_entry: ConfigEntry,
) -> None:
"""Initialize satellite."""
self.hass = hass
self.config_entry = config_entry
"""Initialize an Assist satellite."""
WyomingSatelliteEntity.__init__(self, device)
AssistSatelliteEntity.__init__(self)
self.service = service
self.device = device
self.config_entry = config_entry
self.is_running = True
self._client: AsyncTcpClient | None = None
@ -84,6 +118,160 @@ class WyomingSatellite:
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
return self.device.get_pipeline_entity_id(self.hass)
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
return self.device.get_vad_sensitivity_entity_id(self.hass)
@property
def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
self.start_satellite()
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
self.stop_satellite()
@callback
def async_get_configuration(
self,
) -> AssistSatelliteConfiguration:
"""Get the current satellite configuration."""
raise NotImplementedError
async def async_set_configuration(
self, config: AssistSatelliteConfiguration
) -> None:
"""Set the current satellite configuration."""
raise NotImplementedError
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detection = Detection(
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
# -------------------------------------------------------------------------
def start_satellite(self) -> None:
"""Start satellite task."""
self.is_running = True
self.config_entry.async_create_background_task(
self.hass, self.run(), "wyoming satellite run"
)
def stop_satellite(self) -> None:
"""Signal satellite task to stop running."""
# Stop existing pipeline
self._audio_queue.put_nowait(None)
# Tell satellite to stop running
self._send_pause()
# Stop task loop
self.is_running = False
# Unblock waiting for unmuted
self._muted_changed_event.set()
# -------------------------------------------------------------------------
async def run(self) -> None:
"""Run and maintain a connection to satellite."""
_LOGGER.debug("Running satellite task")
@ -110,6 +298,9 @@ class WyomingSatellite:
except Exception as err: # noqa: BLE001
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))
# Stop any existing pipeline
self._audio_queue.put_nowait(None)
# Ensure sensor is off (before restart)
self.device.set_is_active(False)
@ -123,17 +314,6 @@ class WyomingSatellite:
await self.on_stopped()
def stop(self) -> None:
"""Signal satellite task to stop running."""
# Tell satellite to stop running
self._send_pause()
# Stop task loop
self.is_running = False
# Unblock waiting for unmuted
self._muted_changed_event.set()
async def on_restart(self) -> None:
"""Block until pipeline loop will be restarted."""
_LOGGER.warning(
@ -151,7 +331,7 @@ class WyomingSatellite:
await asyncio.sleep(_RECONNECT_SECONDS)
async def on_muted(self) -> None:
"""Block until device may be unmated again."""
"""Block until device may be unmuted again."""
await self._muted_changed_event.wait()
async def on_stopped(self) -> None:
@ -252,6 +432,7 @@ class WyomingSatellite:
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
if pipeline_ended_task in done:
# Pipeline run end event was received
_LOGGER.debug("Pipeline finished")
@ -302,7 +483,7 @@ class WyomingSatellite:
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
self._audio_queue.put_nowait(None)
elif Info.is_type(client_event.type):
client_info = Info.from_event(client_event)
_LOGGER.debug("Updated client info: %s", client_info)
@ -329,6 +510,9 @@ class WyomingSatellite:
break
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
elif Played.is_type(client_event.type):
# TTS response has finished playing on satellite
self.tts_response_finished()
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
@ -353,72 +537,20 @@ class WyomingSatellite:
if end_stage is None:
raise ValueError(f"Invalid end stage: {end_stage}")
pipeline_id = pipeline_select.get_chosen_pipeline(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
# Reset conversation id, if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = str(uuid4())
# Update timeout
self._conversation_id_time = time.monotonic()
self._is_pipeline_running = True
self._pipeline_ended_event.clear()
self.config_entry.async_create_background_task(
self.hass,
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream,
self.async_accept_pipeline_from_satellite(
audio_stream=self._stt_stream(),
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
silence_seconds=VadSensitivity.to_seconds(
self.device.vad_sensitivity
),
),
device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
conversation_id=self._conversation_id,
),
name="wyoming satellite pipeline",
"wyoming satellite pipeline",
)
async def _send_delayed_ping(self) -> None:
@ -431,91 +563,6 @@ class WyomingSatellite:
except ConnectionError:
pass # handled with timeout
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detection = Detection(
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
async def _connect(self) -> None:
"""Connect to satellite over TCP."""
await self._disconnect()
@ -576,16 +623,16 @@ class WyomingSatellite:
async def _stt_stream(self) -> AsyncGenerator[bytes]:
"""Yield audio chunks from a queue."""
try:
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
is_first_chunk = True
while chunk := await self._audio_queue.get():
if chunk is None:
break
yield chunk
except asyncio.CancelledError:
pass # ignore
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
yield chunk
@callback
def _handle_timer(

View file

@ -28,9 +28,9 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.device is not None
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
async_add_entities([WyomingSatelliteAssistInProgress(item.device)])
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):

View file

@ -6,7 +6,7 @@ from homeassistant.helpers import entity
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from .const import DOMAIN
from .satellite import SatelliteDevice
from .devices import SatelliteDevice
class WyomingSatelliteEntity(entity.Entity):

View file

@ -3,7 +3,12 @@
"name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline", "intent", "conversation"],
"dependencies": [
"assist_satellite",
"assist_pipeline",
"intent",
"conversation"
],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"integration_type": "service",
"iot_class": "local_push",

View file

@ -3,7 +3,7 @@
from dataclasses import dataclass
from .data import WyomingService
from .satellite import WyomingSatellite
from .devices import SatelliteDevice
@dataclass
@ -11,4 +11,4 @@ class DomainDataItem:
"""Domain data item."""
service: WyomingService
satellite: WyomingSatellite | None = None
device: SatelliteDevice | None = None

View file

@ -30,13 +30,12 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.device is not None
device = item.satellite.device
async_add_entities(
[
WyomingSatelliteAutoGainNumber(device),
WyomingSatelliteVolumeMultiplierNumber(device),
WyomingSatelliteAutoGainNumber(item.device),
WyomingSatelliteVolumeMultiplierNumber(item.device),
]
)

View file

@ -42,14 +42,13 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.device is not None
device = item.satellite.device
async_add_entities(
[
WyomingSatellitePipelineSelect(hass, device),
WyomingSatelliteNoiseSuppressionLevelSelect(device),
WyomingSatelliteVadSensitivitySelect(hass, device),
WyomingSatellitePipelineSelect(hass, item.device),
WyomingSatelliteNoiseSuppressionLevelSelect(item.device),
WyomingSatelliteVadSensitivitySelect(hass, item.device),
]
)

View file

@ -27,9 +27,9 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.device is not None
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite.device)])
async_add_entities([WyomingSatelliteMuteSwitch(item.device)])
class WyomingSatelliteMuteSwitch(
@ -51,7 +51,7 @@ class WyomingSatelliteMuteSwitch(
# Default to off
self._attr_is_on = (state is not None) and (state.state == STATE_ON)
self._device.is_muted = self._attr_is_on
self._device.set_is_muted(self._attr_is_on)
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on."""

View file

@ -150,10 +150,10 @@ async def reload_satellite(
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run"
) as _run_mock,
):
# _run_mock: satellite task does not actually run
await hass.config_entries.async_reload(config_entry_id)
return hass.data[DOMAIN][config_entry_id].satellite.device
return hass.data[DOMAIN][config_entry_id].device

View file

@ -152,7 +152,7 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run"
) as _run_mock,
):
# _run_mock: satellite task does not actually run
@ -164,4 +164,4 @@ async def satellite_device(
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
) -> SatelliteDevice:
"""Get a satellite device fixture."""
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
return hass.data[DOMAIN][satellite_config_entry.entry_id].device

View file

@ -23,6 +23,7 @@ from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.const import STATE_ON
from homeassistant.core import HomeAssistant, State
@ -240,23 +241,22 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
assert device is not None
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -443,7 +443,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
"""Test callback for a satellite that has been muted."""
on_muted_event = asyncio.Event()
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
original_on_muted = WyomingAssistSatellite.on_muted
async def on_muted(self):
# Trigger original function
@ -462,12 +462,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
),
patch(
"homeassistant.components.wyoming.switch.WyomingSatelliteMuteSwitch.async_get_last_state",
return_value=State("switch.test_mute", STATE_ON),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_muted",
on_muted,
),
):
@ -484,11 +488,11 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
"""Test pipeline loop restart after unexpected error."""
on_restart_event = asyncio.Event()
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
original_on_restart = WyomingAssistSatellite.on_restart
async def on_restart(self):
await original_on_restart(self)
self.stop()
self.stop_satellite()
on_restart_event.set()
with (
@ -497,14 +501,14 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._connect_and_loop",
side_effect=RuntimeError(),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
@ -517,7 +521,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
original_on_reconnect = WyomingAssistSatellite.on_reconnect
async def on_reconnect(self):
await original_on_reconnect(self)
@ -526,7 +530,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
num_reconnects += 1
if num_reconnects >= 2:
reconnect_event.set()
self.stop()
self.stop_satellite()
async def on_stopped(self):
stopped_event.set()
@ -537,18 +541,20 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
side_effect=ConnectionRefusedError(),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_reconnect",
on_reconnect,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
on_stopped,
),
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
patch(
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
@ -561,7 +567,7 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
on_restart_event = asyncio.Event()
async def on_restart(self):
self.stop()
self.stop_satellite()
on_restart_event.set()
with (
@ -570,14 +576,14 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient([]), # no RunPipeline event
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
):
@ -603,7 +609,7 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
async def on_restart(self):
# Pretend sensor got stuck on
self.device.is_active = True
self.stop()
self.stop_satellite()
on_restart_event.set()
async def on_stopped(self):
@ -615,25 +621,23 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient(events),
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
on_stopped,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await on_restart_event.wait()
@ -665,11 +669,11 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
@ -701,7 +705,7 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
"""Test satellite receiving non-WAV audio from text-to-speech."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
original_stream_tts = WyomingAssistSatellite._stream_tts
error_event = asyncio.Event()
async def _stream_tts(self, media_id):
@ -724,19 +728,19 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
_stream_tts,
),
):
@ -819,18 +823,16 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -893,18 +895,16 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -938,7 +938,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
).event(),
]
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
original_run_pipeline_once = WyomingAssistSatellite._run_pipeline_once
start_stage_event = asyncio.Event()
end_stage_event = asyncio.Event()
@ -967,11 +967,11 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._run_pipeline_once",
_run_pipeline_once,
),
):
@ -1029,11 +1029,11 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
@ -1083,11 +1083,11 @@ async def test_wake_word_phrase(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
@ -1114,14 +1114,12 @@ async def test_timers(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
) as mock_client,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -1285,104 +1283,3 @@ async def test_timers(hass: HomeAssistant) -> None:
timer_finished = mock_client.timer_finished
assert timer_finished is not None
assert timer_finished.id == timer_started.id
async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
"""Test that the same conversation id is used until timeout."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE,
end_stage=PipelineStage.TTS,
restart_on_end=True,
).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# A conversation id should have been generated
conversation_id = pipeline_kwargs.get("conversation_id")
assert conversation_id
# Reset and run again
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be the same conversation id
assert pipeline_kwargs.get("conversation_id") == conversation_id
# Reset and run again, but this time "time out"
satellite._conversation_id_time = None
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Should be a different conversation id
new_conversation_id = pipeline_kwargs.get("conversation_id")
assert new_conversation_id
assert new_conversation_id != conversation_id