Migrate Wyoming satellite to Assist satellite entity (#128488)
* Migrate Wyoming satellite to Assist satellite entity * Fix tests * Update homeassistant/components/wyoming/assist_satellite.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Update homeassistant/components/wyoming/assist_satellite.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
c294130080
commit
bcac851677
12 changed files with 325 additions and 400 deletions
|
@ -14,11 +14,11 @@ from .const import ATTR_SPEAKER, DOMAIN
|
|||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .models import DomainDataItem
|
||||
from .satellite import WyomingSatellite
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SATELLITE_PLATFORMS = [
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
@ -47,51 +47,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
if (satellite_info := service.info.satellite) is not None:
|
||||
# Create satellite device, etc.
|
||||
item.satellite = _make_satellite(hass, entry, service)
|
||||
# Create satellite device
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Set up satellite sensors, switches, etc.
|
||||
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
|
||||
|
||||
# Start satellite communication
|
||||
entry.async_create_background_task(
|
||||
hass,
|
||||
item.satellite.run(),
|
||||
f"Satellite {satellite_info.name}",
|
||||
# Use config entry id since only one satellite per entry is supported
|
||||
satellite_id = entry.entry_id
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
identifiers={(DOMAIN, satellite_id)},
|
||||
name=satellite_info.name,
|
||||
suggested_area=satellite_info.area,
|
||||
)
|
||||
|
||||
entry.async_on_unload(item.satellite.stop)
|
||||
item.device = SatelliteDevice(
|
||||
satellite_id=satellite_id,
|
||||
device_id=device.id,
|
||||
)
|
||||
|
||||
# Set up satellite entity, sensors, switches, etc.
|
||||
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _make_satellite(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
) -> WyomingSatellite:
|
||||
"""Create Wyoming satellite/device from config entry and Wyoming service."""
|
||||
satellite_info = service.info.satellite
|
||||
assert satellite_info is not None
|
||||
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Use config entry id since only one satellite per entry is supported
|
||||
satellite_id = config_entry.entry_id
|
||||
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={(DOMAIN, satellite_id)},
|
||||
name=satellite_info.name,
|
||||
suggested_area=satellite_info.area,
|
||||
)
|
||||
|
||||
satellite_device = SatelliteDevice(
|
||||
satellite_id=satellite_id,
|
||||
device_id=device.id,
|
||||
)
|
||||
|
||||
return WyomingSatellite(hass, config_entry, service, satellite_device)
|
||||
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
@ -102,7 +80,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
platforms = list(item.service.platforms)
|
||||
if item.satellite is not None:
|
||||
if item.device is not None:
|
||||
platforms += SATELLITE_PLATFORMS
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
"""Support for Wyoming satellite services."""
|
||||
"""Assist satellite entity for Wyoming integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Final
|
||||
from uuid import uuid4
|
||||
from typing import Any, Final
|
||||
import wave
|
||||
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
|
@ -18,20 +18,29 @@ from wyoming.info import Describe, Info
|
|||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import PauseSatellite, RunSatellite
|
||||
from wyoming.snd import Played
|
||||
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, intent, stt, tts
|
||||
from homeassistant.components.assist_pipeline import select as pipeline_select
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components import assist_pipeline, intent, tts
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .entity import WyomingSatelliteEntity
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -41,7 +50,6 @@ _RESTART_SECONDS: Final = 3
|
|||
_PING_TIMEOUT: Final = 5
|
||||
_PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
|
@ -52,21 +60,47 @@ _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
|||
}
|
||||
|
||||
|
||||
class WyomingSatellite:
|
||||
"""Remove voice satellite running the Wyoming protocol."""
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming Assist satellite entity."""
|
||||
domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
assert domain_data.device is not None
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingAssistSatellite(
|
||||
hass, domain_data.service, domain_data.device, config_entry
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
"""Assist satellite for Wyoming devices."""
|
||||
|
||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_entity_category = EntityCategory.CONFIG
|
||||
_attr_name = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
device: SatelliteDevice,
|
||||
config_entry: ConfigEntry,
|
||||
) -> None:
|
||||
"""Initialize satellite."""
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
"""Initialize an Assist satellite."""
|
||||
WyomingSatelliteEntity.__init__(self, device)
|
||||
AssistSatelliteEntity.__init__(self)
|
||||
|
||||
self.service = service
|
||||
self.device = device
|
||||
self.config_entry = config_entry
|
||||
|
||||
self.is_running = True
|
||||
|
||||
self._client: AsyncTcpClient | None = None
|
||||
|
@ -84,6 +118,160 @@ class WyomingSatellite:
|
|||
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
return self.device.get_pipeline_entity_id(self.hass)
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
return self.device.get_vad_sensitivity_entity_id(self.hass)
|
||||
|
||||
@property
|
||||
def tts_options(self) -> dict[str, Any] | None:
|
||||
"""Options passed for text-to-speech."""
|
||||
return {
|
||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
}
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
self.start_satellite()
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
await super().async_will_remove_from_hass()
|
||||
self.stop_satellite()
|
||||
|
||||
@callback
|
||||
def async_get_configuration(
|
||||
self,
|
||||
) -> AssistSatelliteConfiguration:
|
||||
"""Get the current satellite configuration."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_configuration(
|
||||
self, config: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Set the current satellite configuration."""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
assert self._client is not None
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
# Inform client of wake word detection
|
||||
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||
detection = Detection(
|
||||
name=wake_word_output["wake_word_id"],
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
voice=SynthesizeVoice(
|
||||
name=event.data.get("voice"),
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||
# Pipeline error
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Error(
|
||||
text=event.data["message"], code=event.data["code"]
|
||||
).event()
|
||||
)
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def start_satellite(self) -> None:
|
||||
"""Start satellite task."""
|
||||
self.is_running = True
|
||||
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self.run(), "wyoming satellite run"
|
||||
)
|
||||
|
||||
def stop_satellite(self) -> None:
|
||||
"""Signal satellite task to stop running."""
|
||||
# Stop existing pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
# Tell satellite to stop running
|
||||
self._send_pause()
|
||||
|
||||
# Stop task loop
|
||||
self.is_running = False
|
||||
|
||||
# Unblock waiting for unmuted
|
||||
self._muted_changed_event.set()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run and maintain a connection to satellite."""
|
||||
_LOGGER.debug("Running satellite task")
|
||||
|
@ -110,6 +298,9 @@ class WyomingSatellite:
|
|||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))
|
||||
|
||||
# Stop any existing pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
# Ensure sensor is off (before restart)
|
||||
self.device.set_is_active(False)
|
||||
|
||||
|
@ -123,17 +314,6 @@ class WyomingSatellite:
|
|||
|
||||
await self.on_stopped()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal satellite task to stop running."""
|
||||
# Tell satellite to stop running
|
||||
self._send_pause()
|
||||
|
||||
# Stop task loop
|
||||
self.is_running = False
|
||||
|
||||
# Unblock waiting for unmuted
|
||||
self._muted_changed_event.set()
|
||||
|
||||
async def on_restart(self) -> None:
|
||||
"""Block until pipeline loop will be restarted."""
|
||||
_LOGGER.warning(
|
||||
|
@ -151,7 +331,7 @@ class WyomingSatellite:
|
|||
await asyncio.sleep(_RECONNECT_SECONDS)
|
||||
|
||||
async def on_muted(self) -> None:
|
||||
"""Block until device may be unmated again."""
|
||||
"""Block until device may be unmuted again."""
|
||||
await self._muted_changed_event.wait()
|
||||
|
||||
async def on_stopped(self) -> None:
|
||||
|
@ -252,6 +432,7 @@ class WyomingSatellite:
|
|||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
if pipeline_ended_task in done:
|
||||
# Pipeline run end event was received
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
@ -302,7 +483,7 @@ class WyomingSatellite:
|
|||
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
|
||||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
self._audio_queue.put_nowait(None)
|
||||
elif Info.is_type(client_event.type):
|
||||
client_info = Info.from_event(client_event)
|
||||
_LOGGER.debug("Updated client info: %s", client_info)
|
||||
|
@ -329,6 +510,9 @@ class WyomingSatellite:
|
|||
break
|
||||
|
||||
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
|
||||
elif Played.is_type(client_event.type):
|
||||
# TTS response has finished playing on satellite
|
||||
self.tts_response_finished()
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
|
@ -353,72 +537,20 @@ class WyomingSatellite:
|
|||
if end_stage is None:
|
||||
raise ValueError(f"Invalid end stage: {end_stage}")
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
self.device.satellite_id,
|
||||
)
|
||||
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||
assert pipeline is not None
|
||||
|
||||
# We will push audio in through a queue
|
||||
self._audio_queue = asyncio.Queue()
|
||||
stt_stream = self._stt_stream()
|
||||
|
||||
# Start pipeline running
|
||||
_LOGGER.debug(
|
||||
"Starting pipeline %s from %s to %s",
|
||||
pipeline.name,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
|
||||
# Reset conversation id, if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = str(uuid4())
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
self._is_pipeline_running = True
|
||||
self._pipeline_ended_event.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
assist_pipeline.async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=Context(),
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language=pipeline.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream,
|
||||
self.async_accept_pipeline_from_satellite(
|
||||
audio_stream=self._stt_stream(),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
tts_audio_output="wav",
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
noise_suppression_level=self.device.noise_suppression_level,
|
||||
auto_gain_dbfs=self.device.auto_gain,
|
||||
volume_multiplier=self.device.volume_multiplier,
|
||||
silence_seconds=VadSensitivity.to_seconds(
|
||||
self.device.vad_sensitivity
|
||||
),
|
||||
),
|
||||
device_id=self.device.device_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_id=self._conversation_id,
|
||||
),
|
||||
name="wyoming satellite pipeline",
|
||||
"wyoming satellite pipeline",
|
||||
)
|
||||
|
||||
async def _send_delayed_ping(self) -> None:
|
||||
|
@ -431,91 +563,6 @@ class WyomingSatellite:
|
|||
except ConnectionError:
|
||||
pass # handled with timeout
|
||||
|
||||
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
"""Translate pipeline events into Wyoming events."""
|
||||
assert self._client is not None
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
# Inform client of wake word detection
|
||||
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||
detection = Detection(
|
||||
name=wake_word_output["wake_word_id"],
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
voice=SynthesizeVoice(
|
||||
name=event.data.get("voice"),
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||
# Pipeline error
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Error(
|
||||
text=event.data["message"], code=event.data["code"]
|
||||
).event()
|
||||
)
|
||||
)
|
||||
|
||||
async def _connect(self) -> None:
|
||||
"""Connect to satellite over TCP."""
|
||||
await self._disconnect()
|
||||
|
@ -576,16 +623,16 @@ class WyomingSatellite:
|
|||
|
||||
async def _stt_stream(self) -> AsyncGenerator[bytes]:
|
||||
"""Yield audio chunks from a queue."""
|
||||
try:
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
pass # ignore
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
|
||||
yield chunk
|
||||
|
||||
@callback
|
||||
def _handle_timer(
|
|
@ -28,9 +28,9 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.device is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
|
||||
async_add_entities([WyomingSatelliteAssistInProgress(item.device)])
|
||||
|
||||
|
||||
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
|
||||
|
|
|
@ -6,7 +6,7 @@ from homeassistant.helpers import entity
|
|||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
|
||||
from .const import DOMAIN
|
||||
from .satellite import SatelliteDevice
|
||||
from .devices import SatelliteDevice
|
||||
|
||||
|
||||
class WyomingSatelliteEntity(entity.Entity):
|
||||
|
|
|
@ -3,7 +3,12 @@
|
|||
"name": "Wyoming Protocol",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline", "intent", "conversation"],
|
||||
"dependencies": [
|
||||
"assist_satellite",
|
||||
"assist_pipeline",
|
||||
"intent",
|
||||
"conversation"
|
||||
],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_push",
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .data import WyomingService
|
||||
from .satellite import WyomingSatellite
|
||||
from .devices import SatelliteDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -11,4 +11,4 @@ class DomainDataItem:
|
|||
"""Domain data item."""
|
||||
|
||||
service: WyomingService
|
||||
satellite: WyomingSatellite | None = None
|
||||
device: SatelliteDevice | None = None
|
||||
|
|
|
@ -30,13 +30,12 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.device is not None
|
||||
|
||||
device = item.satellite.device
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSatelliteAutoGainNumber(device),
|
||||
WyomingSatelliteVolumeMultiplierNumber(device),
|
||||
WyomingSatelliteAutoGainNumber(item.device),
|
||||
WyomingSatelliteVolumeMultiplierNumber(item.device),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -42,14 +42,13 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.device is not None
|
||||
|
||||
device = item.satellite.device
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSatellitePipelineSelect(hass, device),
|
||||
WyomingSatelliteNoiseSuppressionLevelSelect(device),
|
||||
WyomingSatelliteVadSensitivitySelect(hass, device),
|
||||
WyomingSatellitePipelineSelect(hass, item.device),
|
||||
WyomingSatelliteNoiseSuppressionLevelSelect(item.device),
|
||||
WyomingSatelliteVadSensitivitySelect(hass, item.device),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -27,9 +27,9 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.device is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite.device)])
|
||||
async_add_entities([WyomingSatelliteMuteSwitch(item.device)])
|
||||
|
||||
|
||||
class WyomingSatelliteMuteSwitch(
|
||||
|
@ -51,7 +51,7 @@ class WyomingSatelliteMuteSwitch(
|
|||
|
||||
# Default to off
|
||||
self._attr_is_on = (state is not None) and (state.state == STATE_ON)
|
||||
self._device.is_muted = self._attr_is_on
|
||||
self._device.set_is_muted(self._attr_is_on)
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on."""
|
||||
|
|
|
@ -150,10 +150,10 @@ async def reload_satellite(
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run"
|
||||
) as _run_mock,
|
||||
):
|
||||
# _run_mock: satellite task does not actually run
|
||||
await hass.config_entries.async_reload(config_entry_id)
|
||||
|
||||
return hass.data[DOMAIN][config_entry_id].satellite.device
|
||||
return hass.data[DOMAIN][config_entry_id].device
|
||||
|
|
|
@ -152,7 +152,7 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.run"
|
||||
) as _run_mock,
|
||||
):
|
||||
# _run_mock: satellite task does not actually run
|
||||
|
@ -164,4 +164,4 @@ async def satellite_device(
|
|||
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
|
||||
) -> SatelliteDevice:
|
||||
"""Get a satellite device fixture."""
|
||||
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
|
||||
return hass.data[DOMAIN][satellite_config_entry.entry_id].device
|
||||
|
|
|
@ -23,6 +23,7 @@ from wyoming.vad import VoiceStarted, VoiceStopped
|
|||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, wyoming
|
||||
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
|
@ -240,23 +241,22 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
assert device is not None
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -443,7 +443,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
"""Test callback for a satellite that has been muted."""
|
||||
on_muted_event = asyncio.Event()
|
||||
|
||||
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
|
||||
original_on_muted = WyomingAssistSatellite.on_muted
|
||||
|
||||
async def on_muted(self):
|
||||
# Trigger original function
|
||||
|
@ -462,12 +462,16 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([]),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.switch.WyomingSatelliteMuteSwitch.async_get_last_state",
|
||||
return_value=State("switch.test_mute", STATE_ON),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_muted",
|
||||
on_muted,
|
||||
),
|
||||
):
|
||||
|
@ -484,11 +488,11 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
"""Test pipeline loop restart after unexpected error."""
|
||||
on_restart_event = asyncio.Event()
|
||||
|
||||
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
|
||||
original_on_restart = WyomingAssistSatellite.on_restart
|
||||
|
||||
async def on_restart(self):
|
||||
await original_on_restart(self)
|
||||
self.stop()
|
||||
self.stop_satellite()
|
||||
on_restart_event.set()
|
||||
|
||||
with (
|
||||
|
@ -497,14 +501,14 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._connect_and_loop",
|
||||
side_effect=RuntimeError(),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -517,7 +521,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
reconnect_event = asyncio.Event()
|
||||
stopped_event = asyncio.Event()
|
||||
|
||||
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
|
||||
original_on_reconnect = WyomingAssistSatellite.on_reconnect
|
||||
|
||||
async def on_reconnect(self):
|
||||
await original_on_reconnect(self)
|
||||
|
@ -526,7 +530,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
num_reconnects += 1
|
||||
if num_reconnects >= 2:
|
||||
reconnect_event.set()
|
||||
self.stop()
|
||||
self.stop_satellite()
|
||||
|
||||
async def on_stopped(self):
|
||||
stopped_event.set()
|
||||
|
@ -537,18 +541,20 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
|
||||
side_effect=ConnectionRefusedError(),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_reconnect",
|
||||
on_reconnect,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
|
||||
),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -561,7 +567,7 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
|
|||
on_restart_event = asyncio.Event()
|
||||
|
||||
async def on_restart(self):
|
||||
self.stop()
|
||||
self.stop_satellite()
|
||||
on_restart_event.set()
|
||||
|
||||
with (
|
||||
|
@ -570,14 +576,14 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient([]), # no RunPipeline event
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
):
|
||||
|
@ -603,7 +609,7 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
|||
async def on_restart(self):
|
||||
# Pretend sensor got stuck on
|
||||
self.device.is_active = True
|
||||
self.stop()
|
||||
self.stop_satellite()
|
||||
on_restart_event.set()
|
||||
|
||||
async def on_stopped(self):
|
||||
|
@ -615,25 +621,23 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
@ -665,11 +669,11 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
|
@ -701,7 +705,7 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||
"""Test satellite receiving non-WAV audio from text-to-speech."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
|
||||
original_stream_tts = WyomingAssistSatellite._stream_tts
|
||||
error_event = asyncio.Event()
|
||||
|
||||
async def _stream_tts(self, media_id):
|
||||
|
@ -724,19 +728,19 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("mp3", bytes(1)),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
|
||||
_stream_tts,
|
||||
),
|
||||
):
|
||||
|
@ -819,18 +823,16 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -893,18 +895,16 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -938,7 +938,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
).event(),
|
||||
]
|
||||
|
||||
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
|
||||
original_run_pipeline_once = WyomingAssistSatellite._run_pipeline_once
|
||||
start_stage_event = asyncio.Event()
|
||||
end_stage_event = asyncio.Event()
|
||||
|
||||
|
@ -967,11 +967,11 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._run_pipeline_once",
|
||||
_run_pipeline_once,
|
||||
),
|
||||
):
|
||||
|
@ -1029,11 +1029,11 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
|
@ -1083,11 +1083,11 @@ async def test_wake_word_phrase(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
|
@ -1114,14 +1114,12 @@ async def test_timers(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([]),
|
||||
) as mock_client,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -1285,104 +1283,3 @@ async def test_timers(hass: HomeAssistant) -> None:
|
|||
timer_finished = mock_client.timer_finished
|
||||
assert timer_finished is not None
|
||||
assert timer_finished.id == timer_started.id
|
||||
|
||||
|
||||
async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
|
||||
"""Test that the same conversation id is used until timeout."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE,
|
||||
end_stage=PipelineStage.TTS,
|
||||
restart_on_end=True,
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_kwargs: dict[str, Any] = {}
|
||||
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
||||
None
|
||||
)
|
||||
run_pipeline_called = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
nonlocal pipeline_kwargs, pipeline_event_callback
|
||||
pipeline_kwargs = kwargs
|
||||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
assert pipeline_event_callback is not None
|
||||
|
||||
# A conversation id should have been generated
|
||||
conversation_id = pipeline_kwargs.get("conversation_id")
|
||||
assert conversation_id
|
||||
|
||||
# Reset and run again
|
||||
run_pipeline_called.clear()
|
||||
pipeline_kwargs.clear()
|
||||
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||
)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
# Should be the same conversation id
|
||||
assert pipeline_kwargs.get("conversation_id") == conversation_id
|
||||
|
||||
# Reset and run again, but this time "time out"
|
||||
satellite._conversation_id_time = None
|
||||
run_pipeline_called.clear()
|
||||
pipeline_kwargs.clear()
|
||||
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||
)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
|
||||
# Should be a different conversation id
|
||||
new_conversation_id = pipeline_kwargs.get("conversation_id")
|
||||
assert new_conversation_id
|
||||
assert new_conversation_id != conversation_id
|
||||
|
|
Loading…
Add table
Reference in a new issue