Compare commits

...
Sign in to create a new pull request.

25 commits

Author SHA1 Message Date
Michael Hansen
b77dff5e65 Fix wyoming version 2024-08-26 09:54:42 -05:00
Michael Hansen
07ac3df4ff Add more tests 2024-08-26 09:34:19 -05:00
Michael Hansen
d11dace22b Add remote trigger to Wyoming satellites 2024-08-26 09:34:19 -05:00
Michael Hansen
59a6b1ebfa Add trigger services 2024-08-26 09:34:19 -05:00
Michael Hansen
1d2bced1f0
Merge branch 'assist-satellite' into synesthesiam-20240822-assist-satellite-trigger 2024-08-26 09:29:55 -05:00
Michael Hansen
644427ecc7
Add Assist satellite entity + VoIP (#123830)
* Add assist_satellite and implement VoIP

* Fix tests

* More tests

* Improve test

* Update entity state

* Set state correctly

* Move more functionality into base class

* Move RTP protocol into entity

* Fix tests

* Remove string

* Move to util method

* Align states better with pipeline events

* Remove public async_get_satellite_entity

* WAITING_FOR_WAKE_WORD

* Pass entity ids for pipeline/vad sensitivity

* Remove connect/disconnect

* Clean up

* Final cleanup
2024-08-25 16:19:36 +02:00
Michael Hansen
d48fcb3221 Add tts_input to _async_accept_pipeline_from_satellite 2024-08-22 16:05:13 -05:00
Michael Hansen
c468d9c5c9 Final cleanup 2024-08-21 10:34:42 -05:00
Michael Hansen
33d0d2cfed Clean up 2024-08-21 08:32:52 -05:00
Michael Hansen
a4876e435c Remove connect/disconnect 2024-08-21 08:11:46 -05:00
Michael Hansen
93da8de1e4 Pass entity ids for pipeline/vad sensitivity 2024-08-20 11:19:19 -05:00
Michael Hansen
bd0a97a3b7 WAITING_FOR_WAKE_WORD 2024-08-19 16:31:05 -05:00
Michael Hansen
9a483613e1 Remove public async_get_satellite_entity 2024-08-19 16:03:49 -05:00
Michael Hansen
f1c0bdf5be Align states better with pipeline events 2024-08-19 14:21:33 -05:00
Michael Hansen
f4d6e46fed Move to util method 2024-08-19 14:12:48 -05:00
Michael Hansen
ecec1d3208 Remove string 2024-08-19 12:04:08 -05:00
Michael Hansen
66be7b9648 Fix tests 2024-08-19 12:02:28 -05:00
Michael Hansen
f6e5d2d80b Move RTP protocol into entity 2024-08-19 10:55:41 -05:00
Michael Hansen
712e4e5f50 Move more functionality into base class 2024-08-19 10:55:41 -05:00
Michael Hansen
1e1623309d Set state correctly 2024-08-19 10:55:41 -05:00
Michael Hansen
d32a681f28 Update entity state 2024-08-19 10:55:41 -05:00
Michael Hansen
337fe974f7 Improve test 2024-08-19 10:55:41 -05:00
Michael Hansen
d7e9f6aae4 More tests 2024-08-19 10:55:41 -05:00
Michael Hansen
ec1866e131 Fix tests 2024-08-19 10:55:41 -05:00
Michael Hansen
b21e2360b9 Add assist_satellite and implement VoIP 2024-08-19 10:55:40 -05:00
46 changed files with 2275 additions and 1022 deletions

View file

@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
/tests/components/aseko_pool_live/ @milanmeu
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @synesthesiam
/tests/components/assist_satellite/ @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View file

@ -16,6 +16,7 @@ from .const import (
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
OPTION_PREFERRED,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
@ -57,6 +58,7 @@ __all__ = (
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
"OPTION_PREFERRED",
"SAMPLES_PER_CHUNK",
"SAMPLE_RATE",
"SAMPLE_WIDTH",
@ -100,6 +102,7 @@ async def async_pipeline_from_audio_stream(
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
tts_input: str | None = None,
wake_word_settings: WakeWordSettings | None = None,
audio_settings: AudioSettings | None = None,
device_id: str | None = None,
@ -116,6 +119,7 @@ async def async_pipeline_from_audio_stream(
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
tts_input=tts_input,
run=PipelineRun(
hass,
context=context,

View file

@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
MS_PER_CHUNK = 10
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"

View file

@ -504,7 +504,7 @@ class AudioSettings:
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
silence_seconds: float = 0.5
silence_seconds: float = 0.7
"""Seconds of silence after voice command has ended."""
def __post_init__(self) -> None:
@ -906,6 +906,8 @@ class PipelineRun:
metadata,
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
)
except (asyncio.CancelledError, TimeoutError):
raise # expected
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(

View file

@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(

View file

@ -0,0 +1,104 @@
"""Base class for assist satellite entities."""
import logging
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, SupportsResponse
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .entity import AssistSatelliteEntity
from .models import (
AssistSatelliteEntityFeature,
AssistSatelliteState,
PipelineRunConfig,
PipelineRunResult,
)
__all__ = [
"DOMAIN",
"AssistSatelliteEntity",
"AssistSatelliteEntityFeature",
"AssistSatelliteState",
"PipelineRunConfig",
"PipelineRunResult",
"SERVICE_WAIT_WAKE",
"SERVICE_GET_COMMAND",
"SERVICE_SAY_TEXT",
"ATTR_WAKE_WORDS",
"ATTR_PROCESS",
"ATTR_ANNOUNCE_TEXT",
]
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
ATTR_WAKE_WORDS = "wake_words"
ATTR_PROCESS = "process"
ATTR_ANNOUNCE_TEXT = "announce_text"
SERVICE_WAIT_WAKE = "wait_wake"
SERVICE_GET_COMMAND = "get_command"
SERVICE_SAY_TEXT = "say_text"
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
component.async_register_entity_service(
name=SERVICE_WAIT_WAKE,
schema=cv.make_entity_service_schema(
{
vol.Required(ATTR_WAKE_WORDS): [cv.string],
vol.Optional(ATTR_ANNOUNCE_TEXT): cv.string,
}
),
func="async_wait_wake",
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
supports_response=SupportsResponse.OPTIONAL,
)
component.async_register_entity_service(
name=SERVICE_GET_COMMAND,
schema=cv.make_entity_service_schema(
{
vol.Optional(ATTR_PROCESS): cv.boolean,
vol.Optional(ATTR_ANNOUNCE_TEXT): cv.string,
}
),
func="async_get_command",
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
supports_response=SupportsResponse.OPTIONAL,
)
component.async_register_entity_service(
name=SERVICE_SAY_TEXT,
schema=cv.make_entity_service_schema(
{vol.Required(ATTR_ANNOUNCE_TEXT): cv.string}
),
func="async_say_text",
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
supports_response=SupportsResponse.NONE,
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View file

@ -0,0 +1,3 @@
"""Constants for assist satellite."""
DOMAIN = "assist_satellite"

View file

@ -0,0 +1,231 @@
"""Assist satellite entity."""
from collections.abc import AsyncIterable
import time
from typing import Final
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.const import EntityCategory
from homeassistant.core import Context
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .models import (
AssistSatelliteEntityFeature,
AssistSatelliteState,
PipelineRunConfig,
PipelineRunResult,
)
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description = EntityDescription(
key="assist_satellite",
translation_key="assist_satellite",
entity_category=EntityCategory.CONFIG,
)
_attr_has_entity_name = True
_attr_name = None
_attr_should_poll = False
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
_attr_supported_features = AssistSatelliteEntityFeature(0)
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
async def async_trigger_pipeline_on_satellite(
self,
start_stage: PipelineStage,
end_stage: PipelineStage,
run_config: PipelineRunConfig,
) -> PipelineRunResult | None:
"""Run a pipeline on the satellite from start to end stage.
Can be called from a service.
Requires TRIGGER_PIPELINE supported feature.
- announce when start/end = "tts"
- listen for wake word when start/end = "wake"
- listen for command when start/end = "stt" (no processing)
- listen for command when start = "stt", end = "tts" (with processing)
"""
raise NotImplementedError
async def async_wait_wake(
self, wake_words: list[str], announce_text: str | None = None
) -> str | None:
"""Listen for one or more wake words on the satellite.
Returns the detected wake word phrase or None.
"""
if announce_text:
await self.async_say_text(announce_text)
result = await self.async_trigger_pipeline_on_satellite(
PipelineStage.WAKE_WORD,
PipelineStage.WAKE_WORD,
PipelineRunConfig(wake_word_names=wake_words),
)
if result is None:
return None
return result.detected_wake_word
async def async_get_command(
self, process: bool = False, announce_text: str | None = None
) -> str | None:
"""Get the text of a voice command from the satellite, optionally processing it.
Returns the spoken text or None.
"""
if announce_text:
await self.async_say_text(announce_text)
if process:
end_stage = PipelineStage.TTS
else:
end_stage = PipelineStage.STT
result = await self.async_trigger_pipeline_on_satellite(
PipelineStage.STT, end_stage, PipelineRunConfig()
)
if result is None:
return None
return result.command_text
async def async_say_text(self, announce_text: str) -> None:
"""Speak the text on the satellite."""
await self.async_trigger_pipeline_on_satellite(
PipelineStage.TTS,
PipelineStage.TTS,
PipelineRunConfig(announce_text=announce_text),
)
async def _async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
pipeline_entity_id: str | None = None,
vad_sensitivity_entity_id: str | None = None,
wake_word_phrase: str | None = None,
tts_input: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
pipeline_id: str | None = None
vad_sensitivity = vad.VadSensitivity.DEFAULT
if pipeline_entity_id:
# Resolve pipeline by name
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
if (pipeline_entity_state is not None) and (
pipeline_entity_state.state != OPTION_PREFERRED
):
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
pipeline_id = pipeline.id
break
if vad_sensitivity_entity_id:
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
if vad_sensitivity_state is not None:
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
device_id: str | None = None
if self.registry_entry is not None:
device_id = self.registry_entry.device_id
# Refresh context if necessary
if (
(self._context is None)
or (self._context_set is None)
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
):
self.async_set_context(Context())
assert self._context is not None
# Reset conversation id if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events
self._run_has_tts = False
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self.on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=pipeline_id,
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
tts_input=tts_input,
audio_settings=AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
),
start_stage=start_stage,
end_stage=end_stage,
)
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type == PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type == PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type == PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type == PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type == PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
def _set_state(self, state: AssistSatelliteState):
"""Set the entity's state."""
self._attr_state = state
self.async_write_ha_state()
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

View file

@ -0,0 +1,12 @@
{
"entity_component": {
"_": {
"default": "mdi:comment-processing-outline"
}
},
"services": {
"wait_wake": "mdi:microphone-message",
"get_command": "mdi:comment-text-outline",
"say_text": "mdi:speaker-message"
}
}

View file

@ -0,0 +1,9 @@
{
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@synesthesiam"],
"config_flow": false,
"dependencies": ["assist_pipeline", "stt"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity"
}

View file

@ -0,0 +1,49 @@
"""Models for assist satellite."""
from dataclasses import dataclass
from enum import IntFlag, StrEnum
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
LISTENING_WAKE_WORD = "listening_wake_word"
"""Device is streaming audio for wake word detection to Home Assistant."""
LISTENING_COMMAND = "listening_command"
"""Device is streaming audio with the voice command to Home Assistant."""
PROCESSING = "processing"
"""Home Assistant is processing the voice command."""
RESPONDING = "responding"
"""Device is speaking the response."""
class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""
TRIGGER_PIPELINE = 1
"""Device supports remote triggering of a pipeline."""
@dataclass(frozen=True)
class PipelineRunConfig:
"""Configuration for a satellite pipeline run."""
wake_word_names: list[str] | None = None
"""Wake word names to listen for (start_stage = wake)."""
announce_text: str | None = None
"""Text to announce using text-to-speech (start_stage = wake, stt, or tts)."""
@dataclass(frozen=True)
class PipelineRunResult:
"""Result of a pipeline run."""
detected_wake_word: str | None = None
"""Name of detected wake word (None if timeout)."""
command_text: str | None = None
"""Transcript of speech-to-text for voice command."""

View file

@ -0,0 +1,46 @@
wait_wake:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
fields:
wake_words:
required: true
example: "ok nabu"
selector:
text:
multiple: true
announce_text:
required: false
example: "Please say ok nabu."
selector:
text:
get_command:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
fields:
process:
required: false
selector:
boolean:
announce_text:
required: false
example: "What would you like for dinner?"
selector:
text:
say_text:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
fields:
announce_text:
required: true
example: "Dinner is ready!"
selector:
text:

View file

@ -0,0 +1,54 @@
{
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
}
},
"services": {
"wait_wake": {
"name": "Wait for wake words",
"description": "Wait for one or more wake words to be spoken",
"fields": {
"wake_words": {
"name": "Wake words",
"description": "Names of wake words to wait for"
},
"announce_text": {
"name": "Announce text",
"description": "Text to speak before waiting for wake words"
}
}
},
"get_command": {
"name": "Get voice command from satellite",
"description": "Records and transcribes a command from a voice satellite",
"fields": {
"process": {
"name": "Process command",
"description": "Process the text of the command in Home Assistant"
},
"announce_text": {
"name": "announce_text",
"description": "Text to speak before recording command"
}
}
},
"say_text": {
"name": "Say text",
"description": "Speak text from a voice satellite",
"fields": {
"announce_text": {
"name": "Announce text",
"description": "Text to speak"
}
}
}
}
}

View file

@ -20,6 +20,7 @@ from .devices import VoIPDevices
from .voip import HassVoipDatagramProtocol
PLATFORMS = (
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,

View file

@ -0,0 +1,298 @@
"""Assist satellite entity for VoIP integration."""
from __future__ import annotations
import asyncio
from enum import IntFlag
from functools import partial
import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Final
import wave
from voip_utils import RtpDatagramProtocol
from homeassistant.components import tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
)
from homeassistant.components.assist_satellite import AssistSatelliteEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.async_ import queue_to_iterable
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity
if TYPE_CHECKING:
from . import DomainData
_LOGGER = logging.getLogger(__name__)
_PIPELINE_TIMEOUT_SEC: Final = 30
class Tones(IntFlag):
"""Feedback tones for specific events."""
LISTENING = 1
PROCESSING = 2
ERROR = 4
_TONE_FILENAMES: dict[Tones, str] = {
Tones.LISTENING: "tone.pcm",
Tones.PROCESSING: "processing.pcm",
Tones.ERROR: "error.pcm",
}
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP Assist satellite entity."""
domain_data: DomainData = hass.data[DOMAIN]
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
domain_data.devices.async_add_new_device_listener(async_add_device)
entities: list[VoIPEntity] = [
VoipAssistSatellite(hass, device, config_entry)
for device in domain_data.devices
]
async_add_entities(entities)
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
"""Assist satellite for VoIP devices."""
def __init__(
self,
hass: HomeAssistant,
voip_device: VoIPDevice,
config_entry: ConfigEntry,
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
) -> None:
"""Initialize an Assist satellite."""
VoIPEntity.__init__(self, voip_device)
AssistSatelliteEntity.__init__(self)
RtpDatagramProtocol.__init__(self)
self.config_entry = config_entry
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._audio_chunk_timeout: float = 2.0
self._pipeline_task: asyncio.Task | None = None
self._pipeline_had_error: bool = False
self._tts_done = asyncio.Event()
self._tts_extra_timeout: float = 1.0
self._tone_bytes: dict[Tones, bytes] = {}
self._tones = tones
self._processing_tone_done = asyncio.Event()
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
self.voip_device.protocol = self
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
if self.voip_device.protocol == self:
self.voip_device.protocol = None
# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
self.voip_device.set_is_active(True)
# Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2)
try:
self._tts_done.clear()
# Run pipeline with a timeout
_LOGGER.debug("Starting pipeline")
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
audio_stream=queue_to_iterable(
self._audio_queue, timeout=self._audio_chunk_timeout
),
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
self.hass
),
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
self.hass
),
)
if self._pipeline_had_error:
self._pipeline_had_error = False
await self._play_tone(Tones.ERROR)
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except (asyncio.CancelledError, TimeoutError):
# Expected after caller hangs up
_LOGGER.debug("Pipeline cancelled or timed out")
self.disconnect()
self._clear_audio_queue()
finally:
self.voip_device.set_is_active(False)
# Allow pipeline to run again
self._pipeline_task = None
def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
super().on_pipeline_event(event)
if event.type == PipelineEventType.STT_END:
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
self._processing_tone_done.clear()
self.config_entry.async_create_background_task(
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
)
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.config_entry.async_create_background_task(
self.hass,
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished.
self._pipeline_had_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return # not connected
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
# Don't overlap TTS and processing beep
await self._processing_tone_done.wait()
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
# Update satellite state
self.tts_response_finished()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
"""Play a tone as feedback to the user if it's enabled."""
if (self._tones & tone) != tone:
return # not enabled
if tone not in self._tone_bytes:
# Do I/O in executor
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
self._load_pcm,
_TONE_FILENAMES[tone],
)
await self._async_send_audio(
self._tone_bytes[tone],
silence_before=silence_before,
)
if tone == Tones.PROCESSING:
self._processing_tone_done.set()
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()

View file

@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
self.async_on_remove(
self.voip_device.async_listen_update(self._is_active_changed)
)
@callback
def _is_active_changed(self, device: VoIPDevice) -> None:
"""Call when active state changed."""
self._attr_is_on = self._device.is_active
self._attr_is_on = self.voip_device.is_active
self.async_write_ha_state()

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field
from voip_utils import CallInfo
from voip_utils import CallInfo, VoipDatagramProtocol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Event, HomeAssistant, callback
@ -22,6 +22,7 @@ class VoIPDevice:
device_id: str
is_active: bool = False
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
protocol: VoipDatagramProtocol | None = None
@callback
def set_is_active(self, active: bool) -> None:
@ -56,6 +57,18 @@ class VoIPDevice:
return False
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for pipeline select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for VAD sensitivity."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
)
class VoIPDevices:
"""Class to store devices."""

View file

@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
_attr_has_entity_name = True
_attr_should_poll = False
def __init__(self, device: VoIPDevice) -> None:
def __init__(self, voip_device: VoIPDevice) -> None:
"""Initialize VoIP entity."""
self._device = device
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
self.voip_device = voip_device
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, device.voip_id)},
identifiers={(DOMAIN, voip_device.voip_id)},
)

View file

@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline"],
"dependencies": ["assist_pipeline", "assist_satellite"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",

View file

@ -10,6 +10,16 @@
}
},
"entity": {
"assist_satellite": {
"assist_satellite": {
"state": {
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
}
}
},
"binary_sensor": {
"call_in_progress": {
"name": "Call in progress"

View file

@ -3,15 +3,11 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial
import io
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING
import wave
from voip_utils import (
CallInfo,
@ -21,33 +17,19 @@ from voip_utils import (
VoipDatagramProtocol,
)
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import (
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
async_get_pipeline,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadEnhancer,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
from homeassistant.util.ulid import ulid_now
from homeassistant.core import HomeAssistant
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
from .devices import VoIPDevices
_LOGGER = logging.getLogger(__name__)
@ -60,11 +42,8 @@ def make_protocol(
) -> VoipDatagramProtocol:
"""Plays a pre-recorded message if pipeline is misconfigured."""
voip_device = devices.async_get_or_create(call_info)
pipeline_id = pipeline_select.get_chosen_pipeline(
hass,
DOMAIN,
voip_device.voip_id,
)
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
try:
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
except PipelineNotFound:
@ -83,22 +62,18 @@ def make_protocol(
rtcp_state=rtcp_state,
)
vad_sensitivity = pipeline_select.get_vad_sensitivity(
hass,
DOMAIN,
voip_device.voip_id,
)
if (protocol := voip_device.protocol) is None:
raise ValueError("VoIP satellite not found")
# Pipeline is properly configured
return PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(user_id=devices.config_entry.data["user"]),
opus_payload_type=call_info.opus_payload_type,
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
rtcp_state=rtcp_state,
)
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
protocol.rtcp_state = rtcp_state
if protocol.rtcp_state is not None:
# Automatically disconnect when BYE is received over RTCP
protocol.rtcp_state.bye_callback = protocol.disconnect
return protocol
class HassVoipDatagramProtocol(VoipDatagramProtocol):
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
await self._closed_event.wait()
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
"""Run a voice assistant pipeline in a loop for a VoIP call."""
def __init__(
self,
hass: HomeAssistant,
language: str,
voip_device: VoIPDevice,
context: Context,
opus_payload_type: int,
pipeline_timeout: float = 30.0,
audio_timeout: float = 2.0,
buffered_chunks_before_speech: int = 100,
listening_tone_enabled: bool = True,
processing_tone_enabled: bool = True,
error_tone_enabled: bool = True,
tone_delay: float = 0.2,
tts_extra_timeout: float = 1.0,
silence_seconds: float = 1.0,
rtcp_state: RtcpState | None = None,
) -> None:
"""Set up pipeline RTP server."""
super().__init__(
rate=RATE,
width=WIDTH,
channels=CHANNELS,
opus_payload_type=opus_payload_type,
rtcp_state=rtcp_state,
)
self.hass = hass
self.language = language
self.voip_device = voip_device
self.pipeline: Pipeline | None = None
self.pipeline_timeout = pipeline_timeout
self.audio_timeout = audio_timeout
self.buffered_chunks_before_speech = buffered_chunks_before_speech
self.listening_tone_enabled = listening_tone_enabled
self.processing_tone_enabled = processing_tone_enabled
self.error_tone_enabled = error_tone_enabled
self.tone_delay = tone_delay
self.tts_extra_timeout = tts_extra_timeout
self.silence_seconds = silence_seconds
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._context = context
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
self._tts_done = asyncio.Event()
self._session_id: str | None = None
self._tone_bytes: bytes | None = None
self._processing_bytes: bytes | None = None
self._error_bytes: bytes | None = None
self._pipeline_error: bool = False
def connection_made(self, transport):
"""Server is ready."""
super().connection_made(transport)
self.voip_device.set_is_active(True)
def connection_lost(self, exc):
"""Handle connection is lost or closed."""
super().connection_lost(exc)
self.voip_device.set_is_active(False)
def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._pipeline_task is None:
self._clear_audio_queue()
# Run pipeline until voice command finishes, then start over
self._pipeline_task = self.hass.async_create_background_task(
self._run_pipeline(),
"voip_pipeline_run",
)
self._audio_queue.put_nowait(audio_bytes)
async def _run_pipeline(
self,
) -> None:
"""Forward audio to pipeline STT and handle TTS."""
if self._session_id is None:
self._session_id = ulid_now()
# Play listening tone at the start of each cycle
if self.listening_tone_enabled:
await self._play_listening_tone()
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
audio_enhancer = MicroVadEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
audio_enhancer,
chunk_buffer,
)
if not speech_detected:
_LOGGER.debug("No speech detected")
return
_LOGGER.debug("Starting pipeline")
self._tts_done.clear()
async def stt_stream():
try:
async for chunk in self._segment_audio(
segmenter,
audio_enhancer,
chunk_buffer,
):
yield chunk
if self.processing_tone_enabled:
await self._play_processing_tone()
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Audio timeout")
self._session_id = None
self.disconnect()
finally:
self._clear_audio_queue()
# Run pipeline with a timeout
async with asyncio.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="wav",
)
if self._pipeline_error:
self._pipeline_error = False
if self.error_tone_enabled:
await self._play_error_tone()
else:
# Block until TTS is done speaking.
#
# This is set in _send_tts and has a timeout that's based on the
# length of the TTS audio.
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
_LOGGER.warning("Pipeline not found")
except TimeoutError:
# Expected after caller hangs up
_LOGGER.debug("Pipeline timeout")
self._session_id = None
self.disconnect()
finally:
# Allow pipeline to run again
self._pipeline_task = None
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
Returns True if speech was detected, False otherwise.
"""
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
chunk_buffer.append(vad_buffer.bytes())
return True
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
# This means the caller hung up.
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
while chunk:
if not segmenter.process_with_vad(
chunk,
assist_pipeline.SAMPLES_PER_CHUNK,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished
break
yield chunk
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
def _clear_audio_queue(self) -> None:
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
def _event_callback(self, event: PipelineEvent):
if not event.data:
return
if event.type == PipelineEventType.INTENT_END:
# Capture conversation id
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP
tts_output = event.data["tts_output"]
if tts_output:
media_id = tts_output["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id),
"voip_pipeline_tts",
)
else:
# Empty TTS response
self._tts_done.set()
elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS
self._pipeline_error = True
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
if (
(sample_rate != RATE)
or (sample_width != WIDTH)
or (sample_channels != CHANNELS)
):
raise ValueError(
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
f" got {sample_rate}/{sample_width}/{sample_channels}"
)
audio_bytes = wav_file.readframes(wav_file.getnframes())
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except TimeoutError:
_LOGGER.warning("TTS timeout")
raise
finally:
# Signal pipeline to restart
self._tts_done.set()
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
"""Send audio in executor."""
await self.hass.async_add_executor_job(
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
)
async def _play_listening_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is listening."""
if self._tone_bytes is None:
# Do I/O in executor
self._tone_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"tone.pcm",
)
await self._async_send_audio(
self._tone_bytes,
silence_before=self.tone_delay,
)
async def _play_processing_tone(self) -> None:
"""Play a tone to indicate that Home Assistant is processing the voice command."""
if self._processing_bytes is None:
# Do I/O in executor
self._processing_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"processing.pcm",
)
await self._async_send_audio(self._processing_bytes)
async def _play_error_tone(self) -> None:
"""Play a tone to indicate a pipeline error occurred."""
if self._error_bytes is None:
# Do I/O in executor
self._error_bytes = await self.hass.async_add_executor_job(
self._load_pcm,
"error.pcm",
)
await self._async_send_audio(self._error_bytes)
def _load_pcm(self, file_name: str) -> bytes:
"""Load raw audio (16Khz, 16-bit mono)."""
return (Path(__file__).parent / file_name).read_bytes()
class PreRecordMessageProtocol(RtpDatagramProtocol):
"""Plays a pre-recorded message on a loop."""

View file

@ -14,11 +14,11 @@ from .const import ATTR_SPEAKER, DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
from .models import DomainDataItem
from .satellite import WyomingSatellite
_LOGGER = logging.getLogger(__name__)
SATELLITE_PLATFORMS = [
Platform.ASSIST_SATELLITE,
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,
@ -47,51 +47,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry.async_on_unload(entry.add_update_listener(update_listener))
if (satellite_info := service.info.satellite) is not None:
# Create satellite device, etc.
item.satellite = _make_satellite(hass, entry, service)
# Create satellite device
dev_reg = dr.async_get(hass)
# Set up satellite sensors, switches, etc.
# Use config entry id since only one satellite per entry is supported
satellite_id = entry.entry_id
device = dev_reg.async_get_or_create(
config_entry_id=entry.entry_id,
identifiers={(DOMAIN, satellite_id)},
name=satellite_info.name,
suggested_area=satellite_info.area,
)
item.satellite_device = SatelliteDevice(satellite_id, device.id)
# Set up satellite entity, sensors, switches, etc.
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
# Start satellite communication
entry.async_create_background_task(
hass,
item.satellite.run(),
f"Satellite {satellite_info.name}",
)
entry.async_on_unload(item.satellite.stop)
return True
def _make_satellite(
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
) -> WyomingSatellite:
"""Create Wyoming satellite/device from config entry and Wyoming service."""
satellite_info = service.info.satellite
assert satellite_info is not None
dev_reg = dr.async_get(hass)
# Use config entry id since only one satellite per entry is supported
satellite_id = config_entry.entry_id
device = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={(DOMAIN, satellite_id)},
name=satellite_info.name,
suggested_area=satellite_info.area,
)
satellite_device = SatelliteDevice(
satellite_id=satellite_id,
device_id=device.id,
)
return WyomingSatellite(hass, config_entry, service, satellite_device)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
@ -102,7 +76,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
platforms = list(item.service.platforms)
if item.satellite is not None:
if item.satellite_device is not None:
platforms += SATELLITE_PLATFORMS
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)

View file

@ -1,12 +1,11 @@
"""Support for Wyoming satellite services."""
import asyncio
from collections.abc import AsyncGenerator
from collections import defaultdict, deque
import io
import logging
import time
from typing import Final
from uuid import uuid4
import wave
from wyoming.asr import Transcribe, Transcript
@ -18,20 +17,23 @@ from wyoming.info import Describe, Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import PauseSatellite, RunSatellite
from wyoming.snd import Played
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, intent, stt, tts
from homeassistant.components.assist_pipeline import select as pipeline_select
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components import assist_pipeline, assist_satellite, intent, tts
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.async_ import queue_to_iterable
from .const import DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
from .entity import WyomingEntity
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__)
@ -41,19 +43,47 @@ _RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_STOP_CHUNK: Final = b""
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
_ASSIST_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
}
_WYOMING_STAGES: dict[assist_pipeline.PipelineStage, PipelineStage] = {
assist_pipeline.PipelineStage.WAKE_WORD: PipelineStage.WAKE,
assist_pipeline.PipelineStage.STT: PipelineStage.ASR,
assist_pipeline.PipelineStage.INTENT: PipelineStage.HANDLE,
assist_pipeline.PipelineStage.TTS: PipelineStage.TTS,
}
class WyomingSatellite:
"""Remove voice satellite running the Wyoming protocol."""
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP Assist satellite entity."""
domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
assert domain_data.satellite_device is not None
async_add_entities(
[
WyomingSatellite(
hass, config_entry, domain_data.service, domain_data.satellite_device
)
]
)
class WyomingSatellite(WyomingEntity, assist_satellite.AssistSatelliteEntity):
"""Remote voice satellite running the Wyoming protocol."""
_attr_supported_features = (
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
)
def __init__(
self,
@ -63,6 +93,9 @@ class WyomingSatellite:
device: SatelliteDevice,
) -> None:
"""Initialize satellite."""
WyomingEntity.__init__(self, device)
assist_satellite.AssistSatelliteEntity.__init__(self)
self.hass = hass
self.config_entry = config_entry
self.service = service
@ -70,20 +103,171 @@ class WyomingSatellite:
self.is_running = True
self._client: AsyncTcpClient | None = None
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
self._chunk_converter = AudioChunkConverter(
rate=assist_pipeline.SAMPLE_RATE,
width=assist_pipeline.SAMPLE_WIDTH,
channels=assist_pipeline.SAMPLE_CHANNELS,
)
self._is_pipeline_running = False
self._pipeline_ended_event = asyncio.Event()
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._pipeline_id: str | None = None
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._muted_changed_event = asyncio.Event()
self._conversation_id: str | None = None
self._conversation_id_time: float | None = None
# Results of remotely triggered pipelines
self._pipeline_result_futures: dict[
assist_pipeline.PipelineStage, deque[asyncio.Future]
] = defaultdict(deque)
self._played_timeout_id: int | None = None
self.device.set_is_muted_listener(self._muted_changed)
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
self.config_entry.async_create_background_task(
self.hass, self.run(), "wyoming_satellite_run"
)
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
self.stop()
async def async_trigger_pipeline_on_satellite(
self,
start_stage: assist_pipeline.PipelineStage,
end_stage: assist_pipeline.PipelineStage,
run_config: assist_satellite.PipelineRunConfig,
) -> assist_satellite.PipelineRunResult | None:
"""Run a pipeline on the satellite from start to end stage."""
if self._client is None:
return None # not connected
result_future: asyncio.Future[str | None] = asyncio.Future()
self._pipeline_result_futures[start_stage].append(result_future)
await self._client.write_event(
RunPipeline(
start_stage=_WYOMING_STAGES[start_stage],
end_stage=_WYOMING_STAGES[end_stage],
wake_word_names=run_config.wake_word_names,
announce_text=run_config.announce_text,
).event()
)
# Wait for result
result = await result_future
if start_stage == assist_pipeline.PipelineStage.WAKE_WORD:
return assist_satellite.PipelineRunResult(detected_wake_word=result)
if start_stage == assist_pipeline.PipelineStage.STT:
return assist_satellite.PipelineRunResult(command_text=result)
return None
def on_pipeline_event(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
super().on_pipeline_event(event)
if self._client is None:
return # stopping
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detected_wake_word = wake_word_output["wake_word_id"]
detection = Detection(
name=detected_wake_word,
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
# Set result for remote pipeline trigger
if result_futures := self._pipeline_result_futures[
assist_pipeline.PipelineStage.WAKE_WORD
]:
result_futures.popleft().set_result(detected_wake_word)
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
# Set result for remote pipeline trigger
if result_futures := self._pipeline_result_futures[
assist_pipeline.PipelineStage.STT
]:
result_futures.popleft().set_result(stt_text)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
# -------------------------------------------------------------------------
async def run(self) -> None:
"""Run and maintain a connection to satellite."""
_LOGGER.debug("Running satellite task")
@ -125,6 +309,9 @@ class WyomingSatellite:
def stop(self) -> None:
"""Signal satellite task to stop running."""
# Cancel any running pipeline
self._audio_queue.put_nowait(_STOP_CHUNK)
# Tell satellite to stop running
self._send_pause()
@ -173,7 +360,7 @@ class WyomingSatellite:
"""Run when device muted status changes."""
if self.device.is_muted:
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
self._audio_queue.put_nowait(_STOP_CHUNK)
# Send pause event so satellite can react immediately
self._send_pause()
@ -185,13 +372,13 @@ class WyomingSatellite:
"""Run when device pipeline changes."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
self._audio_queue.put_nowait(_STOP_CHUNK)
def _audio_settings_changed(self) -> None:
"""Run when device audio settings."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
self._audio_queue.put_nowait(_STOP_CHUNK)
async def _connect_and_loop(self) -> None:
"""Connect to satellite and run pipelines until an error occurs."""
@ -222,7 +409,9 @@ class WyomingSatellite:
async def _run_pipeline_loop(self) -> None:
"""Run a pipeline one or more times."""
assert self._client is not None
if self._client is None:
return # stopping
client_info: Info | None = None
wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None
@ -302,7 +491,7 @@ class WyomingSatellite:
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
self._audio_queue.put_nowait(_STOP_CHUNK)
elif Info.is_type(client_event.type):
client_info = Info.from_event(client_event)
_LOGGER.debug("Updated client info: %s", client_info)
@ -328,7 +517,19 @@ class WyomingSatellite:
if found_phrase:
break
if result_futures := self._pipeline_result_futures[
assist_pipeline.PipelineStage.WAKE_WORD
]:
result_futures.popleft().set_result(wake_word_phrase)
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
elif Played.is_type(client_event.type):
# Set result for remote pipeline trigger
self._played_timeout_id = None
if result_futures := self._pipeline_result_futures[
assist_pipeline.PipelineStage.TTS
]:
result_futures.popleft().set_result(None)
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
@ -344,8 +545,8 @@ class WyomingSatellite:
"""Run a pipeline once."""
_LOGGER.debug("Received run information: %s", run_pipeline)
start_stage = _STAGES.get(run_pipeline.start_stage)
end_stage = _STAGES.get(run_pipeline.end_stage)
start_stage = _ASSIST_STAGES.get(run_pipeline.start_stage)
end_stage = _ASSIST_STAGES.get(run_pipeline.end_stage)
if start_stage is None:
raise ValueError(f"Invalid start stage: {start_stage}")
@ -353,77 +554,32 @@ class WyomingSatellite:
if end_stage is None:
raise ValueError(f"Invalid end stage: {end_stage}")
pipeline_id = pipeline_select.get_chosen_pipeline(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
# Reset conversation id, if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = str(uuid4())
# Update timeout
self._conversation_id_time = time.monotonic()
self._is_pipeline_running = True
self._pipeline_ended_event.clear()
self.config_entry.async_create_background_task(
self.hass,
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
self._async_accept_pipeline_from_satellite(
queue_to_iterable(self._audio_queue),
start_stage,
end_stage,
pipeline_entity_id=self.device.get_pipeline_entity_id(self.hass),
vad_sensitivity_entity_id=self.device.get_vad_sensitivity_entity_id(
self.hass
),
stt_stream=stt_stream,
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
silence_seconds=VadSensitivity.to_seconds(
self.device.vad_sensitivity
),
),
device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
conversation_id=self._conversation_id,
tts_input=run_pipeline.announce_text,
),
name="wyoming satellite pipeline",
)
async def _send_delayed_ping(self) -> None:
"""Send ping to satellite after a delay."""
assert self._client is not None
if self._client is None:
return # stopping
try:
await asyncio.sleep(_PING_SEND_DELAY)
@ -431,91 +587,6 @@ class WyomingSatellite:
except ConnectionError:
pass # handled with timeout
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detection = Detection(
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
async def _connect(self) -> None:
"""Connect to satellite over TCP."""
await self._disconnect()
@ -537,62 +608,78 @@ class WyomingSatellite:
async def _stream_tts(self, media_id: str) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
if extension != "wav":
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
async def _stt_stream(self) -> AsyncGenerator[bytes]:
"""Yield audio chunks from a queue."""
try:
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
if self._client is None:
return # stopping
yield chunk
except asyncio.CancelledError:
pass # ignore
extension, data = await tts.async_get_media_source_audio(
self.hass, media_id
)
if extension != "wav":
raise ValueError(
f"Cannot stream audio format to satellite: {extension}"
)
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
num_frames = wav_file.getnframes()
_LOGGER.debug("Streaming %s TTS sample(s)", num_frames)
wav_seconds = num_frames / sample_rate
self._played_timeout_id = time.monotonic_ns()
self.config_entry.async_create_background_task(
self.hass,
self._tts_played_timeout(self._played_timeout_id, wav_seconds + 1),
"wyoming tts timeout",
)
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
finally:
self.tts_response_finished()
async def _tts_played_timeout(self, timeout_id: int, timeout_sec: float) -> None:
"""Set pipeline result after timeout if Played message is not received."""
await asyncio.sleep(timeout_sec)
if self._played_timeout_id != timeout_id:
return
if result_futures := self._pipeline_result_futures[
assist_pipeline.PipelineStage.TTS
]:
result_futures.popleft().set_result(None)
@callback
def _handle_timer(
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
) -> None:
"""Forward timer events to satellite."""
assert self._client is not None
if self._client is None:
return # stopping
_LOGGER.debug("Timer event: type=%s, info=%s", event_type, timer)
event: Event | None = None

View file

@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
from .entity import WyomingEntity
if TYPE_CHECKING:
from .models import DomainDataItem
@ -28,12 +28,12 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.satellite_device is not None
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite_device)])
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
class WyomingSatelliteAssistInProgress(WyomingEntity, BinarySensorEntity):
"""Entity to represent Assist is in progress for satellite."""
entity_description = BinarySensorEntityDescription(

View file

@ -157,3 +157,10 @@ class SatelliteDevice:
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.satellite_id}-vad_sensitivity"
)
def get_satellite_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for satellite."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"assist_satellite", DOMAIN, f"{self.satellite_id}-assist_satellite"
)

View file

@ -6,10 +6,10 @@ from homeassistant.helpers import entity
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from .const import DOMAIN
from .satellite import SatelliteDevice
from .devices import SatelliteDevice
class WyomingSatelliteEntity(entity.Entity):
class WyomingEntity(entity.Entity):
"""Wyoming satellite entity."""
_attr_has_entity_name = True

View file

@ -3,10 +3,15 @@
"name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline", "intent", "conversation"],
"dependencies": [
"assist_satellite",
"assist_pipeline",
"intent",
"conversation"
],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"integration_type": "service",
"iot_class": "local_push",
"requirements": ["wyoming==1.5.4"],
"requirements": ["wyoming==1.6.0"],
"zeroconf": ["_wyoming._tcp.local."]
}

View file

@ -3,7 +3,7 @@
from dataclasses import dataclass
from .data import WyomingService
from .satellite import WyomingSatellite
from .devices import SatelliteDevice
@dataclass
@ -11,4 +11,4 @@ class DomainDataItem:
"""Domain data item."""
service: WyomingService
satellite: WyomingSatellite | None = None
satellite_device: SatelliteDevice | None = None

View file

@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
from .entity import WyomingEntity
if TYPE_CHECKING:
from .models import DomainDataItem
@ -30,9 +30,9 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.satellite_device is not None
device = item.satellite.device
device = item.satellite_device
async_add_entities(
[
WyomingSatelliteAutoGainNumber(device),
@ -41,7 +41,7 @@ async def async_setup_entry(
)
class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber):
class WyomingSatelliteAutoGainNumber(WyomingEntity, RestoreNumber):
"""Entity to represent auto gain amount."""
entity_description = NumberEntityDescription(
@ -70,7 +70,7 @@ class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber):
self._device.set_auto_gain(auto_gain)
class WyomingSatelliteVolumeMultiplierNumber(WyomingSatelliteEntity, RestoreNumber):
class WyomingSatelliteVolumeMultiplierNumber(WyomingEntity, RestoreNumber):
"""Entity to represent microphone volume multiplier."""
entity_description = NumberEntityDescription(

View file

@ -18,7 +18,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .devices import SatelliteDevice
from .entity import WyomingSatelliteEntity
from .entity import WyomingEntity
if TYPE_CHECKING:
from .models import DomainDataItem
@ -42,9 +42,9 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.satellite_device is not None
device = item.satellite.device
device = item.satellite_device
async_add_entities(
[
WyomingSatellitePipelineSelect(hass, device),
@ -54,14 +54,14 @@ async def async_setup_entry(
)
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
class WyomingSatellitePipelineSelect(WyomingEntity, AssistPipelineSelect):
"""Pipeline selector for Wyoming satellites."""
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
"""Initialize a pipeline selector."""
self.device = device
WyomingSatelliteEntity.__init__(self, device)
WyomingEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id)
async def async_select_option(self, option: str) -> None:
@ -71,7 +71,7 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec
class WyomingSatelliteNoiseSuppressionLevelSelect(
WyomingSatelliteEntity, SelectEntity, restore_state.RestoreEntity
WyomingEntity, SelectEntity, restore_state.RestoreEntity
):
"""Entity to represent noise suppression level setting."""
@ -99,16 +99,14 @@ class WyomingSatelliteNoiseSuppressionLevelSelect(
self._device.set_noise_suppression_level(_NOISE_SUPPRESSION_LEVEL[option])
class WyomingSatelliteVadSensitivitySelect(
WyomingSatelliteEntity, VadSensitivitySelect
):
class WyomingSatelliteVadSensitivitySelect(WyomingEntity, VadSensitivitySelect):
"""VAD sensitivity selector for Wyoming satellites."""
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
"""Initialize a VAD sensitivity selector."""
self.device = device
WyomingSatelliteEntity.__init__(self, device)
WyomingEntity.__init__(self, device)
VadSensitivitySelect.__init__(self, hass, device.satellite_id)
async def async_select_option(self, option: str) -> None:

View file

@ -12,7 +12,7 @@ from homeassistant.helpers import restore_state
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
from .entity import WyomingEntity
if TYPE_CHECKING:
from .models import DomainDataItem
@ -27,13 +27,13 @@ async def async_setup_entry(
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
assert item.satellite_device is not None
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite.device)])
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite_device)])
class WyomingSatelliteMuteSwitch(
WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity
WyomingEntity, restore_state.RestoreEntity, SwitchEntity
):
"""Entity to represent if satellite is muted."""
@ -51,7 +51,7 @@ class WyomingSatelliteMuteSwitch(
# Default to off
self._attr_is_on = (state is not None) and (state.state == STATE_ON)
self._device.is_muted = self._attr_is_on
self._device.set_is_muted(self._attr_is_on)
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on."""

View file

@ -41,6 +41,7 @@ class Platform(StrEnum):
AIR_QUALITY = "air_quality"
ALARM_CONTROL_PANEL = "alarm_control_panel"
ASSIST_SATELLITE = "assist_satellite"
BINARY_SENSOR = "binary_sensor"
BUTTON = "button"
CALENDAR = "calendar"

View file

@ -5,22 +5,28 @@ from __future__ import annotations
from asyncio import (
AbstractEventLoop,
Future,
Queue,
Semaphore,
Task,
TimerHandle,
gather,
get_running_loop,
timeout as async_timeout,
)
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
import concurrent.futures
import logging
import threading
from typing import Any
from typing_extensions import TypeVar
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
_DataT = TypeVar("_DataT", default=Any)
def create_eager_task[_T](
coro: Coroutine[Any, Any, _T],
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
"""Return a list of scheduled TimerHandles."""
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
return handles
async def queue_to_iterable(
queue: Queue[_DataT], timeout: float | None = None
) -> AsyncIterable[_DataT]:
"""Stream items from a queue until None with an optional timeout per item."""
if timeout is None:
while (item := await queue.get()) is not None:
yield item
else:
async with async_timeout(timeout):
item = await queue.get()
while item is not None:
yield item
async with async_timeout(timeout):
item = await queue.get()

View file

@ -2936,7 +2936,7 @@ wled==0.20.2
wolf-comm==0.0.9
# homeassistant.components.wyoming
wyoming==1.5.4
wyoming==1.6.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -2319,7 +2319,7 @@ wled==0.20.2
wolf-comm==0.0.9
# homeassistant.components.wyoming
wyoming==1.5.4
wyoming==1.6.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View file

@ -0,0 +1,11 @@
"""Tests for the Assist satellite integration."""
from homeassistant.components import assist_satellite
class MockSatelliteEntity(assist_satellite.AssistSatelliteEntity):
"""Mock satellite that supports pipeline triggering."""
_attr_supported_features = (
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
)

View file

@ -0,0 +1,104 @@
"""Common fixtures for the Assist satellite tests."""
from collections.abc import Generator
import pytest
from homeassistant.components import assist_satellite
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component
from . import MockSatelliteEntity
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
TEST_DOMAIN = "test"
async def mock_config_entry_setup(
hass: HomeAssistant, satellite_entity: MockSatelliteEntity
) -> MockConfigEntry:
"""Set up a test satellite platform via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [assist_satellite.DOMAIN]
)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(
config_entry, assist_satellite.DOMAIN
)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([satellite_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{TEST_DOMAIN}.{assist_satellite.DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry
class AssistSatelliteFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture(autouse=True)
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
"""Mock config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
with mock_config_flow(TEST_DOMAIN, AssistSatelliteFlow):
yield
@pytest.fixture
def setup_mock_satellite_entity() -> MockSatelliteEntity:
"""Test satellite entity."""
return MockSatelliteEntity()
@pytest.fixture
async def mock_satellite(
hass: HomeAssistant, setup_mock_satellite_entity: MockSatelliteEntity
) -> MockSatelliteEntity:
"""Create a config entry."""
assert await async_setup_component(hass, "homeassistant", {})
await mock_config_entry_setup(hass, setup_mock_satellite_entity)
return setup_mock_satellite_entity

View file

@ -0,0 +1,204 @@
"""Tests for Assist satellite."""
from unittest.mock import patch
from homeassistant.components import assist_pipeline, assist_satellite
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant
from . import MockSatelliteEntity
async def test_wait_wake(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test wait_wake service."""
test_wake_word = "test-wake-word"
with patch.object(
mock_satellite,
"async_trigger_pipeline_on_satellite",
return_value=assist_satellite.PipelineRunResult(
detected_wake_word=test_wake_word
),
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_WAIT_WAKE,
{
ATTR_ENTITY_ID: mock_satellite.entity_id,
assist_satellite.ATTR_WAKE_WORDS: [test_wake_word],
},
return_response=True,
blocking=True,
)
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
assist_pipeline.PipelineStage.WAKE_WORD,
assist_pipeline.PipelineStage.WAKE_WORD,
assist_satellite.PipelineRunConfig(wake_word_names=[test_wake_word]),
)
assert result == {mock_satellite.entity_id: test_wake_word}
async def test_announce_wait_wake(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test wait_wake service with announcement."""
test_wake_word = "test-wake-word"
announce_text = "test-announce-text"
with patch.object(
mock_satellite,
"async_trigger_pipeline_on_satellite",
return_value=assist_satellite.PipelineRunResult(
detected_wake_word=test_wake_word
),
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_WAIT_WAKE,
{
ATTR_ENTITY_ID: mock_satellite.entity_id,
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
assist_satellite.ATTR_WAKE_WORDS: [test_wake_word],
},
return_response=True,
blocking=True,
)
assert mock_async_trigger_pipeline_on_satellite.call_count == 2
assert mock_async_trigger_pipeline_on_satellite.call_args_list[0].args == (
assist_pipeline.PipelineStage.TTS,
assist_pipeline.PipelineStage.TTS,
assist_satellite.PipelineRunConfig(announce_text=announce_text),
)
assert mock_async_trigger_pipeline_on_satellite.call_args_list[1].args == (
assist_pipeline.PipelineStage.WAKE_WORD,
assist_pipeline.PipelineStage.WAKE_WORD,
assist_satellite.PipelineRunConfig(wake_word_names=[test_wake_word]),
)
assert result == {mock_satellite.entity_id: test_wake_word}
async def test_get_command(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test get_command service."""
test_command = "test-command"
with patch.object(
mock_satellite,
"async_trigger_pipeline_on_satellite",
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_GET_COMMAND,
{ATTR_ENTITY_ID: mock_satellite.entity_id},
return_response=True,
blocking=True,
)
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
assist_pipeline.PipelineStage.STT,
assist_pipeline.PipelineStage.STT,
assist_satellite.PipelineRunConfig(),
)
assert result == {mock_satellite.entity_id: test_command}
async def test_announce_get_command(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test get_command service with announcement."""
test_command = "test-command"
announce_text = "test-announce-text"
with patch.object(
mock_satellite,
"async_trigger_pipeline_on_satellite",
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_GET_COMMAND,
{
ATTR_ENTITY_ID: mock_satellite.entity_id,
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
},
return_response=True,
blocking=True,
)
assert mock_async_trigger_pipeline_on_satellite.call_count == 2
assert mock_async_trigger_pipeline_on_satellite.call_args_list[0].args == (
assist_pipeline.PipelineStage.TTS,
assist_pipeline.PipelineStage.TTS,
assist_satellite.PipelineRunConfig(announce_text=announce_text),
)
assert mock_async_trigger_pipeline_on_satellite.call_args_list[1].args == (
assist_pipeline.PipelineStage.STT,
assist_pipeline.PipelineStage.STT,
assist_satellite.PipelineRunConfig(),
)
assert result == {mock_satellite.entity_id: test_command}
async def test_get_command_process(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test get_command service with processing enabled."""
test_command = "test-command"
with patch.object(
mock_satellite,
"async_trigger_pipeline_on_satellite",
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_GET_COMMAND,
{
ATTR_ENTITY_ID: mock_satellite.entity_id,
assist_satellite.ATTR_PROCESS: True,
},
return_response=True,
blocking=True,
)
# Pipeline should run to TTS stage now
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
assist_pipeline.PipelineStage.STT,
assist_pipeline.PipelineStage.TTS,
assist_satellite.PipelineRunConfig(),
)
assert result == {mock_satellite.entity_id: test_command}
async def test_say_text(
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
) -> None:
"""Test say_text service."""
announce_text = "test-announce-text"
with patch.object(
mock_satellite, "async_trigger_pipeline_on_satellite", return_value=None
) as mock_async_trigger_pipeline_on_satellite:
result = await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_SAY_TEXT,
{
ATTR_ENTITY_ID: mock_satellite.entity_id,
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
},
return_response=False,
blocking=True,
)
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
assist_pipeline.PipelineStage.TTS,
assist_pipeline.PipelineStage.TTS,
assist_satellite.PipelineRunConfig(announce_text=announce_text),
)
assert result is None

File diff suppressed because one or more lines are too long

View file

@ -3,15 +3,26 @@
import asyncio
import io
from pathlib import Path
import time
from unittest.mock import AsyncMock, Mock, patch
import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from voip_utils import CallInfo
from homeassistant.components import assist_pipeline, voip
from homeassistant.components.voip.devices import VoIPDevice
from homeassistant.components import assist_pipeline, assist_satellite, voip
from homeassistant.components.assist_satellite import (
AssistSatelliteEntity,
AssistSatelliteState,
)
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
from homeassistant.const import STATE_OFF, STATE_ON, Platform
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import async_setup_component
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
return wav_io.getvalue()
def async_get_satellite_entity(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> AssistSatelliteEntity | None:
"""Get Assist satellite entity."""
ent_reg = er.async_get(hass)
satellite_entity_id = ent_reg.async_get_entity_id(
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
)
if satellite_entity_id is None:
return None
component: EntityComponent[AssistSatelliteEntity] = hass.data[
assist_satellite.DOMAIN
]
return component.get_entity(satellite_entity_id)
async def test_is_valid_call(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that a call is now allowed from an unknown device."""
assert await async_setup_component(hass, "voip", {})
protocol = HassVoipDatagramProtocol(hass, voip_devices)
assert not protocol.is_valid_call(call_info)
ent_reg = er.async_get(hass)
allowed_call_entity_id = ent_reg.async_get_entity_id(
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
)
assert allowed_call_entity_id is not None
state = hass.states.get(allowed_call_entity_id)
assert state is not None
assert state.state == STATE_OFF
# Allow calls
hass.states.async_set(allowed_call_entity_id, STATE_ON)
assert protocol.is_valid_call(call_info)
async def test_calls_not_allowed(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when calls aren't allowed."""
assert await async_setup_component(hass, "voip", {})
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
# Test the playback
done = asyncio.Event()
played_audio_bytes = b""
def send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be problem.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
protocol.transport = Mock()
protocol.loop_delay = 0
with patch.object(protocol, "send_audio", send_audio):
protocol.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()
async def test_pipeline_not_found(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pre-recorded message is played when a pipeline isn't found."""
assert await async_setup_component(hass, "voip", {})
with patch(
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
):
protocol: PreRecordMessageProtocol = make_protocol(
hass, voip_devices, call_info
)
assert isinstance(protocol, PreRecordMessageProtocol)
assert protocol.file_name == "problem.pcm"
async def test_satellite_prepared(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
snapshot: SnapshotAssertion,
) -> None:
"""Test that satellite is prepared for a call."""
assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test",
conversation_language="en",
language="en",
name="test",
stt_engine="test",
stt_language="en",
tts_engine="test",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
with (
patch(
"homeassistant.components.voip.voip.async_get_pipeline",
return_value=pipeline,
),
):
protocol = make_protocol(hass, voip_devices, call_info)
assert protocol == satellite
async def test_pipeline(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
call_info: CallInfo,
) -> None:
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
voip_user_id = satellite.config_entry.data["user"]
assert voip_user_id
return 0
# Satellite is muted until a call begins
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
done = asyncio.Event()
# Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
async def async_pipeline_from_audio_stream(
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
):
assert context.user_id == voip_user_id
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
in_command = False
async for chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
assert _chunk != bad_chunk
assert chunk != bad_chunk
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Test empty data
event_callback(
@ -71,6 +229,38 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_START,
data={"engine": "test", "metadata": {}},
)
)
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_START,
data={
"engine": "test",
"language": hass.config.language,
"intent_input": "fake-text",
"conversation_id": None,
"device_id": None,
},
)
)
assert satellite.state == AssistSatelliteState.PROCESSING
# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
@ -83,6 +273,21 @@ async def test_pipeline(
)
)
# Fake tts result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_START,
data={
"engine": "test",
"language": hass.config.language,
"voice": "test",
"tts_input": "fake-text",
},
)
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
@ -91,6 +296,18 @@ async def test_pipeline(
)
)
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.RUN_END
)
)
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
@ -100,102 +317,56 @@ async def test_pipeline(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
)
rtp_protocol.transport = Mock()
satellite._tones = Tones(0)
satellite.transport = Mock()
satellite.connection_made(satellite.transport)
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
# Ensure audio queue is cleared before pipeline starts
rtp_protocol._audio_queue.put_nowait(bad_chunk)
satellite._audio_queue.put_nowait(bad_chunk)
def send_audio(*args, **kwargs):
# Test finished successfully
done.set()
# Don't send audio
pass
rtp_protocol.send_audio = Mock(side_effect=send_audio)
satellite.send_audio = Mock(side_effect=send_audio)
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes aggressive VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
"""Test timeout during pipeline run."""
assert await async_setup_component(hass, "voip", {})
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
await asyncio.sleep(10)
with (
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
return_value=True,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
pipeline_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
await done.wait()
# Finished speaking
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
async def test_stt_stream_timeout(
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
) -> None:
"""Test timeout in STT stream during pipeline run."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
pass
with patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
audio_timeout=0.001,
listening_tone_enabled=False,
processing_tone_enabled=False,
error_tone_enabled=False,
)
satellite._tones = Tones(0)
satellite._audio_chunk_timeout = 0.001
transport = Mock(spec=["close"])
rtp_protocol.connection_made(transport)
satellite.connection_made(transport)
# Closing the transport will cause the test to succeed
transport.close.side_effect = done.set
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to time out
async with asyncio.timeout(1):
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
async def test_tts_timeout(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -278,15 +448,7 @@ async def test_tts_timeout(
tone_bytes = bytes([1, 2, 3, 4])
def send_audio(audio_bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
# Block here to force a timeout in _send_tts
time.sleep(2)
async def async_send_audio(audio_bytes, **kwargs):
async def async_send_audio(audio_bytes: bytes, **kwargs):
if audio_bytes == tone_bytes:
# Not TTS
return
@ -303,37 +465,22 @@ async def test_tts_timeout(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
tts_extra_timeout=0.001,
listening_tone_enabled=True,
processing_tone_enabled=True,
error_tone_enabled=True,
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
)
rtp_protocol._tone_bytes = tone_bytes
rtp_protocol._processing_bytes = tone_bytes
rtp_protocol._error_bytes = tone_bytes
rtp_protocol.transport = Mock()
rtp_protocol.send_audio = Mock()
satellite._tts_extra_timeout = 0.001
for tone in Tones:
satellite._tone_bytes[tone] = tone_bytes
original_send_tts = rtp_protocol._send_tts
satellite.transport = Mock()
satellite.send_audio = Mock()
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -342,17 +489,17 @@ async def test_tts_timeout(
done.set()
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
# silence
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -361,26 +508,34 @@ async def test_tts_timeout(
async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
original_send_tts = rtp_protocol._send_tts
original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs):
# Call original then end test successfully
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
done.set()
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
async def test_empty_tts_output(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
if sum(chunk) > 0:
return 1
return 0
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
in_command = False
async for chunk in stt_stream:
if sum(chunk) > 0:
in_command = True
elif in_command:
break # done with command
# Fake STT result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.STT_END,
data={"stt_output": {"text": "fake-text"}},
)
)
# Fake intent result
event_callback(
@ -605,37 +754,78 @@ async def test_empty_tts_output(
with (
patch(
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()
satellite.transport = Mock()
# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
satellite.on_chunk(bytes(_ONE_SECOND))
# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
satellite.on_chunk(bytes(_ONE_SECOND * 4))
# Wait for mock pipeline to finish
async with asyncio.timeout(1):
await rtp_protocol._tts_done.wait()
await satellite._tts_done.wait()
mock_send_tts.assert_not_called()
async def test_pipeline_error(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
snapshot: SnapshotAssertion,
) -> None:
"""Test that a pipeline error causes the error tone to be played."""
assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event()
played_audio_bytes = b""
async def async_pipeline_from_audio_stream(*args, **kwargs):
# Fake error
event_callback = kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.ERROR,
data={"code": "error-code", "message": "error message"},
)
)
async def async_send_audio(audio_bytes: bytes, **kwargs):
nonlocal played_audio_bytes
# Should be error.pcm from components/voip
played_audio_bytes = audio_bytes
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
satellite._tones = Tones.ERROR
satellite.transport = Mock()
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for error tone to be played
async with asyncio.timeout(1):
await done.wait()
assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()

View file

@ -150,10 +150,10 @@ async def reload_satellite(
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.run"
) as _run_mock,
):
# _run_mock: satellite task does not actually run
await hass.config_entries.async_reload(config_entry_id)
return hass.data[DOMAIN][config_entry_id].satellite.device
return hass.data[DOMAIN][config_entry_id].satellite_device

View file

@ -152,7 +152,7 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.run"
) as _run_mock,
):
# _run_mock: satellite task does not actually run
@ -164,4 +164,4 @@ async def satellite_device(
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
) -> SatelliteDevice:
"""Get a satellite device fixture."""
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite_device

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -17,14 +17,17 @@ from wyoming.info import Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.snd import Played
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
from homeassistant.components.wyoming.assist_satellite import WyomingSatellite
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.const import STATE_ON
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import intent as intent_helper
from homeassistant.setup import async_setup_component
@ -69,10 +72,17 @@ def get_test_wav() -> bytes:
return wav_io.getvalue()
def get_device(hass: HomeAssistant, entry: ConfigEntry) -> SatelliteDevice:
"""Get the satellite device for a config entry."""
device = hass.data[wyoming.DOMAIN][entry.entry_id].satellite_device
assert isinstance(device, SatelliteDevice)
return device
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
"""Satellite AsyncTcpClient."""
def __init__(self, responses: list[Event]) -> None:
def __init__(self, responses: list[Event], auto_audio: bool = True) -> None:
"""Initialize client."""
super().__init__(responses)
@ -124,9 +134,16 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.timer_finished_event = asyncio.Event()
self.timer_finished: TimerFinished | None = None
self.run_pipeline_event = asyncio.Event()
self.run_pipeline_count = asyncio.Semaphore()
self.run_pipeline: RunPipeline | None = None
self.run_pipeline_list: list[RunPipeline] = []
self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk"
).event()
self._auto_audio = auto_audio
self._event_injected = asyncio.Event()
async def connect(self) -> None:
"""Connect."""
@ -184,17 +201,29 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
elif TimerFinished.is_type(event.type):
self.timer_finished = TimerFinished.from_event(event)
self.timer_finished_event.set()
elif RunPipeline.is_type(event.type):
self.run_pipeline = RunPipeline.from_event(event)
self.run_pipeline_list.append(self.run_pipeline)
self.run_pipeline_event.set()
self.run_pipeline_count.release()
async def read_event(self) -> Event | None:
"""Receive."""
event = await super().read_event()
while True:
event = await super().read_event()
if event is not None:
return event
# Keep sending audio chunks instead of None
return event or self._mic_audio_chunk
if self._auto_audio:
# Keep sending audio chunks instead of None
return self._mic_audio_chunk
await self._event_injected.wait()
def inject_event(self, event: Event) -> None:
"""Put an event in as the next response."""
self.responses = [event, *self.responses]
self._event_injected.set()
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
@ -240,23 +269,21 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device = get_device(hass, entry)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -443,7 +470,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
"""Test callback for a satellite that has been muted."""
on_muted_event = asyncio.Event()
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
original_on_muted = WyomingSatellite.on_muted
async def on_muted(self):
# Trigger original function
@ -457,6 +484,18 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
self.device.set_is_muted(False)
on_muted_event.set()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
async for chunk in stt_stream:
if not chunk:
break
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
@ -467,9 +506,17 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
return_value=State("switch.test_mute", STATE_ON),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_muted",
on_muted,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
@ -484,7 +531,7 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
"""Test pipeline loop restart after unexpected error."""
on_restart_event = asyncio.Event()
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
original_on_restart = WyomingSatellite.on_restart
async def on_restart(self):
await original_on_restart(self)
@ -497,14 +544,14 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._connect_and_loop",
side_effect=RuntimeError(),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
on_restart,
),
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
@ -517,7 +564,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
original_on_reconnect = WyomingSatellite.on_reconnect
async def on_reconnect(self):
await original_on_reconnect(self)
@ -537,18 +584,20 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
side_effect=ConnectionRefusedError(),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_reconnect",
on_reconnect,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_stopped",
on_stopped,
),
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
patch(
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
@ -570,14 +619,14 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient([]), # no RunPipeline event
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
on_restart,
),
):
@ -615,25 +664,23 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient(events),
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
on_restart,
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_stopped",
on_stopped,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device = get_device(hass, entry)
async with asyncio.timeout(1):
await on_restart_event.wait()
@ -665,11 +712,11 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
@ -701,7 +748,7 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
"""Test satellite receiving non-WAV audio from text-to-speech."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
original_stream_tts = WyomingSatellite._stream_tts
error_event = asyncio.Event()
async def _stream_tts(self, media_id):
@ -724,19 +771,19 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
),
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._stream_tts",
_stream_tts,
),
):
@ -808,8 +855,9 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
async for chunk in stt_stream:
if not chunk:
break
pipeline_stopped.set()
@ -819,18 +867,16 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device = get_device(hass, entry)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -882,8 +928,9 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
async for chunk in stt_stream:
if not chunk:
break
pipeline_stopped.set()
@ -893,18 +940,16 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device = get_device(hass, entry)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -938,7 +983,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
).event(),
]
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
original_run_pipeline_once = WyomingSatellite._run_pipeline_once
start_stage_event = asyncio.Event()
end_stage_event = asyncio.Event()
@ -967,11 +1012,11 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._run_pipeline_once",
_run_pipeline_once,
),
):
@ -1018,8 +1063,9 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
async for chunk in stt_stream:
if not chunk:
break
pipeline_stopped.set()
@ -1029,11 +1075,11 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
@ -1083,11 +1129,11 @@ async def test_wake_word_phrase(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
@ -1114,14 +1160,12 @@ async def test_timers(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
) as mock_client,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
device = get_device(hass, entry)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -1325,23 +1369,20 @@ async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite
await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
@ -1370,19 +1411,128 @@ async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
# Should be the same conversation id
assert pipeline_kwargs.get("conversation_id") == conversation_id
# Reset and run again, but this time "time out"
satellite._conversation_id_time = None
run_pipeline_called.clear()
pipeline_kwargs.clear()
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
async def test_say_text(hass: HomeAssistant) -> None:
"""Test say text service call."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
test_text = "test-text"
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
) as mock_client,
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
device = get_device(hass, entry)
satellite_entity_id = device.get_satellite_entity_id(hass)
async with asyncio.timeout(1):
await run_pipeline_called.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Should be a different conversation id
new_conversation_id = pipeline_kwargs.get("conversation_id")
assert new_conversation_id
assert new_conversation_id != conversation_id
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_SAY_TEXT,
{
ATTR_ENTITY_ID: satellite_entity_id,
assist_satellite.ATTR_ANNOUNCE_TEXT: test_text,
},
blocking=False,
)
await mock_client.run_pipeline_event.wait()
assert mock_client.run_pipeline is not None
rp: RunPipeline = mock_client.run_pipeline
assert rp.start_stage == PipelineStage.TTS
assert rp.end_stage == PipelineStage.TTS
assert rp.announce_text == test_text
async def test_get_command(hass: HomeAssistant) -> None:
"""Test get command service call."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
test_command = "test-command"
test_text = "test-text"
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_END,
{"stt_output": {"text": test_command}},
)
)
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([], auto_audio=False),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
device = get_device(hass, entry)
satellite_entity_id = device.get_satellite_entity_id(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
task = asyncio.create_task(
hass.services.async_call(
assist_satellite.DOMAIN,
assist_satellite.SERVICE_GET_COMMAND,
{
ATTR_ENTITY_ID: satellite_entity_id,
assist_satellite.ATTR_ANNOUNCE_TEXT: test_text,
},
blocking=True,
return_response=True,
)
)
await mock_client.run_pipeline_event.wait()
# Announcement happens first
assert mock_client.run_pipeline is not None
rp: RunPipeline = mock_client.run_pipeline
assert rp.start_stage == PipelineStage.TTS
assert rp.end_stage == PipelineStage.TTS
assert rp.announce_text == test_text
mock_client.run_pipeline_event.clear()
mock_client.run_pipeline = None
mock_client.inject_event(Played().event())
# Command happens next
await mock_client.run_pipeline_event.wait()
assert mock_client.run_pipeline is not None
rp = mock_client.run_pipeline
assert rp.start_stage == PipelineStage.ASR
assert rp.end_stage == PipelineStage.ASR
mock_client.inject_event(rp.event())
result = await task
assert result == {satellite_entity_id: test_command}

View file

@ -3,8 +3,8 @@
from unittest.mock import Mock, patch
from homeassistant.components import assist_pipeline
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.pipeline import PipelineData
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.assist_pipeline.vad import VadSensitivity
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry

View file

@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
timer_handle.cancel()
timer_handle2.cancel()
timer_handle3.cancel()
async def test_queue_to_iterable() -> None:
"""Test queue_to_iterable."""
queue: asyncio.Queue[int | None] = asyncio.Queue()
expected_items = list(range(10))
for i in expected_items:
await queue.put(i)
# Will terminate the stream
await queue.put(None)
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
assert expected_items == actual_items
# Check timeout
assert queue.empty()
# Time out on first item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
await asyncio.sleep(1)
# Check timeout on second item
assert queue.empty()
await queue.put(12345)
# Time out on second item
async with asyncio.timeout(1):
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
# Should time out very quickly
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
if item != 12345:
await asyncio.sleep(1)
assert queue.empty()