Allow picking a pipeline for voip devices (#91524)

* Allow picking a pipeline for voip device

* Add tests

* Fix test

* Adjust on new pipeline data
This commit is contained in:
Paulus Schoutsen 2023-04-17 13:09:11 -04:00 committed by GitHub
parent 9bd12f6503
commit bd22e0bd43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 323 additions and 8 deletions

View file

@ -52,16 +52,13 @@ async def async_pipeline_from_audio_stream(
tts_options: dict | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if language is None:
if language is None and pipeline_id is None:
language = hass.config.language
# Temporary workaround for language codes
if language == "en":
language = "en-US"
if stt_metadata.language == "":
stt_metadata.language = language
if context is None:
context = Context()
@ -75,6 +72,9 @@ async def async_pipeline_from_audio_stream(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
if stt_metadata.language == "":
stt_metadata.language = pipeline.language
pipeline_input = PipelineInput(
conversation_id=conversation_id,
stt_metadata=stt_metadata,

View file

@ -105,7 +105,7 @@ class Pipeline:
"""A voice assistant pipeline."""
conversation_engine: str | None
language: str | None
language: str
name: str
stt_engine: str | None
tts_engine: str | None

View file

@ -0,0 +1,95 @@
"""Select entities for a pipeline."""
from __future__ import annotations
from collections.abc import Iterable
from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .pipeline import PipelineStorageCollection
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> str | None:
"""Get the chosen pipeline for a domain."""
ent_reg = er.async_get(hass)
pipeline_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT, domain, f"{unique_id_prefix}-pipeline"
)
if pipeline_entity_id is None:
return None
state = hass.states.get(pipeline_entity_id)
if state is None or state.state == OPTION_PREFERRED:
return None
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
return next(
(item.id for item in pipeline_store.async_items() if item.name == state.state),
None,
)
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent a pipeline selector."""
entity_description = SelectEntityDescription(
key="pipeline",
translation_key="pipeline",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_current_option = OPTION_PREFERRED
_attr_options = [OPTION_PREFERRED]
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
self.hass = hass
self._update_options()
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
self.async_write_ha_state()
async def _pipelines_updated(
self, change_sets: Iterable[collection.CollectionChangeSet]
) -> None:
"""Handle pipeline update."""
self._update_options()
self.async_write_ha_state()
@callback
def _update_options(self) -> None:
"""Handle pipeline update."""
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
options = [OPTION_PREFERRED]
options.extend(sorted(item.name for item in pipeline_store.async_items()))
self._attr_options = options
if self._attr_current_option not in options:
self._attr_current_option = OPTION_PREFERRED

View file

@ -0,0 +1,12 @@
{
"entity": {
"select": {
"pipeline": {
"name": "Assist Pipeline",
"state": {
"preferred": "Preferred"
}
}
}
}
}

View file

@ -19,6 +19,7 @@ from .voip import HassVoipDatagramProtocol
PLATFORMS = (
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,
)
_LOGGER = logging.getLogger(__name__)

View file

@ -0,0 +1,46 @@
"""Select entities for VoIP integration."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .devices import VoIPDevice
from .entity import VoIPEntity
if TYPE_CHECKING:
from . import DomainData
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP switch entities."""
domain_data: DomainData = hass.data[DOMAIN]
@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipPipelineSelect(hass, device)])
domain_data.devices.async_add_new_device_listener(async_add_device)
async_add_entities(
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
)
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
"""Pipeline selector for VoIP devices."""
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a pipeline selector."""
VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id)

View file

@ -24,6 +24,14 @@
"allow_call": {
"name": "Allow Calls"
}
},
"select": {
"pipeline": {
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
"state": {
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
}
}
}
}
}

View file

@ -15,11 +15,14 @@ from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.const import __version__
from homeassistant.core import HomeAssistant
from .const import DOMAIN
if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices
@ -151,7 +154,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
language=self.language,
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
)

View file

@ -143,20 +143,24 @@ class ObservableCollection(ABC, Generic[_ItemT]):
return list(self.data.values())
@callback
def async_add_listener(self, listener: ChangeListener) -> None:
def async_add_listener(self, listener: ChangeListener) -> Callable[[], None]:
"""Add a listener.
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
return lambda: self.listeners.remove(listener)
@callback
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
def async_add_change_set_listener(
self, listener: ChangeSetListener
) -> Callable[[], None]:
"""Add a listener for a full change set.
Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)
return lambda: self.change_set_listeners.remove(listener)
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""

View file

@ -6,6 +6,8 @@ from unittest.mock import AsyncMock, Mock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
@ -137,3 +139,9 @@ async def init_components(
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store

View file

@ -0,0 +1,117 @@
"""Test select entity."""
from __future__ import annotations
import pytest
from homeassistant.components.assist_pipeline import Pipeline
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
class SelectPlatform(MockPlatform):
"""Fake select platform."""
# pylint: disable=method-hidden
async def async_setup_entry(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up fake select platform."""
async_add_entities([AssistPipelineSelect(hass, "test")])
@pytest.fixture
async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
"""Initialize select entity."""
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
config_entry = MockConfigEntry(domain="assist_pipeline")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
return config_entry
@pytest.fixture
async def pipeline_1(
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
) -> Pipeline:
"""Create a pipeline."""
return await pipeline_storage.async_create_item(
{
"name": "Test 1",
"language": "en-US",
"conversation_engine": None,
"tts_engine": None,
"stt_engine": None,
}
)
@pytest.fixture
async def pipeline_2(
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
) -> Pipeline:
"""Create a pipeline."""
return await pipeline_storage.async_create_item(
{
"name": "Test 2",
"language": "en-US",
"conversation_engine": None,
"tts_engine": None,
"stt_engine": None,
}
)
async def test_select_entity_changing_pipelines(
hass: HomeAssistant,
init_select: ConfigEntry,
pipeline_1: Pipeline,
pipeline_2: Pipeline,
pipeline_storage: PipelineStorageCollection,
) -> None:
"""Test entity tracking pipeline changes."""
config_entry = init_select # nicer naming
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [
"preferred",
pipeline_1.name,
pipeline_2.name,
]
# Change select to new pipeline
await hass.services.async_call(
"select",
"select_option",
{
"entity_id": "select.assist_pipeline_test_pipeline",
"option": pipeline_2.name,
},
blocking=True,
)
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state.state == pipeline_2.name
# Reload config entry to test selected option persists
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state.state == pipeline_2.name
# Remove selected pipeline
await pipeline_storage.async_delete_item(pipeline_2.id)
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state.state == "preferred"
assert state.attributes["options"] == ["preferred", pipeline_1.name]

View file

@ -0,0 +1,19 @@
"""Test VoIP select."""
from homeassistant.components.voip.devices import VoIPDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
async def test_pipeline_select(
hass: HomeAssistant,
config_entry: ConfigEntry,
voip_device: VoIPDevice,
) -> None:
"""Test pipeline select.
Functionality is tested in assist_pipeline/test_select.py.
This test is only to ensure it is set up.
"""
state = hass.states.get("select.192_168_1_210_assist_pipeline")
assert state is not None
assert state.state == "preferred"