Allow picking a pipeline for voip devices (#91524)
* Allow picking a pipeline for voip device * Add tests * Fix test * Adjust on new pipeline data
This commit is contained in:
parent
9bd12f6503
commit
bd22e0bd43
12 changed files with 323 additions and 8 deletions
|
@ -52,16 +52,13 @@ async def async_pipeline_from_audio_stream(
|
|||
tts_options: dict | None = None,
|
||||
) -> None:
|
||||
"""Create an audio pipeline from an audio stream."""
|
||||
if language is None:
|
||||
if language is None and pipeline_id is None:
|
||||
language = hass.config.language
|
||||
|
||||
# Temporary workaround for language codes
|
||||
if language == "en":
|
||||
language = "en-US"
|
||||
|
||||
if stt_metadata.language == "":
|
||||
stt_metadata.language = language
|
||||
|
||||
if context is None:
|
||||
context = Context()
|
||||
|
||||
|
@ -75,6 +72,9 @@ async def async_pipeline_from_audio_stream(
|
|||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||
)
|
||||
|
||||
if stt_metadata.language == "":
|
||||
stt_metadata.language = pipeline.language
|
||||
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=conversation_id,
|
||||
stt_metadata=stt_metadata,
|
||||
|
|
|
@ -105,7 +105,7 @@ class Pipeline:
|
|||
"""A voice assistant pipeline."""
|
||||
|
||||
conversation_engine: str | None
|
||||
language: str | None
|
||||
language: str
|
||||
name: str
|
||||
stt_engine: str | None
|
||||
tts_engine: str | None
|
||||
|
|
95
homeassistant/components/assist_pipeline/select.py
Normal file
95
homeassistant/components/assist_pipeline/select.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
"""Select entities for a pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from homeassistant.components.select import SelectEntity, SelectEntityDescription
|
||||
from homeassistant.const import EntityCategory, Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import PipelineStorageCollection
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||
) -> str | None:
|
||||
"""Get the chosen pipeline for a domain."""
|
||||
ent_reg = er.async_get(hass)
|
||||
pipeline_entity_id = ent_reg.async_get_entity_id(
|
||||
Platform.SELECT, domain, f"{unique_id_prefix}-pipeline"
|
||||
)
|
||||
if pipeline_entity_id is None:
|
||||
return None
|
||||
|
||||
state = hass.states.get(pipeline_entity_id)
|
||||
if state is None or state.state == OPTION_PREFERRED:
|
||||
return None
|
||||
|
||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||
return next(
|
||||
(item.id for item in pipeline_store.async_items() if item.name == state.state),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||
"""Entity to represent a pipeline selector."""
|
||||
|
||||
entity_description = SelectEntityDescription(
|
||||
key="pipeline",
|
||||
translation_key="pipeline",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
_attr_should_poll = False
|
||||
_attr_current_option = OPTION_PREFERRED
|
||||
_attr_options = [OPTION_PREFERRED]
|
||||
|
||||
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
|
||||
self.hass = hass
|
||||
self._update_options()
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||
DOMAIN
|
||||
].pipeline_store
|
||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||
|
||||
state = await self.async_get_last_state()
|
||||
if state is not None and state.state in self.options:
|
||||
self._attr_current_option = state.state
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
self._attr_current_option = option
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def _pipelines_updated(
|
||||
self, change_sets: Iterable[collection.CollectionChangeSet]
|
||||
) -> None:
|
||||
"""Handle pipeline update."""
|
||||
self._update_options()
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def _update_options(self) -> None:
|
||||
"""Handle pipeline update."""
|
||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||
DOMAIN
|
||||
].pipeline_store
|
||||
options = [OPTION_PREFERRED]
|
||||
options.extend(sorted(item.name for item in pipeline_store.async_items()))
|
||||
self._attr_options = options
|
||||
|
||||
if self._attr_current_option not in options:
|
||||
self._attr_current_option = OPTION_PREFERRED
|
12
homeassistant/components/assist_pipeline/strings.json
Normal file
12
homeassistant/components/assist_pipeline/strings.json
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"entity": {
|
||||
"select": {
|
||||
"pipeline": {
|
||||
"name": "Assist Pipeline",
|
||||
"state": {
|
||||
"preferred": "Preferred"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ from .voip import HassVoipDatagramProtocol
|
|||
|
||||
PLATFORMS = (
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.SELECT,
|
||||
Platform.SWITCH,
|
||||
)
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
|
46
homeassistant/components/voip/select.py
Normal file
46
homeassistant/components/voip/select.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
"""Select entities for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DomainData
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP switch entities."""
|
||||
domain_data: DomainData = hass.data[DOMAIN]
|
||||
|
||||
@callback
|
||||
def async_add_device(device: VoIPDevice) -> None:
|
||||
"""Add device."""
|
||||
async_add_entities([VoipPipelineSelect(hass, device)])
|
||||
|
||||
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||
|
||||
async_add_entities(
|
||||
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
|
||||
)
|
||||
|
||||
|
||||
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
|
||||
"""Pipeline selector for VoIP devices."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
VoIPEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, device.voip_id)
|
|
@ -24,6 +24,14 @@
|
|||
"allow_call": {
|
||||
"name": "Allow Calls"
|
||||
}
|
||||
},
|
||||
"select": {
|
||||
"pipeline": {
|
||||
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
|
||||
"state": {
|
||||
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,11 +15,14 @@ from homeassistant.components.assist_pipeline import (
|
|||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .devices import VoIPDevice, VoIPDevices
|
||||
|
||||
|
@ -151,7 +154,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
language=self.language,
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.voip_device.voip_id
|
||||
),
|
||||
conversation_id=self._conversation_id,
|
||||
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||
)
|
||||
|
|
|
@ -143,20 +143,24 @@ class ObservableCollection(ABC, Generic[_ItemT]):
|
|||
return list(self.data.values())
|
||||
|
||||
@callback
|
||||
def async_add_listener(self, listener: ChangeListener) -> None:
|
||||
def async_add_listener(self, listener: ChangeListener) -> Callable[[], None]:
|
||||
"""Add a listener.
|
||||
|
||||
Will be called with (change_type, item_id, updated_config).
|
||||
"""
|
||||
self.listeners.append(listener)
|
||||
return lambda: self.listeners.remove(listener)
|
||||
|
||||
@callback
|
||||
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
|
||||
def async_add_change_set_listener(
|
||||
self, listener: ChangeSetListener
|
||||
) -> Callable[[], None]:
|
||||
"""Add a listener for a full change set.
|
||||
|
||||
Will be called with [(change_type, item_id, updated_config), ...]
|
||||
"""
|
||||
self.change_set_listeners.append(listener)
|
||||
return lambda: self.change_set_listeners.remove(listener)
|
||||
|
||||
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
|
||||
"""Notify listeners of a change."""
|
||||
|
|
|
@ -6,6 +6,8 @@ from unittest.mock import AsyncMock, Mock
|
|||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -137,3 +139,9 @@ async def init_components(
|
|||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||
"""Return pipeline storage collection."""
|
||||
return hass.data[DOMAIN].pipeline_store
|
||||
|
|
117
tests/components/assist_pipeline/test_select.py
Normal file
117
tests/components/assist_pipeline/test_select.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
"""Test select entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import Pipeline
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
|
||||
|
||||
|
||||
class SelectPlatform(MockPlatform):
|
||||
"""Fake select platform."""
|
||||
|
||||
# pylint: disable=method-hidden
|
||||
async def async_setup_entry(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up fake select platform."""
|
||||
async_add_entities([AssistPipelineSelect(hass, "test")])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
|
||||
"""Initialize select entity."""
|
||||
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
|
||||
config_entry = MockConfigEntry(domain="assist_pipeline")
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pipeline_1(
|
||||
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
|
||||
) -> Pipeline:
|
||||
"""Create a pipeline."""
|
||||
return await pipeline_storage.async_create_item(
|
||||
{
|
||||
"name": "Test 1",
|
||||
"language": "en-US",
|
||||
"conversation_engine": None,
|
||||
"tts_engine": None,
|
||||
"stt_engine": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pipeline_2(
|
||||
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
|
||||
) -> Pipeline:
|
||||
"""Create a pipeline."""
|
||||
return await pipeline_storage.async_create_item(
|
||||
{
|
||||
"name": "Test 2",
|
||||
"language": "en-US",
|
||||
"conversation_engine": None,
|
||||
"tts_engine": None,
|
||||
"stt_engine": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def test_select_entity_changing_pipelines(
|
||||
hass: HomeAssistant,
|
||||
init_select: ConfigEntry,
|
||||
pipeline_1: Pipeline,
|
||||
pipeline_2: Pipeline,
|
||||
pipeline_storage: PipelineStorageCollection,
|
||||
) -> None:
|
||||
"""Test entity tracking pipeline changes."""
|
||||
config_entry = init_select # nicer naming
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == [
|
||||
"preferred",
|
||||
pipeline_1.name,
|
||||
pipeline_2.name,
|
||||
]
|
||||
|
||||
# Change select to new pipeline
|
||||
await hass.services.async_call(
|
||||
"select",
|
||||
"select_option",
|
||||
{
|
||||
"entity_id": "select.assist_pipeline_test_pipeline",
|
||||
"option": pipeline_2.name,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
# Reload config entry to test selected option persists
|
||||
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state.state == pipeline_2.name
|
||||
|
||||
# Remove selected pipeline
|
||||
await pipeline_storage.async_delete_item(pipeline_2.id)
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == ["preferred", pipeline_1.name]
|
19
tests/components/voip/test_select.py
Normal file
19
tests/components/voip/test_select.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
"""Test VoIP select."""
|
||||
from homeassistant.components.voip.devices import VoIPDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_pipeline_select(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test pipeline select.
|
||||
|
||||
Functionality is tested in assist_pipeline/test_select.py.
|
||||
This test is only to ensure it is set up.
|
||||
"""
|
||||
state = hass.states.get("select.192_168_1_210_assist_pipeline")
|
||||
assert state is not None
|
||||
assert state.state == "preferred"
|
Loading…
Add table
Reference in a new issue