Compare commits
1 commit
dev
...
announce-r
Author | SHA1 | Date | |
---|---|---|---|
|
3bf174b369 |
9 changed files with 218 additions and 55 deletions
|
@ -63,6 +63,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
"async_internal_announce",
|
||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
"start_conversation",
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Optional("start_message"): str,
|
||||
vol.Optional("start_media_id"): str,
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key("start_message", "start_media_id"),
|
||||
),
|
||||
"async_internal_start_conversation",
|
||||
[AssistSatelliteEntityFeature.START_CONVERSATION],
|
||||
)
|
||||
hass.data[CONNECTION_TEST_DATA] = {}
|
||||
async_register_websocket_api(hass)
|
||||
hass.http.register_view(ConnectionTestView())
|
||||
|
|
|
@ -26,3 +26,6 @@ class AssistSatelliteEntityFeature(IntFlag):
|
|||
|
||||
ANNOUNCE = 1
|
||||
"""Device supports remotely triggered announcements."""
|
||||
|
||||
START_CONVERSATION = 2
|
||||
"""Device supports starting conversations."""
|
||||
|
|
|
@ -187,47 +187,10 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
"""
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||
|
||||
if message is None:
|
||||
message = ""
|
||||
|
||||
if not media_id:
|
||||
media_id_source = "tts"
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
if self.tts_options is not None:
|
||||
tts_options.update(self.tts_options)
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
if not media_id_source:
|
||||
media_id_source = "media_id"
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
None,
|
||||
)
|
||||
media_id = media.url
|
||||
|
||||
if not media_id_source:
|
||||
media_id_source = "url"
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
announcement = await self._resolve_media_id(message, media_id)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
@ -237,9 +200,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
|
||||
try:
|
||||
# Block until announcement is finished
|
||||
await self.async_announce(
|
||||
AssistSatelliteAnnouncement(message, media_id, media_id_source)
|
||||
)
|
||||
await self.async_announce(announcement)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
@ -251,6 +212,44 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_internal_start_conversation(
|
||||
self,
|
||||
start_message: str | None = None,
|
||||
start_media_id: str | None = None,
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite.
|
||||
|
||||
If start_media_id is not provided, message is synthesized to
|
||||
audio with the selected pipeline.
|
||||
|
||||
If start_media_id is provided, it is played directly. It is possible
|
||||
to omit the message and the satellite will not show any text.
|
||||
|
||||
Calls async_start_conversation.
|
||||
"""
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
if start_message is None:
|
||||
start_message = ""
|
||||
|
||||
announcement = await self._resolve_media_id(start_message, start_media_id)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
self._is_announcing = True
|
||||
|
||||
try:
|
||||
await self.async_start_conversation(announcement)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
|
||||
async def async_start_conversation(
|
||||
self, start_announcement: AssistSatelliteAnnouncement
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
|
@ -428,3 +427,48 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
return vad.VadSensitivity.to_seconds(vad_sensitivity)
|
||||
|
||||
async def _resolve_media_id(
|
||||
self, message: str, media_id: str | None
|
||||
) -> AssistSatelliteAnnouncement:
|
||||
"""Resolve the media ID."""
|
||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||
|
||||
if not media_id:
|
||||
media_id_source = "tts"
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
if self.tts_options is not None:
|
||||
tts_options.update(self.tts_options)
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
if not media_id_source:
|
||||
media_id_source = "media_id"
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
None,
|
||||
)
|
||||
media_id = media.url
|
||||
|
||||
if not media_id_source:
|
||||
media_id_source = "url"
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
return AssistSatelliteAnnouncement(message, media_id, media_id_source)
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
"services": {
|
||||
"announce": {
|
||||
"service": "mdi:bullhorn"
|
||||
},
|
||||
"start_conversation": {
|
||||
"service": "mdi:forum"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,3 +14,19 @@ announce:
|
|||
required: false
|
||||
selector:
|
||||
text:
|
||||
start_conversation:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
fields:
|
||||
start_message:
|
||||
required: false
|
||||
example: "You left the lights on in the living room. Turn them off?"
|
||||
selector:
|
||||
text:
|
||||
start_media_id:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
|
|
|
@ -25,6 +25,20 @@
|
|||
"description": "The media ID to announce instead of using text-to-speech."
|
||||
}
|
||||
}
|
||||
},
|
||||
"start_conversation": {
|
||||
"name": "Start Conversation",
|
||||
"description": "Start a conversation from a satellite.",
|
||||
"fields": {
|
||||
"start_message": {
|
||||
"name": "Message",
|
||||
"description": "The message to start with."
|
||||
},
|
||||
"start_media_id": {
|
||||
"name": "Media ID",
|
||||
"description": "The media ID to start with instead of using text-to-speech."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,6 +87,7 @@ def _base_components() -> dict[str, ModuleType]:
|
|||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components import (
|
||||
alarm_control_panel,
|
||||
assist_satellite,
|
||||
calendar,
|
||||
camera,
|
||||
climate,
|
||||
|
@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]:
|
|||
|
||||
return {
|
||||
"alarm_control_panel": alarm_control_panel,
|
||||
"assist_satellite": assist_satellite,
|
||||
"calendar": calendar,
|
||||
"camera": camera,
|
||||
"climate": climate,
|
||||
|
|
|
@ -39,7 +39,11 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
_attr_supported_features = (
|
||||
AssistSatelliteEntityFeature.ANNOUNCE
|
||||
| AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
)
|
||||
_attr_tts_options = {"test-option": "test-value"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
|
@ -59,6 +63,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||
active_wake_words=["1234"],
|
||||
max_active_wake_words=1,
|
||||
)
|
||||
self.start_conversations = []
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
@ -79,6 +84,12 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||
"""Set the current satellite configuration."""
|
||||
self.config = config
|
||||
|
||||
async def async_start_conversation(
|
||||
self, start_announcement: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite."""
|
||||
self.start_conversations.append(start_announcement)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity() -> MockAssistSatellite:
|
||||
|
|
|
@ -30,6 +30,18 @@ from . import ENTITY_ID
|
|||
from .conftest import MockAssistSatellite
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
||||
"""Set up a pipeline with a TTS engine."""
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
async_get_pipeline(hass),
|
||||
tts_engine="tts.mock_entity",
|
||||
tts_language="en",
|
||||
tts_voice="test-voice",
|
||||
)
|
||||
|
||||
|
||||
async def test_entity_state(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
|
@ -64,7 +76,7 @@ async def test_entity_state(
|
|||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] is None
|
||||
assert kwargs["tts_audio_output"] == {"test-option": "test-value"}
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
|
@ -189,24 +201,12 @@ async def test_announce(
|
|||
expected_params: tuple[str, str],
|
||||
) -> None:
|
||||
"""Test announcing on a device."""
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
async_get_pipeline(hass),
|
||||
tts_engine="tts.mock_entity",
|
||||
tts_language="en",
|
||||
tts_voice="test-voice",
|
||||
)
|
||||
|
||||
entity._attr_tts_options = {"test-option": "test-value"}
|
||||
|
||||
original_announce = entity.async_announce
|
||||
announce_started = asyncio.Event()
|
||||
|
||||
async def async_announce(announcement):
|
||||
# Verify state change
|
||||
assert entity.state == AssistSatelliteState.RESPONDING
|
||||
await original_announce(announcement)
|
||||
announce_started.set()
|
||||
|
||||
def tts_generate_media_source_id(
|
||||
hass: HomeAssistant,
|
||||
|
@ -464,3 +464,59 @@ async def test_vad_sensitivity_entity_not_found(
|
|||
|
||||
with pytest.raises(RuntimeError):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("service_data", "expected_params"),
|
||||
[
|
||||
(
|
||||
{"start_message": "Hello"},
|
||||
AssistSatelliteAnnouncement(
|
||||
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
|
||||
),
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_message": "Hello",
|
||||
"start_media_id": "media-source://bla",
|
||||
},
|
||||
AssistSatelliteAnnouncement(
|
||||
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"start_media_id": "http://example.com/bla.mp3"},
|
||||
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_start_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
service_data: dict,
|
||||
expected_params: tuple[str, str],
|
||||
) -> None:
|
||||
"""Test starting a conversation on a device."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"assist_satellite",
|
||||
"start_conversation",
|
||||
service_data,
|
||||
target={"entity_id": "assist_satellite.test_entity"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert entity.start_conversations[0] == expected_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue