Compare commits
25 commits
dev
...
synesthesi
Author | SHA1 | Date | |
---|---|---|---|
|
b77dff5e65 | ||
|
07ac3df4ff | ||
|
d11dace22b | ||
|
59a6b1ebfa | ||
|
1d2bced1f0 | ||
|
644427ecc7 | ||
|
d48fcb3221 | ||
|
c468d9c5c9 | ||
|
33d0d2cfed | ||
|
a4876e435c | ||
|
93da8de1e4 | ||
|
bd0a97a3b7 | ||
|
9a483613e1 | ||
|
f1c0bdf5be | ||
|
f4d6e46fed | ||
|
ecec1d3208 | ||
|
66be7b9648 | ||
|
f6e5d2d80b | ||
|
712e4e5f50 | ||
|
1e1623309d | ||
|
d32a681f28 | ||
|
337fe974f7 | ||
|
d7e9f6aae4 | ||
|
ec1866e131 | ||
|
b21e2360b9 |
46 changed files with 2275 additions and 1022 deletions
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/aseko_pool_live/ @milanmeu
|
||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @synesthesiam
|
||||
/tests/components/assist_satellite/ @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
|
|
@ -16,6 +16,7 @@ from .const import (
|
|||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
OPTION_PREFERRED,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
SAMPLE_WIDTH,
|
||||
|
@ -57,6 +58,7 @@ __all__ = (
|
|||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
"OPTION_PREFERRED",
|
||||
"SAMPLES_PER_CHUNK",
|
||||
"SAMPLE_RATE",
|
||||
"SAMPLE_WIDTH",
|
||||
|
@ -100,6 +102,7 @@ async def async_pipeline_from_audio_stream(
|
|||
pipeline_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
tts_audio_output: str | None = None,
|
||||
tts_input: str | None = None,
|
||||
wake_word_settings: WakeWordSettings | None = None,
|
||||
audio_settings: AudioSettings | None = None,
|
||||
device_id: str | None = None,
|
||||
|
@ -116,6 +119,7 @@ async def async_pipeline_from_audio_stream(
|
|||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
tts_input=tts_input,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
|
|
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
|||
MS_PER_CHUNK = 10
|
||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
|
|
@ -504,7 +504,7 @@ class AudioSettings:
|
|||
is_vad_enabled: bool = True
|
||||
"""True if VAD is used to determine the end of the voice command."""
|
||||
|
||||
silence_seconds: float = 0.5
|
||||
silence_seconds: float = 0.7
|
||||
"""Seconds of silence after voice command has ended."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -906,6 +906,8 @@ class PipelineRun:
|
|||
metadata,
|
||||
self._speech_to_text_stream(audio_stream=stream, stt_vad=stt_vad),
|
||||
)
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
raise # expected
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during speech-to-text")
|
||||
raise SpeechToTextError(
|
||||
|
|
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
|
|
104
homeassistant/components/assist_satellite/__init__.py
Normal file
104
homeassistant/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
"""Base class for assist satellite entities."""
|
||||
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, SupportsResponse
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
from .models import (
|
||||
AssistSatelliteEntityFeature,
|
||||
AssistSatelliteState,
|
||||
PipelineRunConfig,
|
||||
PipelineRunResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityFeature",
|
||||
"AssistSatelliteState",
|
||||
"PipelineRunConfig",
|
||||
"PipelineRunResult",
|
||||
"SERVICE_WAIT_WAKE",
|
||||
"SERVICE_GET_COMMAND",
|
||||
"SERVICE_SAY_TEXT",
|
||||
"ATTR_WAKE_WORDS",
|
||||
"ATTR_PROCESS",
|
||||
"ATTR_ANNOUNCE_TEXT",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||
|
||||
ATTR_WAKE_WORDS = "wake_words"
|
||||
ATTR_PROCESS = "process"
|
||||
ATTR_ANNOUNCE_TEXT = "announce_text"
|
||||
|
||||
SERVICE_WAIT_WAKE = "wait_wake"
|
||||
SERVICE_GET_COMMAND = "get_command"
|
||||
SERVICE_SAY_TEXT = "say_text"
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
|
||||
component.async_register_entity_service(
|
||||
name=SERVICE_WAIT_WAKE,
|
||||
schema=cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Required(ATTR_WAKE_WORDS): [cv.string],
|
||||
vol.Optional(ATTR_ANNOUNCE_TEXT): cv.string,
|
||||
}
|
||||
),
|
||||
func="async_wait_wake",
|
||||
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
|
||||
supports_response=SupportsResponse.OPTIONAL,
|
||||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
name=SERVICE_GET_COMMAND,
|
||||
schema=cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Optional(ATTR_PROCESS): cv.boolean,
|
||||
vol.Optional(ATTR_ANNOUNCE_TEXT): cv.string,
|
||||
}
|
||||
),
|
||||
func="async_get_command",
|
||||
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
|
||||
supports_response=SupportsResponse.OPTIONAL,
|
||||
)
|
||||
|
||||
component.async_register_entity_service(
|
||||
name=SERVICE_SAY_TEXT,
|
||||
schema=cv.make_entity_service_schema(
|
||||
{vol.Required(ATTR_ANNOUNCE_TEXT): cv.string}
|
||||
),
|
||||
func="async_say_text",
|
||||
required_features=[AssistSatelliteEntityFeature.TRIGGER_PIPELINE],
|
||||
supports_response=SupportsResponse.NONE,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
3
homeassistant/components/assist_satellite/const.py
Normal file
3
homeassistant/components/assist_satellite/const.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""Constants for assist satellite."""
|
||||
|
||||
DOMAIN = "assist_satellite"
|
231
homeassistant/components/assist_satellite/entity.py
Normal file
231
homeassistant/components/assist_satellite/entity.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
"""Assist satellite entity."""
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
import time
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipelines,
|
||||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .models import (
|
||||
AssistSatelliteEntityFeature,
|
||||
AssistSatelliteState,
|
||||
PipelineRunConfig,
|
||||
PipelineRunResult,
|
||||
)
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description = EntityDescription(
|
||||
key="assist_satellite",
|
||||
translation_key="assist_satellite",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_should_poll = False
|
||||
_attr_state: AssistSatelliteState | None = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
|
||||
async def async_trigger_pipeline_on_satellite(
|
||||
self,
|
||||
start_stage: PipelineStage,
|
||||
end_stage: PipelineStage,
|
||||
run_config: PipelineRunConfig,
|
||||
) -> PipelineRunResult | None:
|
||||
"""Run a pipeline on the satellite from start to end stage.
|
||||
|
||||
Can be called from a service.
|
||||
Requires TRIGGER_PIPELINE supported feature.
|
||||
|
||||
- announce when start/end = "tts"
|
||||
- listen for wake word when start/end = "wake"
|
||||
- listen for command when start/end = "stt" (no processing)
|
||||
- listen for command when start = "stt", end = "tts" (with processing)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_wait_wake(
|
||||
self, wake_words: list[str], announce_text: str | None = None
|
||||
) -> str | None:
|
||||
"""Listen for one or more wake words on the satellite.
|
||||
|
||||
Returns the detected wake word phrase or None.
|
||||
"""
|
||||
if announce_text:
|
||||
await self.async_say_text(announce_text)
|
||||
|
||||
result = await self.async_trigger_pipeline_on_satellite(
|
||||
PipelineStage.WAKE_WORD,
|
||||
PipelineStage.WAKE_WORD,
|
||||
PipelineRunConfig(wake_word_names=wake_words),
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.detected_wake_word
|
||||
|
||||
async def async_get_command(
|
||||
self, process: bool = False, announce_text: str | None = None
|
||||
) -> str | None:
|
||||
"""Get the text of a voice command from the satellite, optionally processing it.
|
||||
|
||||
Returns the spoken text or None.
|
||||
"""
|
||||
if announce_text:
|
||||
await self.async_say_text(announce_text)
|
||||
|
||||
if process:
|
||||
end_stage = PipelineStage.TTS
|
||||
else:
|
||||
end_stage = PipelineStage.STT
|
||||
|
||||
result = await self.async_trigger_pipeline_on_satellite(
|
||||
PipelineStage.STT, end_stage, PipelineRunConfig()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.command_text
|
||||
|
||||
async def async_say_text(self, announce_text: str) -> None:
|
||||
"""Speak the text on the satellite."""
|
||||
await self.async_trigger_pipeline_on_satellite(
|
||||
PipelineStage.TTS,
|
||||
PipelineStage.TTS,
|
||||
PipelineRunConfig(announce_text=announce_text),
|
||||
)
|
||||
|
||||
async def _async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
pipeline_entity_id: str | None = None,
|
||||
vad_sensitivity_entity_id: str | None = None,
|
||||
wake_word_phrase: str | None = None,
|
||||
tts_input: str | None = None,
|
||||
) -> None:
|
||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||
pipeline_id: str | None = None
|
||||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
|
||||
if pipeline_entity_id:
|
||||
# Resolve pipeline by name
|
||||
pipeline_entity_state = self.hass.states.get(pipeline_entity_id)
|
||||
if (pipeline_entity_state is not None) and (
|
||||
pipeline_entity_state.state != OPTION_PREFERRED
|
||||
):
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
pipeline_id = pipeline.id
|
||||
break
|
||||
|
||||
if vad_sensitivity_entity_id:
|
||||
vad_sensitivity_state = self.hass.states.get(vad_sensitivity_entity_id)
|
||||
if vad_sensitivity_state is not None:
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
device_id: str | None = None
|
||||
if self.registry_entry is not None:
|
||||
device_id = self.registry_entry.device_id
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
(self._context is None)
|
||||
or (self._context_set is None)
|
||||
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||
):
|
||||
self.async_set_context(Context())
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = ulid.ulid()
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self.on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=pipeline_id,
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
tts_input=tts_input,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
)
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type == PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type == PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type == PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type == PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type == PipelineEventType.RUN_END:
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
def _set_state(self, state: AssistSatelliteState):
|
||||
"""Set the entity's state."""
|
||||
self._attr_state = state
|
||||
self.async_write_ha_state()
|
||||
|
||||
def tts_response_finished(self) -> None:
|
||||
"""Tell entity that the text-to-speech response has finished playing."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
12
homeassistant/components/assist_satellite/icons.json
Normal file
12
homeassistant/components/assist_satellite/icons.json
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"default": "mdi:comment-processing-outline"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"wait_wake": "mdi:microphone-message",
|
||||
"get_command": "mdi:comment-text-outline",
|
||||
"say_text": "mdi:speaker-message"
|
||||
}
|
||||
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"domain": "assist_satellite",
|
||||
"name": "Assist Satellite",
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": false,
|
||||
"dependencies": ["assist_pipeline", "stt"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity"
|
||||
}
|
49
homeassistant/components/assist_satellite/models.py
Normal file
49
homeassistant/components/assist_satellite/models.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
"""Models for assist satellite."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntFlag, StrEnum
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
||||
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||
|
||||
LISTENING_COMMAND = "listening_command"
|
||||
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||
|
||||
PROCESSING = "processing"
|
||||
"""Home Assistant is processing the voice command."""
|
||||
|
||||
RESPONDING = "responding"
|
||||
"""Device is speaking the response."""
|
||||
|
||||
|
||||
class AssistSatelliteEntityFeature(IntFlag):
|
||||
"""Supported features of Assist satellite entity."""
|
||||
|
||||
TRIGGER_PIPELINE = 1
|
||||
"""Device supports remote triggering of a pipeline."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRunConfig:
|
||||
"""Configuration for a satellite pipeline run."""
|
||||
|
||||
wake_word_names: list[str] | None = None
|
||||
"""Wake word names to listen for (start_stage = wake)."""
|
||||
|
||||
announce_text: str | None = None
|
||||
"""Text to announce using text-to-speech (start_stage = wake, stt, or tts)."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRunResult:
|
||||
"""Result of a pipeline run."""
|
||||
|
||||
detected_wake_word: str | None = None
|
||||
"""Name of detected wake word (None if timeout)."""
|
||||
|
||||
command_text: str | None = None
|
||||
"""Transcript of speech-to-text for voice command."""
|
46
homeassistant/components/assist_satellite/services.yaml
Normal file
46
homeassistant/components/assist_satellite/services.yaml
Normal file
|
@ -0,0 +1,46 @@
|
|||
wait_wake:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
fields:
|
||||
wake_words:
|
||||
required: true
|
||||
example: "ok nabu"
|
||||
selector:
|
||||
text:
|
||||
multiple: true
|
||||
announce_text:
|
||||
required: false
|
||||
example: "Please say ok nabu."
|
||||
selector:
|
||||
text:
|
||||
get_command:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
fields:
|
||||
process:
|
||||
required: false
|
||||
selector:
|
||||
boolean:
|
||||
announce_text:
|
||||
required: false
|
||||
example: "What would you like for dinner?"
|
||||
selector:
|
||||
text:
|
||||
say_text:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
fields:
|
||||
announce_text:
|
||||
required: true
|
||||
example: "Dinner is ready!"
|
||||
selector:
|
||||
text:
|
54
homeassistant/components/assist_satellite/strings.json
Normal file
54
homeassistant/components/assist_satellite/strings.json
Normal file
|
@ -0,0 +1,54 @@
|
|||
{
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"wait_wake": {
|
||||
"name": "Wait for wake words",
|
||||
"description": "Wait for one or more wake words to be spoken",
|
||||
"fields": {
|
||||
"wake_words": {
|
||||
"name": "Wake words",
|
||||
"description": "Names of wake words to wait for"
|
||||
},
|
||||
"announce_text": {
|
||||
"name": "Announce text",
|
||||
"description": "Text to speak before waiting for wake words"
|
||||
}
|
||||
}
|
||||
},
|
||||
"get_command": {
|
||||
"name": "Get voice command from satellite",
|
||||
"description": "Records and transcribes a command from a voice satellite",
|
||||
"fields": {
|
||||
"process": {
|
||||
"name": "Process command",
|
||||
"description": "Process the text of the command in Home Assistant"
|
||||
},
|
||||
"announce_text": {
|
||||
"name": "announce_text",
|
||||
"description": "Text to speak before recording command"
|
||||
}
|
||||
}
|
||||
},
|
||||
"say_text": {
|
||||
"name": "Say text",
|
||||
"description": "Speak text from a voice satellite",
|
||||
"fields": {
|
||||
"announce_text": {
|
||||
"name": "Announce text",
|
||||
"description": "Text to speak"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ from .devices import VoIPDevices
|
|||
from .voip import HassVoipDatagramProtocol
|
||||
|
||||
PLATFORMS = (
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
|
298
homeassistant/components/voip/assist_satellite.py
Normal file
298
homeassistant/components/voip/assist_satellite.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
"""Assist satellite entity for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import IntFlag
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
import wave
|
||||
|
||||
from voip_utils import RtpDatagramProtocol
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.async_ import queue_to_iterable
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||
|
||||
|
||||
class Tones(IntFlag):
|
||||
"""Feedback tones for specific events."""
|
||||
|
||||
LISTENING = 1
|
||||
PROCESSING = 2
|
||||
ERROR = 4
|
||||
|
||||
|
||||
_TONE_FILENAMES: dict[Tones, str] = {
|
||||
Tones.LISTENING: "tone.pcm",
|
||||
Tones.PROCESSING: "processing.pcm",
|
||||
Tones.ERROR: "error.pcm",
|
||||
}
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP Assist satellite entity."""
|
||||
domain_data: DomainData = hass.data[DOMAIN]
|
||||
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipAssistSatellite(hass, device, config_entry)])
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
entities: list[VoIPEntity] = [
|
||||
VoipAssistSatellite(hass, device, config_entry)
|
||||
for device in domain_data.devices
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
|
||||
"""Assist satellite for VoIP devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
voip_device: VoIPDevice,
|
||||
config_entry: ConfigEntry,
|
||||
tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
|
||||
) -> None:
|
||||
"""Initialize an Assist satellite."""
|
||||
VoIPEntity.__init__(self, voip_device)
|
||||
AssistSatelliteEntity.__init__(self)
|
||||
RtpDatagramProtocol.__init__(self)
|
||||
|
||||
self.config_entry = config_entry
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._audio_chunk_timeout: float = 2.0
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._pipeline_had_error: bool = False
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_extra_timeout: float = 1.0
|
||||
self._tone_bytes: dict[Tones, bytes] = {}
|
||||
self._tones = tones
|
||||
self._processing_tone_done = asyncio.Event()
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
self.voip_device.protocol = self
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
if self.voip_device.protocol == self:
|
||||
self.voip_device.protocol = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||
|
||||
try:
|
||||
self._tts_done.clear()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
async with asyncio.timeout(_PIPELINE_TIMEOUT_SEC):
|
||||
await self._async_accept_pipeline_from_satellite( # noqa: SLF001
|
||||
audio_stream=queue_to_iterable(
|
||||
self._audio_queue, timeout=self._audio_chunk_timeout
|
||||
),
|
||||
pipeline_entity_id=self.voip_device.get_pipeline_entity_id(
|
||||
self.hass
|
||||
),
|
||||
vad_sensitivity_entity_id=self.voip_device.get_vad_sensitivity_entity_id(
|
||||
self.hass
|
||||
),
|
||||
)
|
||||
|
||||
if self._pipeline_had_error:
|
||||
self._pipeline_had_error = False
|
||||
await self._play_tone(Tones.ERROR)
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except (asyncio.CancelledError, TimeoutError):
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline cancelled or timed out")
|
||||
self.disconnect()
|
||||
self._clear_audio_queue()
|
||||
finally:
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
super().on_pipeline_event(event)
|
||||
|
||||
if event.type == PipelineEventType.STT_END:
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
self._processing_tone_done.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
|
||||
)
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||
self._pipeline_had_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return # not connected
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
# Don't overlap TTS and processing beep
|
||||
await self._processing_tone_done.wait()
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
# Update satellite state
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
|
||||
"""Play a tone as feedback to the user if it's enabled."""
|
||||
if (self._tones & tone) != tone:
|
||||
return # not enabled
|
||||
|
||||
if tone not in self._tone_bytes:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes[tone] = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
_TONE_FILENAMES[tone],
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes[tone],
|
||||
silence_before=silence_before,
|
||||
)
|
||||
|
||||
if tone == Tones.PROCESSING:
|
||||
self._processing_tone_done.set()
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
|
@ -51,10 +51,12 @@ class VoIPCallInProgress(VoIPEntity, BinarySensorEntity):
|
|||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
self.async_on_remove(self._device.async_listen_update(self._is_active_changed))
|
||||
self.async_on_remove(
|
||||
self.voip_device.async_listen_update(self._is_active_changed)
|
||||
)
|
||||
|
||||
@callback
|
||||
def _is_active_changed(self, device: VoIPDevice) -> None:
|
||||
"""Call when active state changed."""
|
||||
self._attr_is_on = self._device.is_active
|
||||
self._attr_is_on = self.voip_device.is_active
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from voip_utils import CallInfo
|
||||
from voip_utils import CallInfo, VoipDatagramProtocol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
@ -22,6 +22,7 @@ class VoIPDevice:
|
|||
device_id: str
|
||||
is_active: bool = False
|
||||
update_listeners: list[Callable[[VoIPDevice], None]] = field(default_factory=list)
|
||||
protocol: VoipDatagramProtocol | None = None
|
||||
|
||||
@callback
|
||||
def set_is_active(self, active: bool) -> None:
|
||||
|
@ -56,6 +57,18 @@ class VoIPDevice:
|
|||
|
||||
return False
|
||||
|
||||
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for pipeline select."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id("select", DOMAIN, f"{self.voip_id}-pipeline")
|
||||
|
||||
def get_vad_sensitivity_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for VAD sensitivity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.voip_id}-vad_sensitivity"
|
||||
)
|
||||
|
||||
|
||||
class VoIPDevices:
|
||||
"""Class to store devices."""
|
||||
|
|
|
@ -15,10 +15,10 @@ class VoIPEntity(entity.Entity):
|
|||
_attr_has_entity_name = True
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, device: VoIPDevice) -> None:
|
||||
def __init__(self, voip_device: VoIPDevice) -> None:
|
||||
"""Initialize VoIP entity."""
|
||||
self._device = device
|
||||
self._attr_unique_id = f"{device.voip_id}-{self.entity_description.key}"
|
||||
self.voip_device = voip_device
|
||||
self._attr_unique_id = f"{voip_device.voip_id}-{self.entity_description.key}"
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, device.voip_id)},
|
||||
identifiers={(DOMAIN, voip_device.voip_id)},
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline"],
|
||||
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
|
|
@ -10,6 +10,16 @@
|
|||
}
|
||||
},
|
||||
"entity": {
|
||||
"assist_satellite": {
|
||||
"assist_satellite": {
|
||||
"state": {
|
||||
"listening_wake_word": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_wake_word%]",
|
||||
"listening_command": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::listening_command%]",
|
||||
"responding": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::responding%]",
|
||||
"processing": "[%key:component::assist_satellite::entity::assist_satellite::assist_satellite::state::processing%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"binary_sensor": {
|
||||
"call_in_progress": {
|
||||
"name": "Call in progress"
|
||||
|
|
|
@ -3,15 +3,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import wave
|
||||
|
||||
from voip_utils import (
|
||||
CallInfo,
|
||||
|
@ -21,33 +17,19 @@ from voip_utils import (
|
|||
VoipDatagramProtocol,
|
||||
)
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineNotFound,
|
||||
async_get_pipeline,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.audio_enhancer import (
|
||||
AudioEnhancer,
|
||||
MicroVadEnhancer,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import (
|
||||
AudioBuffer,
|
||||
VadSensitivity,
|
||||
VoiceCommandSegmenter,
|
||||
)
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.util.ulid import ulid_now
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
from .devices import VoIPDevices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,11 +42,8 @@ def make_protocol(
|
|||
) -> VoipDatagramProtocol:
|
||||
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
||||
voip_device = devices.async_get_or_create(call_info)
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(hass, DOMAIN, voip_device.voip_id)
|
||||
try:
|
||||
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
||||
except PipelineNotFound:
|
||||
|
@ -83,22 +62,18 @@ def make_protocol(
|
|||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
||||
hass,
|
||||
DOMAIN,
|
||||
voip_device.voip_id,
|
||||
)
|
||||
if (protocol := voip_device.protocol) is None:
|
||||
raise ValueError("VoIP satellite not found")
|
||||
|
||||
# Pipeline is properly configured
|
||||
return PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(user_id=devices.config_entry.data["user"]),
|
||||
opus_payload_type=call_info.opus_payload_type,
|
||||
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
protocol._rtp_input.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
protocol._rtp_output.opus_payload_type = call_info.opus_payload_type # noqa: SLF001
|
||||
|
||||
protocol.rtcp_state = rtcp_state
|
||||
if protocol.rtcp_state is not None:
|
||||
# Automatically disconnect when BYE is received over RTCP
|
||||
protocol.rtcp_state.bye_callback = protocol.disconnect
|
||||
|
||||
return protocol
|
||||
|
||||
|
||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||
|
@ -143,372 +118,6 @@ class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|||
await self._closed_event.wait()
|
||||
|
||||
|
||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
voip_device: VoIPDevice,
|
||||
context: Context,
|
||||
opus_payload_type: int,
|
||||
pipeline_timeout: float = 30.0,
|
||||
audio_timeout: float = 2.0,
|
||||
buffered_chunks_before_speech: int = 100,
|
||||
listening_tone_enabled: bool = True,
|
||||
processing_tone_enabled: bool = True,
|
||||
error_tone_enabled: bool = True,
|
||||
tone_delay: float = 0.2,
|
||||
tts_extra_timeout: float = 1.0,
|
||||
silence_seconds: float = 1.0,
|
||||
rtcp_state: RtcpState | None = None,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
super().__init__(
|
||||
rate=RATE,
|
||||
width=WIDTH,
|
||||
channels=CHANNELS,
|
||||
opus_payload_type=opus_payload_type,
|
||||
rtcp_state=rtcp_state,
|
||||
)
|
||||
|
||||
self.hass = hass
|
||||
self.language = language
|
||||
self.voip_device = voip_device
|
||||
self.pipeline: Pipeline | None = None
|
||||
self.pipeline_timeout = pipeline_timeout
|
||||
self.audio_timeout = audio_timeout
|
||||
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
||||
self.listening_tone_enabled = listening_tone_enabled
|
||||
self.processing_tone_enabled = processing_tone_enabled
|
||||
self.error_tone_enabled = error_tone_enabled
|
||||
self.tone_delay = tone_delay
|
||||
self.tts_extra_timeout = tts_extra_timeout
|
||||
self.silence_seconds = silence_seconds
|
||||
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._context = context
|
||||
self._conversation_id: str | None = None
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._tts_done = asyncio.Event()
|
||||
self._session_id: str | None = None
|
||||
self._tone_bytes: bytes | None = None
|
||||
self._processing_bytes: bytes | None = None
|
||||
self._error_bytes: bytes | None = None
|
||||
self._pipeline_error: bool = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Server is ready."""
|
||||
super().connection_made(transport)
|
||||
self.voip_device.set_is_active(True)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Handle connection is lost or closed."""
|
||||
super().connection_lost(exc)
|
||||
self.voip_device.set_is_active(False)
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.hass.async_create_background_task(
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
if self._session_id is None:
|
||||
self._session_id = ulid_now()
|
||||
|
||||
# Play listening tone at the start of each cycle
|
||||
if self.listening_tone_enabled:
|
||||
await self._play_listening_tone()
|
||||
|
||||
try:
|
||||
# Wait for speech before starting pipeline
|
||||
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
||||
audio_enhancer = MicroVadEnhancer(0, 0, True)
|
||||
chunk_buffer: deque[bytes] = deque(
|
||||
maxlen=self.buffered_chunks_before_speech,
|
||||
)
|
||||
speech_detected = await self._wait_for_speech(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
)
|
||||
if not speech_detected:
|
||||
_LOGGER.debug("No speech detected")
|
||||
return
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
self._tts_done.clear()
|
||||
|
||||
async def stt_stream():
|
||||
try:
|
||||
async for chunk in self._segment_audio(
|
||||
segmenter,
|
||||
audio_enhancer,
|
||||
chunk_buffer,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
if self.processing_tone_enabled:
|
||||
await self._play_processing_tone()
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Audio timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
self._clear_audio_queue()
|
||||
|
||||
# Run pipeline with a timeout
|
||||
async with asyncio.timeout(self.pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.voip_device.voip_id
|
||||
),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=self.voip_device.device_id,
|
||||
tts_audio_output="wav",
|
||||
)
|
||||
|
||||
if self._pipeline_error:
|
||||
self._pipeline_error = False
|
||||
if self.error_tone_enabled:
|
||||
await self._play_error_tone()
|
||||
else:
|
||||
# Block until TTS is done speaking.
|
||||
#
|
||||
# This is set in _send_tts and has a timeout that's based on the
|
||||
# length of the TTS audio.
|
||||
await self._tts_done.wait()
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except PipelineNotFound:
|
||||
_LOGGER.warning("Pipeline not found")
|
||||
except TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
self._session_id = None
|
||||
self.disconnect()
|
||||
finally:
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
):
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
||||
Returns True if speech was detected, False otherwise.
|
||||
"""
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
chunk_buffer.append(chunk)
|
||||
|
||||
segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
)
|
||||
if segmenter.in_command:
|
||||
# Buffer until command starts
|
||||
if len(vad_buffer) > 0:
|
||||
chunk_buffer.append(vad_buffer.bytes())
|
||||
|
||||
return True
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
return False
|
||||
|
||||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
audio_enhancer: AudioEnhancer,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished."""
|
||||
# Buffered chunks first
|
||||
for buffered_chunk in chunk_buffer:
|
||||
yield buffered_chunk
|
||||
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
vad_buffer = AudioBuffer(assist_pipeline.SAMPLES_PER_CHUNK * WIDTH)
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process_with_vad(
|
||||
chunk,
|
||||
assist_pipeline.SAMPLES_PER_CHUNK,
|
||||
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
|
||||
vad_buffer,
|
||||
):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with asyncio.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
def _event_callback(self, event: PipelineEvent):
|
||||
if not event.data:
|
||||
return
|
||||
|
||||
if event.type == PipelineEventType.INTENT_END:
|
||||
# Capture conversation id
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
tts_output = event.data["tts_output"]
|
||||
if tts_output:
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_tts(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
else:
|
||||
# Empty TTS response
|
||||
self._tts_done.set()
|
||||
elif event.type == PipelineEventType.ERROR:
|
||||
# Play error tone instead of wait for TTS
|
||||
self._pipeline_error = True
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io:
|
||||
with wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
|
||||
if (
|
||||
(sample_rate != RATE)
|
||||
or (sample_width != WIDTH)
|
||||
or (sample_channels != CHANNELS)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
|
||||
f" got {sample_rate}/{sample_width}/{sample_channels}"
|
||||
)
|
||||
|
||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Time out 1 second after TTS audio should be finished
|
||||
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
||||
tts_seconds = tts_samples / RATE
|
||||
|
||||
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
||||
# TTS audio is 16Khz 16-bit mono
|
||||
await self._async_send_audio(audio_bytes)
|
||||
except TimeoutError:
|
||||
_LOGGER.warning("TTS timeout")
|
||||
raise
|
||||
finally:
|
||||
# Signal pipeline to restart
|
||||
self._tts_done.set()
|
||||
|
||||
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
||||
"""Send audio in executor."""
|
||||
await self.hass.async_add_executor_job(
|
||||
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
||||
)
|
||||
|
||||
async def _play_listening_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is listening."""
|
||||
if self._tone_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._tone_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"tone.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(
|
||||
self._tone_bytes,
|
||||
silence_before=self.tone_delay,
|
||||
)
|
||||
|
||||
async def _play_processing_tone(self) -> None:
|
||||
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
||||
if self._processing_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._processing_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"processing.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._processing_bytes)
|
||||
|
||||
async def _play_error_tone(self) -> None:
|
||||
"""Play a tone to indicate a pipeline error occurred."""
|
||||
if self._error_bytes is None:
|
||||
# Do I/O in executor
|
||||
self._error_bytes = await self.hass.async_add_executor_job(
|
||||
self._load_pcm,
|
||||
"error.pcm",
|
||||
)
|
||||
|
||||
await self._async_send_audio(self._error_bytes)
|
||||
|
||||
def _load_pcm(self, file_name: str) -> bytes:
|
||||
"""Load raw audio (16Khz, 16-bit mono)."""
|
||||
return (Path(__file__).parent / file_name).read_bytes()
|
||||
|
||||
|
||||
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
||||
"""Plays a pre-recorded message on a loop."""
|
||||
|
||||
|
|
|
@ -14,11 +14,11 @@ from .const import ATTR_SPEAKER, DOMAIN
|
|||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .models import DomainDataItem
|
||||
from .satellite import WyomingSatellite
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SATELLITE_PLATFORMS = [
|
||||
Platform.ASSIST_SATELLITE,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
|
@ -47,51 +47,25 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
if (satellite_info := service.info.satellite) is not None:
|
||||
# Create satellite device, etc.
|
||||
item.satellite = _make_satellite(hass, entry, service)
|
||||
# Create satellite device
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Set up satellite sensors, switches, etc.
|
||||
# Use config entry id since only one satellite per entry is supported
|
||||
satellite_id = entry.entry_id
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
identifiers={(DOMAIN, satellite_id)},
|
||||
name=satellite_info.name,
|
||||
suggested_area=satellite_info.area,
|
||||
)
|
||||
item.satellite_device = SatelliteDevice(satellite_id, device.id)
|
||||
|
||||
# Set up satellite entity, sensors, switches, etc.
|
||||
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
|
||||
|
||||
# Start satellite communication
|
||||
entry.async_create_background_task(
|
||||
hass,
|
||||
item.satellite.run(),
|
||||
f"Satellite {satellite_info.name}",
|
||||
)
|
||||
|
||||
entry.async_on_unload(item.satellite.stop)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _make_satellite(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
) -> WyomingSatellite:
|
||||
"""Create Wyoming satellite/device from config entry and Wyoming service."""
|
||||
satellite_info = service.info.satellite
|
||||
assert satellite_info is not None
|
||||
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Use config entry id since only one satellite per entry is supported
|
||||
satellite_id = config_entry.entry_id
|
||||
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={(DOMAIN, satellite_id)},
|
||||
name=satellite_info.name,
|
||||
suggested_area=satellite_info.area,
|
||||
)
|
||||
|
||||
satellite_device = SatelliteDevice(
|
||||
satellite_id=satellite_id,
|
||||
device_id=device.id,
|
||||
)
|
||||
|
||||
return WyomingSatellite(hass, config_entry, service, satellite_device)
|
||||
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
@ -102,7 +76,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
platforms = list(item.service.platforms)
|
||||
if item.satellite is not None:
|
||||
if item.satellite_device is not None:
|
||||
platforms += SATELLITE_PLATFORMS
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
"""Support for Wyoming satellite services."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections import defaultdict, deque
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Final
|
||||
from uuid import uuid4
|
||||
import wave
|
||||
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
|
@ -18,20 +17,23 @@ from wyoming.info import Describe, Info
|
|||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import PauseSatellite, RunSatellite
|
||||
from wyoming.snd import Played
|
||||
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, intent, stt, tts
|
||||
from homeassistant.components.assist_pipeline import select as pipeline_select
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, intent, tts
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.async_ import queue_to_iterable
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .entity import WyomingEntity
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -41,19 +43,47 @@ _RESTART_SECONDS: Final = 3
|
|||
_PING_TIMEOUT: Final = 5
|
||||
_PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
_STOP_CHUNK: Final = b""
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
_ASSIST_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
|
||||
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
|
||||
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
|
||||
}
|
||||
_WYOMING_STAGES: dict[assist_pipeline.PipelineStage, PipelineStage] = {
|
||||
assist_pipeline.PipelineStage.WAKE_WORD: PipelineStage.WAKE,
|
||||
assist_pipeline.PipelineStage.STT: PipelineStage.ASR,
|
||||
assist_pipeline.PipelineStage.INTENT: PipelineStage.HANDLE,
|
||||
assist_pipeline.PipelineStage.TTS: PipelineStage.TTS,
|
||||
}
|
||||
|
||||
|
||||
class WyomingSatellite:
|
||||
"""Remove voice satellite running the Wyoming protocol."""
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP Assist satellite entity."""
|
||||
domain_data: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
assert domain_data.satellite_device is not None
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSatellite(
|
||||
hass, config_entry, domain_data.service, domain_data.satellite_device
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WyomingSatellite(WyomingEntity, assist_satellite.AssistSatelliteEntity):
|
||||
"""Remote voice satellite running the Wyoming protocol."""
|
||||
|
||||
_attr_supported_features = (
|
||||
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -63,6 +93,9 @@ class WyomingSatellite:
|
|||
device: SatelliteDevice,
|
||||
) -> None:
|
||||
"""Initialize satellite."""
|
||||
WyomingEntity.__init__(self, device)
|
||||
assist_satellite.AssistSatelliteEntity.__init__(self)
|
||||
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.service = service
|
||||
|
@ -70,20 +103,171 @@ class WyomingSatellite:
|
|||
self.is_running = True
|
||||
|
||||
self._client: AsyncTcpClient | None = None
|
||||
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
||||
self._chunk_converter = AudioChunkConverter(
|
||||
rate=assist_pipeline.SAMPLE_RATE,
|
||||
width=assist_pipeline.SAMPLE_WIDTH,
|
||||
channels=assist_pipeline.SAMPLE_CHANNELS,
|
||||
)
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event = asyncio.Event()
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._pipeline_id: str | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._muted_changed_event = asyncio.Event()
|
||||
|
||||
self._conversation_id: str | None = None
|
||||
self._conversation_id_time: float | None = None
|
||||
# Results of remotely triggered pipelines
|
||||
self._pipeline_result_futures: dict[
|
||||
assist_pipeline.PipelineStage, deque[asyncio.Future]
|
||||
] = defaultdict(deque)
|
||||
self._played_timeout_id: int | None = None
|
||||
|
||||
self.device.set_is_muted_listener(self._muted_changed)
|
||||
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass, self.run(), "wyoming_satellite_run"
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
self.stop()
|
||||
|
||||
async def async_trigger_pipeline_on_satellite(
|
||||
self,
|
||||
start_stage: assist_pipeline.PipelineStage,
|
||||
end_stage: assist_pipeline.PipelineStage,
|
||||
run_config: assist_satellite.PipelineRunConfig,
|
||||
) -> assist_satellite.PipelineRunResult | None:
|
||||
"""Run a pipeline on the satellite from start to end stage."""
|
||||
if self._client is None:
|
||||
return None # not connected
|
||||
|
||||
result_future: asyncio.Future[str | None] = asyncio.Future()
|
||||
self._pipeline_result_futures[start_stage].append(result_future)
|
||||
|
||||
await self._client.write_event(
|
||||
RunPipeline(
|
||||
start_stage=_WYOMING_STAGES[start_stage],
|
||||
end_stage=_WYOMING_STAGES[end_stage],
|
||||
wake_word_names=run_config.wake_word_names,
|
||||
announce_text=run_config.announce_text,
|
||||
).event()
|
||||
)
|
||||
|
||||
# Wait for result
|
||||
result = await result_future
|
||||
if start_stage == assist_pipeline.PipelineStage.WAKE_WORD:
|
||||
return assist_satellite.PipelineRunResult(detected_wake_word=result)
|
||||
|
||||
if start_stage == assist_pipeline.PipelineStage.STT:
|
||||
return assist_satellite.PipelineRunResult(command_text=result)
|
||||
|
||||
return None
|
||||
|
||||
def on_pipeline_event(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
"""Translate pipeline events into Wyoming events."""
|
||||
super().on_pipeline_event(event)
|
||||
|
||||
if self._client is None:
|
||||
return # stopping
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
# Inform client of wake word detection
|
||||
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||
detected_wake_word = wake_word_output["wake_word_id"]
|
||||
detection = Detection(
|
||||
name=detected_wake_word,
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
|
||||
# Set result for remote pipeline trigger
|
||||
if result_futures := self._pipeline_result_futures[
|
||||
assist_pipeline.PipelineStage.WAKE_WORD
|
||||
]:
|
||||
result_futures.popleft().set_result(detected_wake_word)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
)
|
||||
|
||||
# Set result for remote pipeline trigger
|
||||
if result_futures := self._pipeline_result_futures[
|
||||
assist_pipeline.PipelineStage.STT
|
||||
]:
|
||||
result_futures.popleft().set_result(stt_text)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
voice=SynthesizeVoice(
|
||||
name=event.data.get("voice"),
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||
# Pipeline error
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Error(
|
||||
text=event.data["message"], code=event.data["code"]
|
||||
).event()
|
||||
)
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run and maintain a connection to satellite."""
|
||||
_LOGGER.debug("Running satellite task")
|
||||
|
@ -125,6 +309,9 @@ class WyomingSatellite:
|
|||
|
||||
def stop(self) -> None:
|
||||
"""Signal satellite task to stop running."""
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(_STOP_CHUNK)
|
||||
|
||||
# Tell satellite to stop running
|
||||
self._send_pause()
|
||||
|
||||
|
@ -173,7 +360,7 @@ class WyomingSatellite:
|
|||
"""Run when device muted status changes."""
|
||||
if self.device.is_muted:
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
self._audio_queue.put_nowait(_STOP_CHUNK)
|
||||
|
||||
# Send pause event so satellite can react immediately
|
||||
self._send_pause()
|
||||
|
@ -185,13 +372,13 @@ class WyomingSatellite:
|
|||
"""Run when device pipeline changes."""
|
||||
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
self._audio_queue.put_nowait(_STOP_CHUNK)
|
||||
|
||||
def _audio_settings_changed(self) -> None:
|
||||
"""Run when device audio settings."""
|
||||
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
self._audio_queue.put_nowait(_STOP_CHUNK)
|
||||
|
||||
async def _connect_and_loop(self) -> None:
|
||||
"""Connect to satellite and run pipelines until an error occurs."""
|
||||
|
@ -222,7 +409,9 @@ class WyomingSatellite:
|
|||
|
||||
async def _run_pipeline_loop(self) -> None:
|
||||
"""Run a pipeline one or more times."""
|
||||
assert self._client is not None
|
||||
if self._client is None:
|
||||
return # stopping
|
||||
|
||||
client_info: Info | None = None
|
||||
wake_word_phrase: str | None = None
|
||||
run_pipeline: RunPipeline | None = None
|
||||
|
@ -302,7 +491,7 @@ class WyomingSatellite:
|
|||
elif AudioStop.is_type(client_event.type) and self._is_pipeline_running:
|
||||
# Stop pipeline
|
||||
_LOGGER.debug("Client requested pipeline to stop")
|
||||
self._audio_queue.put_nowait(b"")
|
||||
self._audio_queue.put_nowait(_STOP_CHUNK)
|
||||
elif Info.is_type(client_event.type):
|
||||
client_info = Info.from_event(client_event)
|
||||
_LOGGER.debug("Updated client info: %s", client_info)
|
||||
|
@ -328,7 +517,19 @@ class WyomingSatellite:
|
|||
if found_phrase:
|
||||
break
|
||||
|
||||
if result_futures := self._pipeline_result_futures[
|
||||
assist_pipeline.PipelineStage.WAKE_WORD
|
||||
]:
|
||||
result_futures.popleft().set_result(wake_word_phrase)
|
||||
|
||||
_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
|
||||
elif Played.is_type(client_event.type):
|
||||
# Set result for remote pipeline trigger
|
||||
self._played_timeout_id = None
|
||||
if result_futures := self._pipeline_result_futures[
|
||||
assist_pipeline.PipelineStage.TTS
|
||||
]:
|
||||
result_futures.popleft().set_result(None)
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
|
@ -344,8 +545,8 @@ class WyomingSatellite:
|
|||
"""Run a pipeline once."""
|
||||
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||
|
||||
start_stage = _STAGES.get(run_pipeline.start_stage)
|
||||
end_stage = _STAGES.get(run_pipeline.end_stage)
|
||||
start_stage = _ASSIST_STAGES.get(run_pipeline.start_stage)
|
||||
end_stage = _ASSIST_STAGES.get(run_pipeline.end_stage)
|
||||
|
||||
if start_stage is None:
|
||||
raise ValueError(f"Invalid start stage: {start_stage}")
|
||||
|
@ -353,77 +554,32 @@ class WyomingSatellite:
|
|||
if end_stage is None:
|
||||
raise ValueError(f"Invalid end stage: {end_stage}")
|
||||
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
self.device.satellite_id,
|
||||
)
|
||||
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||
assert pipeline is not None
|
||||
|
||||
# We will push audio in through a queue
|
||||
self._audio_queue = asyncio.Queue()
|
||||
stt_stream = self._stt_stream()
|
||||
|
||||
# Start pipeline running
|
||||
_LOGGER.debug(
|
||||
"Starting pipeline %s from %s to %s",
|
||||
pipeline.name,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
|
||||
# Reset conversation id, if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = str(uuid4())
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
self._is_pipeline_running = True
|
||||
self._pipeline_ended_event.clear()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
assist_pipeline.async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=Context(),
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language=pipeline.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
self._async_accept_pipeline_from_satellite(
|
||||
queue_to_iterable(self._audio_queue),
|
||||
start_stage,
|
||||
end_stage,
|
||||
pipeline_entity_id=self.device.get_pipeline_entity_id(self.hass),
|
||||
vad_sensitivity_entity_id=self.device.get_vad_sensitivity_entity_id(
|
||||
self.hass
|
||||
),
|
||||
stt_stream=stt_stream,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
tts_audio_output="wav",
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
noise_suppression_level=self.device.noise_suppression_level,
|
||||
auto_gain_dbfs=self.device.auto_gain,
|
||||
volume_multiplier=self.device.volume_multiplier,
|
||||
silence_seconds=VadSensitivity.to_seconds(
|
||||
self.device.vad_sensitivity
|
||||
),
|
||||
),
|
||||
device_id=self.device.device_id,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_id=self._conversation_id,
|
||||
tts_input=run_pipeline.announce_text,
|
||||
),
|
||||
name="wyoming satellite pipeline",
|
||||
)
|
||||
|
||||
async def _send_delayed_ping(self) -> None:
|
||||
"""Send ping to satellite after a delay."""
|
||||
assert self._client is not None
|
||||
if self._client is None:
|
||||
return # stopping
|
||||
|
||||
try:
|
||||
await asyncio.sleep(_PING_SEND_DELAY)
|
||||
|
@ -431,91 +587,6 @@ class WyomingSatellite:
|
|||
except ConnectionError:
|
||||
pass # handled with timeout
|
||||
|
||||
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
"""Translate pipeline events into Wyoming events."""
|
||||
assert self._client is not None
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
# Inform client of wake word detection
|
||||
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||
detection = Detection(
|
||||
name=wake_word_output["wake_word_id"],
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
voice=SynthesizeVoice(
|
||||
name=event.data.get("voice"),
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||
# Pipeline error
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Error(
|
||||
text=event.data["message"], code=event.data["code"]
|
||||
).event()
|
||||
)
|
||||
)
|
||||
|
||||
async def _connect(self) -> None:
|
||||
"""Connect to satellite over TCP."""
|
||||
await self._disconnect()
|
||||
|
@ -537,62 +608,78 @@ class WyomingSatellite:
|
|||
|
||||
async def _stream_tts(self, media_id: str) -> None:
|
||||
"""Stream TTS WAV audio to satellite in chunks."""
|
||||
assert self._client is not None
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||
|
||||
timestamp = 0
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
timestamp=timestamp,
|
||||
).event()
|
||||
)
|
||||
|
||||
# Stream audio chunks
|
||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||
chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=audio_bytes,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
timestamp += chunk.seconds
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
|
||||
async def _stt_stream(self) -> AsyncGenerator[bytes]:
|
||||
"""Yield audio chunks from a queue."""
|
||||
try:
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
if self._client is None:
|
||||
return # stopping
|
||||
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
pass # ignore
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass, media_id
|
||||
)
|
||||
if extension != "wav":
|
||||
raise ValueError(
|
||||
f"Cannot stream audio format to satellite: {extension}"
|
||||
)
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
num_frames = wav_file.getnframes()
|
||||
_LOGGER.debug("Streaming %s TTS sample(s)", num_frames)
|
||||
|
||||
wav_seconds = num_frames / sample_rate
|
||||
self._played_timeout_id = time.monotonic_ns()
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._tts_played_timeout(self._played_timeout_id, wav_seconds + 1),
|
||||
"wyoming tts timeout",
|
||||
)
|
||||
|
||||
timestamp = 0
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
timestamp=timestamp,
|
||||
).event()
|
||||
)
|
||||
|
||||
# Stream audio chunks
|
||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||
chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=audio_bytes,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
timestamp += chunk.seconds
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
finally:
|
||||
self.tts_response_finished()
|
||||
|
||||
async def _tts_played_timeout(self, timeout_id: int, timeout_sec: float) -> None:
|
||||
"""Set pipeline result after timeout if Played message is not received."""
|
||||
await asyncio.sleep(timeout_sec)
|
||||
if self._played_timeout_id != timeout_id:
|
||||
return
|
||||
|
||||
if result_futures := self._pipeline_result_futures[
|
||||
assist_pipeline.PipelineStage.TTS
|
||||
]:
|
||||
result_futures.popleft().set_result(None)
|
||||
|
||||
@callback
|
||||
def _handle_timer(
|
||||
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
|
||||
) -> None:
|
||||
"""Forward timer events to satellite."""
|
||||
assert self._client is not None
|
||||
if self._client is None:
|
||||
return # stopping
|
||||
|
||||
_LOGGER.debug("Timer event: type=%s, info=%s", event_type, timer)
|
||||
event: Event | None = None
|
|
@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import WyomingSatelliteEntity
|
||||
from .entity import WyomingEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
@ -28,12 +28,12 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.satellite_device is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
|
||||
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite_device)])
|
||||
|
||||
|
||||
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
|
||||
class WyomingSatelliteAssistInProgress(WyomingEntity, BinarySensorEntity):
|
||||
"""Entity to represent Assist is in progress for satellite."""
|
||||
|
||||
entity_description = BinarySensorEntityDescription(
|
||||
|
|
|
@ -157,3 +157,10 @@ class SatelliteDevice:
|
|||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.satellite_id}-vad_sensitivity"
|
||||
)
|
||||
|
||||
def get_satellite_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for satellite."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"assist_satellite", DOMAIN, f"{self.satellite_id}-assist_satellite"
|
||||
)
|
||||
|
|
|
@ -6,10 +6,10 @@ from homeassistant.helpers import entity
|
|||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
|
||||
from .const import DOMAIN
|
||||
from .satellite import SatelliteDevice
|
||||
from .devices import SatelliteDevice
|
||||
|
||||
|
||||
class WyomingSatelliteEntity(entity.Entity):
|
||||
class WyomingEntity(entity.Entity):
|
||||
"""Wyoming satellite entity."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
|
|
|
@ -3,10 +3,15 @@
|
|||
"name": "Wyoming Protocol",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline", "intent", "conversation"],
|
||||
"dependencies": [
|
||||
"assist_satellite",
|
||||
"assist_pipeline",
|
||||
"intent",
|
||||
"conversation"
|
||||
],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==1.5.4"],
|
||||
"requirements": ["wyoming==1.6.0"],
|
||||
"zeroconf": ["_wyoming._tcp.local."]
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .data import WyomingService
|
||||
from .satellite import WyomingSatellite
|
||||
from .devices import SatelliteDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -11,4 +11,4 @@ class DomainDataItem:
|
|||
"""Domain data item."""
|
||||
|
||||
service: WyomingService
|
||||
satellite: WyomingSatellite | None = None
|
||||
satellite_device: SatelliteDevice | None = None
|
||||
|
|
|
@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import WyomingSatelliteEntity
|
||||
from .entity import WyomingEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
@ -30,9 +30,9 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.satellite_device is not None
|
||||
|
||||
device = item.satellite.device
|
||||
device = item.satellite_device
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSatelliteAutoGainNumber(device),
|
||||
|
@ -41,7 +41,7 @@ async def async_setup_entry(
|
|||
)
|
||||
|
||||
|
||||
class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber):
|
||||
class WyomingSatelliteAutoGainNumber(WyomingEntity, RestoreNumber):
|
||||
"""Entity to represent auto gain amount."""
|
||||
|
||||
entity_description = NumberEntityDescription(
|
||||
|
@ -70,7 +70,7 @@ class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber):
|
|||
self._device.set_auto_gain(auto_gain)
|
||||
|
||||
|
||||
class WyomingSatelliteVolumeMultiplierNumber(WyomingSatelliteEntity, RestoreNumber):
|
||||
class WyomingSatelliteVolumeMultiplierNumber(WyomingEntity, RestoreNumber):
|
||||
"""Entity to represent microphone volume multiplier."""
|
||||
|
||||
entity_description = NumberEntityDescription(
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
|
||||
from .const import DOMAIN
|
||||
from .devices import SatelliteDevice
|
||||
from .entity import WyomingSatelliteEntity
|
||||
from .entity import WyomingEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
@ -42,9 +42,9 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.satellite_device is not None
|
||||
|
||||
device = item.satellite.device
|
||||
device = item.satellite_device
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSatellitePipelineSelect(hass, device),
|
||||
|
@ -54,14 +54,14 @@ async def async_setup_entry(
|
|||
)
|
||||
|
||||
|
||||
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
|
||||
class WyomingSatellitePipelineSelect(WyomingEntity, AssistPipelineSelect):
|
||||
"""Pipeline selector for Wyoming satellites."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
self.device = device
|
||||
|
||||
WyomingSatelliteEntity.__init__(self, device)
|
||||
WyomingEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
|
@ -71,7 +71,7 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec
|
|||
|
||||
|
||||
class WyomingSatelliteNoiseSuppressionLevelSelect(
|
||||
WyomingSatelliteEntity, SelectEntity, restore_state.RestoreEntity
|
||||
WyomingEntity, SelectEntity, restore_state.RestoreEntity
|
||||
):
|
||||
"""Entity to represent noise suppression level setting."""
|
||||
|
||||
|
@ -99,16 +99,14 @@ class WyomingSatelliteNoiseSuppressionLevelSelect(
|
|||
self._device.set_noise_suppression_level(_NOISE_SUPPRESSION_LEVEL[option])
|
||||
|
||||
|
||||
class WyomingSatelliteVadSensitivitySelect(
|
||||
WyomingSatelliteEntity, VadSensitivitySelect
|
||||
):
|
||||
class WyomingSatelliteVadSensitivitySelect(WyomingEntity, VadSensitivitySelect):
|
||||
"""VAD sensitivity selector for Wyoming satellites."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
|
||||
"""Initialize a VAD sensitivity selector."""
|
||||
self.device = device
|
||||
|
||||
WyomingSatelliteEntity.__init__(self, device)
|
||||
WyomingEntity.__init__(self, device)
|
||||
VadSensitivitySelect.__init__(self, hass, device.satellite_id)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ from homeassistant.helpers import restore_state
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import WyomingSatelliteEntity
|
||||
from .entity import WyomingEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
@ -27,13 +27,13 @@ async def async_setup_entry(
|
|||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
assert item.satellite_device is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite.device)])
|
||||
async_add_entities([WyomingSatelliteMuteSwitch(item.satellite_device)])
|
||||
|
||||
|
||||
class WyomingSatelliteMuteSwitch(
|
||||
WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity
|
||||
WyomingEntity, restore_state.RestoreEntity, SwitchEntity
|
||||
):
|
||||
"""Entity to represent if satellite is muted."""
|
||||
|
||||
|
@ -51,7 +51,7 @@ class WyomingSatelliteMuteSwitch(
|
|||
|
||||
# Default to off
|
||||
self._attr_is_on = (state is not None) and (state.state == STATE_ON)
|
||||
self._device.is_muted = self._attr_is_on
|
||||
self._device.set_is_muted(self._attr_is_on)
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on."""
|
||||
|
|
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
|||
|
||||
AIR_QUALITY = "air_quality"
|
||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||
ASSIST_SATELLITE = "assist_satellite"
|
||||
BINARY_SENSOR = "binary_sensor"
|
||||
BUTTON = "button"
|
||||
CALENDAR = "calendar"
|
||||
|
|
|
@ -5,22 +5,28 @@ from __future__ import annotations
|
|||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
Future,
|
||||
Queue,
|
||||
Semaphore,
|
||||
Task,
|
||||
TimerHandle,
|
||||
gather,
|
||||
get_running_loop,
|
||||
timeout as async_timeout,
|
||||
)
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_shutdown_run_callback_threadsafe"
|
||||
|
||||
_DataT = TypeVar("_DataT", default=Any)
|
||||
|
||||
|
||||
def create_eager_task[_T](
|
||||
coro: Coroutine[Any, Any, _T],
|
||||
|
@ -138,3 +144,20 @@ def get_scheduled_timer_handles(loop: AbstractEventLoop) -> list[TimerHandle]:
|
|||
"""Return a list of scheduled TimerHandles."""
|
||||
handles: list[TimerHandle] = loop._scheduled # type: ignore[attr-defined] # noqa: SLF001
|
||||
return handles
|
||||
|
||||
|
||||
async def queue_to_iterable(
|
||||
queue: Queue[_DataT], timeout: float | None = None
|
||||
) -> AsyncIterable[_DataT]:
|
||||
"""Stream items from a queue until None with an optional timeout per item."""
|
||||
if timeout is None:
|
||||
while (item := await queue.get()) is not None:
|
||||
yield item
|
||||
else:
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
||||
while item is not None:
|
||||
yield item
|
||||
async with async_timeout(timeout):
|
||||
item = await queue.get()
|
||||
|
|
|
@ -2936,7 +2936,7 @@ wled==0.20.2
|
|||
wolf-comm==0.0.9
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.5.4
|
||||
wyoming==1.6.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -2319,7 +2319,7 @@ wled==0.20.2
|
|||
wolf-comm==0.0.9
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.5.4
|
||||
wyoming==1.6.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
11
tests/components/assist_satellite/__init__.py
Normal file
11
tests/components/assist_satellite/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""Tests for the Assist satellite integration."""
|
||||
|
||||
from homeassistant.components import assist_satellite
|
||||
|
||||
|
||||
class MockSatelliteEntity(assist_satellite.AssistSatelliteEntity):
|
||||
"""Mock satellite that supports pipeline triggering."""
|
||||
|
||||
_attr_supported_features = (
|
||||
assist_satellite.AssistSatelliteEntityFeature.TRIGGER_PIPELINE
|
||||
)
|
104
tests/components/assist_satellite/conftest.py
Normal file
104
tests/components/assist_satellite/conftest.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
"""Common fixtures for the Assist satellite tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import assist_satellite
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MockSatelliteEntity
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test"
|
||||
|
||||
|
||||
async def mock_config_entry_setup(
|
||||
hass: HomeAssistant, satellite_entity: MockSatelliteEntity
|
||||
) -> MockConfigEntry:
|
||||
"""Set up a test satellite platform via config entry."""
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
config_entry, [assist_satellite.DOMAIN]
|
||||
)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(
|
||||
config_entry, assist_satellite.DOMAIN
|
||||
)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
|
||||
async def async_setup_entry_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test tts platform via config entry."""
|
||||
async_add_entities([satellite_entity])
|
||||
|
||||
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.{assist_satellite.DOMAIN}", loaded_platform)
|
||||
|
||||
config_entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
||||
|
||||
|
||||
class AssistSatelliteFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, AssistSatelliteFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_satellite_entity() -> MockSatelliteEntity:
|
||||
"""Test satellite entity."""
|
||||
return MockSatelliteEntity()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_satellite(
|
||||
hass: HomeAssistant, setup_mock_satellite_entity: MockSatelliteEntity
|
||||
) -> MockSatelliteEntity:
|
||||
"""Create a config entry."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
await mock_config_entry_setup(hass, setup_mock_satellite_entity)
|
||||
return setup_mock_satellite_entity
|
204
tests/components/assist_satellite/test_init.py
Normal file
204
tests/components/assist_satellite/test_init.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
"""Tests for Assist satellite."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import assist_pipeline, assist_satellite
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import MockSatelliteEntity
|
||||
|
||||
|
||||
async def test_wait_wake(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test wait_wake service."""
|
||||
test_wake_word = "test-wake-word"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite,
|
||||
"async_trigger_pipeline_on_satellite",
|
||||
return_value=assist_satellite.PipelineRunResult(
|
||||
detected_wake_word=test_wake_word
|
||||
),
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_WAIT_WAKE,
|
||||
{
|
||||
ATTR_ENTITY_ID: mock_satellite.entity_id,
|
||||
assist_satellite.ATTR_WAKE_WORDS: [test_wake_word],
|
||||
},
|
||||
return_response=True,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
|
||||
assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
assist_satellite.PipelineRunConfig(wake_word_names=[test_wake_word]),
|
||||
)
|
||||
assert result == {mock_satellite.entity_id: test_wake_word}
|
||||
|
||||
|
||||
async def test_announce_wait_wake(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test wait_wake service with announcement."""
|
||||
test_wake_word = "test-wake-word"
|
||||
announce_text = "test-announce-text"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite,
|
||||
"async_trigger_pipeline_on_satellite",
|
||||
return_value=assist_satellite.PipelineRunResult(
|
||||
detected_wake_word=test_wake_word
|
||||
),
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_WAIT_WAKE,
|
||||
{
|
||||
ATTR_ENTITY_ID: mock_satellite.entity_id,
|
||||
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
|
||||
assist_satellite.ATTR_WAKE_WORDS: [test_wake_word],
|
||||
},
|
||||
return_response=True,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_count == 2
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_args_list[0].args == (
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_satellite.PipelineRunConfig(announce_text=announce_text),
|
||||
)
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_args_list[1].args == (
|
||||
assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
assist_satellite.PipelineRunConfig(wake_word_names=[test_wake_word]),
|
||||
)
|
||||
assert result == {mock_satellite.entity_id: test_wake_word}
|
||||
|
||||
|
||||
async def test_get_command(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test get_command service."""
|
||||
test_command = "test-command"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite,
|
||||
"async_trigger_pipeline_on_satellite",
|
||||
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_GET_COMMAND,
|
||||
{ATTR_ENTITY_ID: mock_satellite.entity_id},
|
||||
return_response=True,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
|
||||
assist_pipeline.PipelineStage.STT,
|
||||
assist_pipeline.PipelineStage.STT,
|
||||
assist_satellite.PipelineRunConfig(),
|
||||
)
|
||||
assert result == {mock_satellite.entity_id: test_command}
|
||||
|
||||
|
||||
async def test_announce_get_command(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test get_command service with announcement."""
|
||||
test_command = "test-command"
|
||||
announce_text = "test-announce-text"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite,
|
||||
"async_trigger_pipeline_on_satellite",
|
||||
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_GET_COMMAND,
|
||||
{
|
||||
ATTR_ENTITY_ID: mock_satellite.entity_id,
|
||||
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
|
||||
},
|
||||
return_response=True,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_count == 2
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_args_list[0].args == (
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_satellite.PipelineRunConfig(announce_text=announce_text),
|
||||
)
|
||||
assert mock_async_trigger_pipeline_on_satellite.call_args_list[1].args == (
|
||||
assist_pipeline.PipelineStage.STT,
|
||||
assist_pipeline.PipelineStage.STT,
|
||||
assist_satellite.PipelineRunConfig(),
|
||||
)
|
||||
assert result == {mock_satellite.entity_id: test_command}
|
||||
|
||||
|
||||
async def test_get_command_process(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test get_command service with processing enabled."""
|
||||
test_command = "test-command"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite,
|
||||
"async_trigger_pipeline_on_satellite",
|
||||
return_value=assist_satellite.PipelineRunResult(command_text=test_command),
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_GET_COMMAND,
|
||||
{
|
||||
ATTR_ENTITY_ID: mock_satellite.entity_id,
|
||||
assist_satellite.ATTR_PROCESS: True,
|
||||
},
|
||||
return_response=True,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
# Pipeline should run to TTS stage now
|
||||
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
|
||||
assist_pipeline.PipelineStage.STT,
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_satellite.PipelineRunConfig(),
|
||||
)
|
||||
assert result == {mock_satellite.entity_id: test_command}
|
||||
|
||||
|
||||
async def test_say_text(
|
||||
hass: HomeAssistant, mock_satellite: MockSatelliteEntity
|
||||
) -> None:
|
||||
"""Test say_text service."""
|
||||
announce_text = "test-announce-text"
|
||||
|
||||
with patch.object(
|
||||
mock_satellite, "async_trigger_pipeline_on_satellite", return_value=None
|
||||
) as mock_async_trigger_pipeline_on_satellite:
|
||||
result = await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_SAY_TEXT,
|
||||
{
|
||||
ATTR_ENTITY_ID: mock_satellite.entity_id,
|
||||
assist_satellite.ATTR_ANNOUNCE_TEXT: announce_text,
|
||||
},
|
||||
return_response=False,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
mock_async_trigger_pipeline_on_satellite.assert_called_once_with(
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_pipeline.PipelineStage.TTS,
|
||||
assist_satellite.PipelineRunConfig(announce_text=announce_text),
|
||||
)
|
||||
assert result is None
|
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
10
tests/components/voip/snapshots/test_voip.ambr
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,15 +3,26 @@
|
|||
import asyncio
|
||||
import io
|
||||
from pathlib import Path
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from voip_utils import CallInfo
|
||||
|
||||
from homeassistant.components import assist_pipeline, voip
|
||||
from homeassistant.components.voip.devices import VoIPDevice
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, voip
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteState,
|
||||
)
|
||||
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
@ -35,33 +46,180 @@ def _empty_wav() -> bytes:
|
|||
return wav_io.getvalue()
|
||||
|
||||
|
||||
def async_get_satellite_entity(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> AssistSatelliteEntity | None:
|
||||
"""Get Assist satellite entity."""
|
||||
ent_reg = er.async_get(hass)
|
||||
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
||||
)
|
||||
if satellite_entity_id is None:
|
||||
return None
|
||||
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||
assist_satellite.DOMAIN
|
||||
]
|
||||
return component.get_entity(satellite_entity_id)
|
||||
|
||||
|
||||
async def test_is_valid_call(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that a call is now allowed from an unknown device."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
||||
assert not protocol.is_valid_call(call_info)
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
||||
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
||||
)
|
||||
assert allowed_call_entity_id is not None
|
||||
state = hass.states.get(allowed_call_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
|
||||
# Allow calls
|
||||
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
||||
assert protocol.is_valid_call(call_info)
|
||||
|
||||
|
||||
async def test_calls_not_allowed(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
# Test the playback
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
def send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be problem.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
protocol.transport = Mock()
|
||||
protocol.loop_delay = 0
|
||||
with patch.object(protocol, "send_audio", send_audio):
|
||||
protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
||||
|
||||
async def test_pipeline_not_found(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
||||
):
|
||||
protocol: PreRecordMessageProtocol = make_protocol(
|
||||
hass, voip_devices, call_info
|
||||
)
|
||||
|
||||
assert isinstance(protocol, PreRecordMessageProtocol)
|
||||
assert protocol.file_name == "problem.pcm"
|
||||
|
||||
|
||||
async def test_satellite_prepared(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that satellite is prepared for a call."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test",
|
||||
stt_engine="test",
|
||||
stt_language="en",
|
||||
tts_engine="test",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
protocol = make_protocol(hass, voip_devices, call_info)
|
||||
assert protocol == satellite
|
||||
|
||||
|
||||
async def test_pipeline(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
call_info: CallInfo,
|
||||
) -> None:
|
||||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
voip_user_id = satellite.config_entry.data["user"]
|
||||
assert voip_user_id
|
||||
|
||||
return 0
|
||||
# Satellite is muted until a call begins
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
# Used to test that audio queue is cleared before pipeline starts
|
||||
bad_chunk = bytes([1, 2, 3, 4])
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs
|
||||
):
|
||||
assert context.user_id == voip_user_id
|
||||
assert device_id == voip_device.device_id
|
||||
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
assert _chunk != bad_chunk
|
||||
assert chunk != bad_chunk
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
|
@ -71,6 +229,38 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_START,
|
||||
data={"engine": "test", "metadata": {}},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.INTENT_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"intent_input": "fake-text",
|
||||
"conversation_id": None,
|
||||
"device_id": None,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -83,6 +273,21 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# Fake tts result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||
data={
|
||||
"engine": "test",
|
||||
"language": hass.config.language,
|
||||
"voice": "test",
|
||||
"tts_input": "fake-text",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Proceed with media output
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
|
@ -91,6 +296,18 @@ async def test_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.RUN_END
|
||||
)
|
||||
)
|
||||
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
def tts_response_finished():
|
||||
original_tts_response_finished()
|
||||
done.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
|
@ -100,102 +317,56 @@ async def test_pipeline(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("aggressive"),
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite._tones = Tones(0)
|
||||
satellite.transport = Mock()
|
||||
|
||||
satellite.connection_made(satellite.transport)
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
# Ensure audio queue is cleared before pipeline starts
|
||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
||||
satellite._audio_queue.put_nowait(bad_chunk)
|
||||
|
||||
def send_audio(*args, **kwargs):
|
||||
# Test finished successfully
|
||||
done.set()
|
||||
# Don't send audio
|
||||
pass
|
||||
|
||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||
satellite.send_audio = Mock(side_effect=send_audio)
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes aggressive VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
"""Test timeout during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._wait_for_speech",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
pipeline_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
# Finished speaking
|
||||
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice) -> None:
|
||||
async def test_stt_stream_timeout(
|
||||
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
||||
) -> None:
|
||||
"""Test timeout in STT stream during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
|
@ -205,28 +376,19 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
pass
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
audio_timeout=0.001,
|
||||
listening_tone_enabled=False,
|
||||
processing_tone_enabled=False,
|
||||
error_tone_enabled=False,
|
||||
)
|
||||
satellite._tones = Tones(0)
|
||||
satellite._audio_chunk_timeout = 0.001
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
satellite.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -235,26 +397,34 @@ async def test_stt_stream_timeout(hass: HomeAssistant, voip_device: VoIPDevice)
|
|||
|
||||
async def test_tts_timeout(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will time out based on its length."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -278,15 +448,7 @@ async def test_tts_timeout(
|
|||
|
||||
tone_bytes = bytes([1, 2, 3, 4])
|
||||
|
||||
def send_audio(audio_bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
||||
# Block here to force a timeout in _send_tts
|
||||
time.sleep(2)
|
||||
|
||||
async def async_send_audio(audio_bytes, **kwargs):
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
if audio_bytes == tone_bytes:
|
||||
# Not TTS
|
||||
return
|
||||
|
@ -303,37 +465,22 @@ async def test_tts_timeout(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
tts_extra_timeout=0.001,
|
||||
listening_tone_enabled=True,
|
||||
processing_tone_enabled=True,
|
||||
error_tone_enabled=True,
|
||||
silence_seconds=assist_pipeline.vad.VadSensitivity.to_seconds("relaxed"),
|
||||
)
|
||||
rtp_protocol._tone_bytes = tone_bytes
|
||||
rtp_protocol._processing_bytes = tone_bytes
|
||||
rtp_protocol._error_bytes = tone_bytes
|
||||
rtp_protocol.transport = Mock()
|
||||
rtp_protocol.send_audio = Mock()
|
||||
satellite._tts_extra_timeout = 0.001
|
||||
for tone in Tones:
|
||||
satellite._tone_bytes[tone] = tone_bytes
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
satellite.transport = Mock()
|
||||
satellite.send_audio = Mock()
|
||||
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -342,17 +489,17 @@ async def test_tts_timeout(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
# silence
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -361,26 +508,34 @@ async def test_tts_timeout(
|
|||
|
||||
async def test_tts_wrong_extension(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -411,28 +566,17 @@ async def test_tts_wrong_extension(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -441,16 +585,16 @@ async def test_tts_wrong_extension(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -459,26 +603,34 @@ async def test_tts_wrong_extension(
|
|||
|
||||
async def test_tts_wrong_wav_format(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will only stream WAV audio with a specific format."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -516,28 +668,17 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
original_send_tts = rtp_protocol._send_tts
|
||||
original_send_tts = satellite._send_tts
|
||||
|
||||
async def send_tts(*args, **kwargs):
|
||||
# Call original then end test successfully
|
||||
|
@ -546,16 +687,16 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
done.set()
|
||||
|
||||
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -564,24 +705,32 @@ async def test_tts_wrong_wav_format(
|
|||
|
||||
async def test_empty_tts_output(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test that TTS will not stream when output is empty."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def process_10ms(self, chunk):
|
||||
"""Anything non-zero is speech."""
|
||||
if sum(chunk) > 0:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
in_command = False
|
||||
async for chunk in stt_stream:
|
||||
if sum(chunk) > 0:
|
||||
in_command = True
|
||||
elif in_command:
|
||||
break # done with command
|
||||
|
||||
# Fake STT result
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": "fake-text"}},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
|
@ -605,37 +754,78 @@ async def test_empty_tts_output(
|
|||
|
||||
with (
|
||||
patch(
|
||||
"pymicro_vad.MicroVad.Process10ms",
|
||||
new=process_10ms,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
voip_device,
|
||||
Context(),
|
||||
opus_payload_type=123,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
satellite.transport = Mock()
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence (assumes relaxed VAD sensitivity)
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
||||
|
||||
# Wait for mock pipeline to finish
|
||||
async with asyncio.timeout(1):
|
||||
await rtp_protocol._tts_done.wait()
|
||||
await satellite._tts_done.wait()
|
||||
|
||||
mock_send_tts.assert_not_called()
|
||||
|
||||
|
||||
async def test_pipeline_error(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that a pipeline error causes the error tone to be played."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
|
||||
done = asyncio.Event()
|
||||
played_audio_bytes = b""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
# Fake error
|
||||
event_callback = kwargs["event_callback"]
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.ERROR,
|
||||
data={"code": "error-code", "message": "error message"},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
||||
nonlocal played_audio_bytes
|
||||
|
||||
# Should be error.pcm from components/voip
|
||||
played_audio_bytes = audio_bytes
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
satellite._tones = Tones.ERROR
|
||||
satellite.transport = Mock()
|
||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for error tone to be played
|
||||
async with asyncio.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
|
|
@ -150,10 +150,10 @@ async def reload_satellite(
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.run"
|
||||
) as _run_mock,
|
||||
):
|
||||
# _run_mock: satellite task does not actually run
|
||||
await hass.config_entries.async_reload(config_entry_id)
|
||||
|
||||
return hass.data[DOMAIN][config_entry_id].satellite.device
|
||||
return hass.data[DOMAIN][config_entry_id].satellite_device
|
||||
|
|
|
@ -152,7 +152,7 @@ async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntr
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.run"
|
||||
) as _run_mock,
|
||||
):
|
||||
# _run_mock: satellite task does not actually run
|
||||
|
@ -164,4 +164,4 @@ async def satellite_device(
|
|||
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
|
||||
) -> SatelliteDevice:
|
||||
"""Get a satellite device fixture."""
|
||||
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
|
||||
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite_device
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -17,14 +17,17 @@ from wyoming.info import Info
|
|||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.snd import Played
|
||||
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
||||
from wyoming.tts import Synthesize
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, wyoming
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
|
||||
from homeassistant.components.wyoming.assist_satellite import WyomingSatellite
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_ON
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import intent as intent_helper
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -69,10 +72,17 @@ def get_test_wav() -> bytes:
|
|||
return wav_io.getvalue()
|
||||
|
||||
|
||||
def get_device(hass: HomeAssistant, entry: ConfigEntry) -> SatelliteDevice:
|
||||
"""Get the satellite device for a config entry."""
|
||||
device = hass.data[wyoming.DOMAIN][entry.entry_id].satellite_device
|
||||
assert isinstance(device, SatelliteDevice)
|
||||
return device
|
||||
|
||||
|
||||
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||
"""Satellite AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses: list[Event]) -> None:
|
||||
def __init__(self, responses: list[Event], auto_audio: bool = True) -> None:
|
||||
"""Initialize client."""
|
||||
super().__init__(responses)
|
||||
|
||||
|
@ -124,9 +134,16 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
self.timer_finished_event = asyncio.Event()
|
||||
self.timer_finished: TimerFinished | None = None
|
||||
|
||||
self.run_pipeline_event = asyncio.Event()
|
||||
self.run_pipeline_count = asyncio.Semaphore()
|
||||
self.run_pipeline: RunPipeline | None = None
|
||||
self.run_pipeline_list: list[RunPipeline] = []
|
||||
|
||||
self._mic_audio_chunk = AudioChunk(
|
||||
rate=16000, width=2, channels=1, audio=b"chunk"
|
||||
).event()
|
||||
self._auto_audio = auto_audio
|
||||
self._event_injected = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect."""
|
||||
|
@ -184,17 +201,29 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
elif TimerFinished.is_type(event.type):
|
||||
self.timer_finished = TimerFinished.from_event(event)
|
||||
self.timer_finished_event.set()
|
||||
elif RunPipeline.is_type(event.type):
|
||||
self.run_pipeline = RunPipeline.from_event(event)
|
||||
self.run_pipeline_list.append(self.run_pipeline)
|
||||
self.run_pipeline_event.set()
|
||||
self.run_pipeline_count.release()
|
||||
|
||||
async def read_event(self) -> Event | None:
|
||||
"""Receive."""
|
||||
event = await super().read_event()
|
||||
while True:
|
||||
event = await super().read_event()
|
||||
if event is not None:
|
||||
return event
|
||||
|
||||
# Keep sending audio chunks instead of None
|
||||
return event or self._mic_audio_chunk
|
||||
if self._auto_audio:
|
||||
# Keep sending audio chunks instead of None
|
||||
return self._mic_audio_chunk
|
||||
|
||||
await self._event_injected.wait()
|
||||
|
||||
def inject_event(self, event: Event) -> None:
|
||||
"""Put an event in as the next response."""
|
||||
self.responses = [event, *self.responses]
|
||||
self._event_injected.set()
|
||||
|
||||
|
||||
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
|
@ -240,23 +269,21 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device = get_device(hass, entry)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -443,7 +470,7 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
"""Test callback for a satellite that has been muted."""
|
||||
on_muted_event = asyncio.Event()
|
||||
|
||||
original_on_muted = wyoming.satellite.WyomingSatellite.on_muted
|
||||
original_on_muted = WyomingSatellite.on_muted
|
||||
|
||||
async def on_muted(self):
|
||||
# Trigger original function
|
||||
|
@ -457,6 +484,18 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
self.device.set_is_muted(False)
|
||||
on_muted_event.set()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
async for chunk in stt_stream:
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
|
@ -467,9 +506,17 @@ async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|||
return_value=State("switch.test_mute", STATE_ON),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_muted",
|
||||
on_muted,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([]),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -484,7 +531,7 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
"""Test pipeline loop restart after unexpected error."""
|
||||
on_restart_event = asyncio.Event()
|
||||
|
||||
original_on_restart = wyoming.satellite.WyomingSatellite.on_restart
|
||||
original_on_restart = WyomingSatellite.on_restart
|
||||
|
||||
async def on_restart(self):
|
||||
await original_on_restart(self)
|
||||
|
@ -497,14 +544,14 @@ async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._connect_and_loop",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._connect_and_loop",
|
||||
side_effect=RuntimeError(),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RESTART_SECONDS", 0),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -517,7 +564,7 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
reconnect_event = asyncio.Event()
|
||||
stopped_event = asyncio.Event()
|
||||
|
||||
original_on_reconnect = wyoming.satellite.WyomingSatellite.on_reconnect
|
||||
original_on_reconnect = WyomingSatellite.on_reconnect
|
||||
|
||||
async def on_reconnect(self):
|
||||
await original_on_reconnect(self)
|
||||
|
@ -537,18 +584,20 @@ async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
|
||||
side_effect=ConnectionRefusedError(),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_reconnect",
|
||||
on_reconnect,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._RECONNECT_SECONDS", 0),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
|
||||
),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
|
@ -570,14 +619,14 @@ async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient([]), # no RunPipeline event
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
):
|
||||
|
@ -615,25 +664,23 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device = get_device(hass, entry)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
@ -665,11 +712,11 @@ async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
|
@ -701,7 +748,7 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||
"""Test satellite receiving non-WAV audio from text-to-speech."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
original_stream_tts = wyoming.satellite.WyomingSatellite._stream_tts
|
||||
original_stream_tts = WyomingSatellite._stream_tts
|
||||
error_event = asyncio.Event()
|
||||
|
||||
async def _stream_tts(self, media_id):
|
||||
|
@ -724,19 +771,19 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("mp3", bytes(1)),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._stream_tts",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._stream_tts",
|
||||
_stream_tts,
|
||||
),
|
||||
):
|
||||
|
@ -808,8 +855,9 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
|||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
async for chunk in stt_stream:
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
|
@ -819,18 +867,16 @@ async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device = get_device(hass, entry)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -882,8 +928,9 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
|||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
async for chunk in stt_stream:
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
|
@ -893,18 +940,16 @@ async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device = get_device(hass, entry)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -938,7 +983,7 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
).event(),
|
||||
]
|
||||
|
||||
original_run_pipeline_once = wyoming.satellite.WyomingSatellite._run_pipeline_once
|
||||
original_run_pipeline_once = WyomingSatellite._run_pipeline_once
|
||||
start_stage_event = asyncio.Event()
|
||||
end_stage_event = asyncio.Event()
|
||||
|
||||
|
@ -967,11 +1012,11 @@ async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_pipeline_once",
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingSatellite._run_pipeline_once",
|
||||
_run_pipeline_once,
|
||||
),
|
||||
):
|
||||
|
@ -1018,8 +1063,9 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|||
pipeline_event_callback = event_callback
|
||||
|
||||
run_pipeline_called.set()
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
async for chunk in stt_stream:
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
pipeline_stopped.set()
|
||||
|
||||
|
@ -1029,11 +1075,11 @@ async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
|
@ -1083,11 +1129,11 @@ async def test_wake_word_phrase(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
|
@ -1114,14 +1160,12 @@ async def test_timers(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([]),
|
||||
) as mock_client,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
device = get_device(hass, entry)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -1325,23 +1369,20 @@ async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
|
|||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
),
|
||||
patch("homeassistant.components.wyoming.satellite._PING_SEND_DELAY", 0),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
satellite: wyoming.WyomingSatellite = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
|
@ -1370,19 +1411,128 @@ async def test_satellite_conversation_id(hass: HomeAssistant) -> None:
|
|||
# Should be the same conversation id
|
||||
assert pipeline_kwargs.get("conversation_id") == conversation_id
|
||||
|
||||
# Reset and run again, but this time "time out"
|
||||
satellite._conversation_id_time = None
|
||||
run_pipeline_called.clear()
|
||||
pipeline_kwargs.clear()
|
||||
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||
)
|
||||
async def test_say_text(hass: HomeAssistant) -> None:
|
||||
"""Test say text service call."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
test_text = "test-text"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([]),
|
||||
) as mock_client,
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device = get_device(hass, entry)
|
||||
satellite_entity_id = device.get_satellite_entity_id(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
# Should be a different conversation id
|
||||
new_conversation_id = pipeline_kwargs.get("conversation_id")
|
||||
assert new_conversation_id
|
||||
assert new_conversation_id != conversation_id
|
||||
async with asyncio.timeout(1):
|
||||
await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_SAY_TEXT,
|
||||
{
|
||||
ATTR_ENTITY_ID: satellite_entity_id,
|
||||
assist_satellite.ATTR_ANNOUNCE_TEXT: test_text,
|
||||
},
|
||||
blocking=False,
|
||||
)
|
||||
await mock_client.run_pipeline_event.wait()
|
||||
|
||||
assert mock_client.run_pipeline is not None
|
||||
rp: RunPipeline = mock_client.run_pipeline
|
||||
assert rp.start_stage == PipelineStage.TTS
|
||||
assert rp.end_stage == PipelineStage.TTS
|
||||
assert rp.announce_text == test_text
|
||||
|
||||
|
||||
async def test_get_command(hass: HomeAssistant) -> None:
|
||||
"""Test get command service call."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
test_command = "test-command"
|
||||
test_text = "test-text"
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant,
|
||||
context,
|
||||
event_callback,
|
||||
stt_metadata,
|
||||
stt_stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_END,
|
||||
{"stt_output": {"text": test_command}},
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient([], auto_audio=False),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device = get_device(hass, entry)
|
||||
satellite_entity_id = device.get_satellite_entity_id(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
task = asyncio.create_task(
|
||||
hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
assist_satellite.SERVICE_GET_COMMAND,
|
||||
{
|
||||
ATTR_ENTITY_ID: satellite_entity_id,
|
||||
assist_satellite.ATTR_ANNOUNCE_TEXT: test_text,
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
)
|
||||
await mock_client.run_pipeline_event.wait()
|
||||
|
||||
# Announcement happens first
|
||||
assert mock_client.run_pipeline is not None
|
||||
rp: RunPipeline = mock_client.run_pipeline
|
||||
assert rp.start_stage == PipelineStage.TTS
|
||||
assert rp.end_stage == PipelineStage.TTS
|
||||
assert rp.announce_text == test_text
|
||||
|
||||
mock_client.run_pipeline_event.clear()
|
||||
mock_client.run_pipeline = None
|
||||
mock_client.inject_event(Played().event())
|
||||
|
||||
# Command happens next
|
||||
await mock_client.run_pipeline_event.wait()
|
||||
assert mock_client.run_pipeline is not None
|
||||
rp = mock_client.run_pipeline
|
||||
assert rp.start_stage == PipelineStage.ASR
|
||||
assert rp.end_stage == PipelineStage.ASR
|
||||
|
||||
mock_client.inject_event(rp.event())
|
||||
|
||||
result = await task
|
||||
assert result == {satellite_entity_id: test_command}
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
from homeassistant.components.assist_pipeline import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
|
|
@ -213,3 +213,43 @@ async def test_get_scheduled_timer_handles(hass: HomeAssistant) -> None:
|
|||
timer_handle.cancel()
|
||||
timer_handle2.cancel()
|
||||
timer_handle3.cancel()
|
||||
|
||||
|
||||
async def test_queue_to_iterable() -> None:
|
||||
"""Test queue_to_iterable."""
|
||||
queue: asyncio.Queue[int | None] = asyncio.Queue()
|
||||
expected_items = list(range(10))
|
||||
|
||||
for i in expected_items:
|
||||
await queue.put(i)
|
||||
|
||||
# Will terminate the stream
|
||||
await queue.put(None)
|
||||
|
||||
actual_items = [item async for item in hasync.queue_to_iterable(queue)]
|
||||
|
||||
assert expected_items == actual_items
|
||||
|
||||
# Check timeout
|
||||
assert queue.empty()
|
||||
|
||||
# Time out on first item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for _item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check timeout on second item
|
||||
assert queue.empty()
|
||||
await queue.put(12345)
|
||||
|
||||
# Time out on second item
|
||||
async with asyncio.timeout(1):
|
||||
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
|
||||
# Should time out very quickly
|
||||
async for item in hasync.queue_to_iterable(queue, timeout=0.01):
|
||||
if item != 12345:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert queue.empty()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue