diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py index 6932fa3180c..b913d9f5102 100644 --- a/homeassistant/components/assist_satellite/__init__.py +++ b/homeassistant/components/assist_satellite/__init__.py @@ -63,6 +63,20 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: "async_internal_announce", [AssistSatelliteEntityFeature.ANNOUNCE], ) + component.async_register_entity_service( + "start_conversation", + vol.All( + cv.make_entity_service_schema( + { + vol.Optional("start_message"): str, + vol.Optional("start_media_id"): str, + } + ), + cv.has_at_least_one_key("start_message", "start_media_id"), + ), + "async_internal_start_conversation", + [AssistSatelliteEntityFeature.START_CONVERSATION], + ) hass.data[CONNECTION_TEST_DATA] = {} async_register_websocket_api(hass) hass.http.register_view(ConnectionTestView()) diff --git a/homeassistant/components/assist_satellite/const.py b/homeassistant/components/assist_satellite/const.py index 73bc126f7ba..f17df9614f3 100644 --- a/homeassistant/components/assist_satellite/const.py +++ b/homeassistant/components/assist_satellite/const.py @@ -26,3 +26,6 @@ class AssistSatelliteEntityFeature(IntFlag): ANNOUNCE = 1 """Device supports remotely triggered announcements.""" + + START_CONVERSATION = 2 + """Device supports starting conversations.""" diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index 23b588b569e..f798ac33e81 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -187,47 +187,10 @@ class AssistSatelliteEntity(entity.Entity): """ await self._cancel_running_pipeline() - media_id_source: Literal["url", "media_id", "tts"] | None = None - if message is None: message = "" - if not media_id: - media_id_source = "tts" - # Synthesize audio and get URL - pipeline_id = self._resolve_pipeline() - pipeline = async_get_pipeline(self.hass, pipeline_id) - - tts_options: dict[str, Any] = {} - if pipeline.tts_voice is not None: - tts_options[tts.ATTR_VOICE] = pipeline.tts_voice - - if self.tts_options is not None: - tts_options.update(self.tts_options) - - media_id = tts_generate_media_source_id( - self.hass, - message, - engine=pipeline.tts_engine, - language=pipeline.tts_language, - options=tts_options, - ) - - if media_source.is_media_source_id(media_id): - if not media_id_source: - media_id_source = "media_id" - media = await media_source.async_resolve_media( - self.hass, - media_id, - None, - ) - media_id = media.url - - if not media_id_source: - media_id_source = "url" - - # Resolve to full URL - media_id = async_process_play_media_url(self.hass, media_id) + announcement = await self._resolve_media_id(message, media_id) if self._is_announcing: raise SatelliteBusyError @@ -237,9 +200,7 @@ class AssistSatelliteEntity(entity.Entity): try: # Block until announcement is finished - await self.async_announce( - AssistSatelliteAnnouncement(message, media_id, media_id_source) - ) + await self.async_announce(announcement) finally: self._is_announcing = False self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) @@ -251,6 +212,44 @@ class AssistSatelliteEntity(entity.Entity): """ raise NotImplementedError + async def async_internal_start_conversation( + self, + start_message: str | None = None, + start_media_id: str | None = None, + ) -> None: + """Start a conversation from the satellite. + + If start_media_id is not provided, message is synthesized to + audio with the selected pipeline. + + If start_media_id is provided, it is played directly. It is possible + to omit the message and the satellite will not show any text. + + Calls async_start_conversation. + """ + await self._cancel_running_pipeline() + + if start_message is None: + start_message = "" + + announcement = await self._resolve_media_id(start_message, start_media_id) + + if self._is_announcing: + raise SatelliteBusyError + + self._is_announcing = True + + try: + await self.async_start_conversation(announcement) + finally: + self._is_announcing = False + + async def async_start_conversation( + self, start_announcement: AssistSatelliteAnnouncement + ) -> None: + """Start a conversation from the satellite.""" + raise NotImplementedError + async def async_accept_pipeline_from_satellite( self, audio_stream: AsyncIterable[bytes], @@ -428,3 +427,48 @@ class AssistSatelliteEntity(entity.Entity): vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state) return vad.VadSensitivity.to_seconds(vad_sensitivity) + + async def _resolve_media_id( + self, message: str, media_id: str | None + ) -> AssistSatelliteAnnouncement: + """Resolve the media ID.""" + media_id_source: Literal["url", "media_id", "tts"] | None = None + + if not media_id: + media_id_source = "tts" + # Synthesize audio and get URL + pipeline_id = self._resolve_pipeline() + pipeline = async_get_pipeline(self.hass, pipeline_id) + + tts_options: dict[str, Any] = {} + if pipeline.tts_voice is not None: + tts_options[tts.ATTR_VOICE] = pipeline.tts_voice + + if self.tts_options is not None: + tts_options.update(self.tts_options) + + media_id = tts_generate_media_source_id( + self.hass, + message, + engine=pipeline.tts_engine, + language=pipeline.tts_language, + options=tts_options, + ) + + if media_source.is_media_source_id(media_id): + if not media_id_source: + media_id_source = "media_id" + media = await media_source.async_resolve_media( + self.hass, + media_id, + None, + ) + media_id = media.url + + if not media_id_source: + media_id_source = "url" + + # Resolve to full URL + media_id = async_process_play_media_url(self.hass, media_id) + + return AssistSatelliteAnnouncement(message, media_id, media_id_source) diff --git a/homeassistant/components/assist_satellite/icons.json b/homeassistant/components/assist_satellite/icons.json index a98c3aefc5b..1ed29541621 100644 --- a/homeassistant/components/assist_satellite/icons.json +++ b/homeassistant/components/assist_satellite/icons.json @@ -7,6 +7,9 @@ "services": { "announce": { "service": "mdi:bullhorn" + }, + "start_conversation": { + "service": "mdi:forum" } } } diff --git a/homeassistant/components/assist_satellite/services.yaml b/homeassistant/components/assist_satellite/services.yaml index e7fefc4705f..c6ebf81a225 100644 --- a/homeassistant/components/assist_satellite/services.yaml +++ b/homeassistant/components/assist_satellite/services.yaml @@ -14,3 +14,19 @@ announce: required: false selector: text: +start_conversation: + target: + entity: + domain: assist_satellite + supported_features: + - assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION + fields: + start_message: + required: false + example: "You left the lights on in the living room. Turn them off?" + selector: + text: + start_media_id: + required: false + selector: + text: diff --git a/homeassistant/components/assist_satellite/strings.json b/homeassistant/components/assist_satellite/strings.json index 1d07882daae..a1c34e33cb8 100644 --- a/homeassistant/components/assist_satellite/strings.json +++ b/homeassistant/components/assist_satellite/strings.json @@ -25,6 +25,20 @@ "description": "The media ID to announce instead of using text-to-speech." } } + }, + "start_conversation": { + "name": "Start Conversation", + "description": "Start a conversation from a satellite.", + "fields": { + "start_message": { + "name": "Message", + "description": "The message to start with." + }, + "start_media_id": { + "name": "Media ID", + "description": "The media ID to start with instead of using text-to-speech." + } + } } } } diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index ac21f1da3fc..f17829011a2 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -87,6 +87,7 @@ def _base_components() -> dict[str, ModuleType]: # pylint: disable-next=import-outside-toplevel from homeassistant.components import ( alarm_control_panel, + assist_satellite, calendar, camera, climate, @@ -107,6 +108,7 @@ def _base_components() -> dict[str, ModuleType]: return { "alarm_control_panel": alarm_control_panel, + "assist_satellite": assist_satellite, "calendar": calendar, "camera": camera, "climate": climate, diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py index 9e9bfd959e6..dbb3a95599e 100644 --- a/tests/components/assist_satellite/conftest.py +++ b/tests/components/assist_satellite/conftest.py @@ -39,7 +39,11 @@ class MockAssistSatellite(AssistSatelliteEntity): """Mock Assist Satellite Entity.""" _attr_name = "Test Entity" - _attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE + _attr_supported_features = ( + AssistSatelliteEntityFeature.ANNOUNCE + | AssistSatelliteEntityFeature.START_CONVERSATION + ) + _attr_tts_options = {"test-option": "test-value"} def __init__(self) -> None: """Initialize the mock entity.""" @@ -59,6 +63,7 @@ class MockAssistSatellite(AssistSatelliteEntity): active_wake_words=["1234"], max_active_wake_words=1, ) + self.start_conversations = [] def on_pipeline_event(self, event: PipelineEvent) -> None: """Handle pipeline events.""" @@ -79,6 +84,12 @@ class MockAssistSatellite(AssistSatelliteEntity): """Set the current satellite configuration.""" self.config = config + async def async_start_conversation( + self, start_announcement: AssistSatelliteConfiguration + ) -> None: + """Start a conversation from the satellite.""" + self.start_conversations.append(start_announcement) + @pytest.fixture def entity() -> MockAssistSatellite: diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py index b2347184bec..0a410b2307f 100644 --- a/tests/components/assist_satellite/test_entity.py +++ b/tests/components/assist_satellite/test_entity.py @@ -30,6 +30,18 @@ from . import ENTITY_ID from .conftest import MockAssistSatellite +@pytest.fixture(autouse=True) +async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None: + """Set up a pipeline with a TTS engine.""" + await async_update_pipeline( + hass, + async_get_pipeline(hass), + tts_engine="tts.mock_entity", + tts_language="en", + tts_voice="test-voice", + ) + + async def test_entity_state( hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite ) -> None: @@ -64,7 +76,7 @@ async def test_entity_state( assert kwargs["stt_stream"] is audio_stream assert kwargs["pipeline_id"] is None assert kwargs["device_id"] is None - assert kwargs["tts_audio_output"] is None + assert kwargs["tts_audio_output"] == {"test-option": "test-value"} assert kwargs["wake_word_phrase"] is None assert kwargs["audio_settings"] == AudioSettings( silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT) @@ -189,24 +201,12 @@ async def test_announce( expected_params: tuple[str, str], ) -> None: """Test announcing on a device.""" - await async_update_pipeline( - hass, - async_get_pipeline(hass), - tts_engine="tts.mock_entity", - tts_language="en", - tts_voice="test-voice", - ) - - entity._attr_tts_options = {"test-option": "test-value"} - original_announce = entity.async_announce - announce_started = asyncio.Event() async def async_announce(announcement): # Verify state change assert entity.state == AssistSatelliteState.RESPONDING await original_announce(announcement) - announce_started.set() def tts_generate_media_source_id( hass: HomeAssistant, @@ -464,3 +464,59 @@ async def test_vad_sensitivity_entity_not_found( with pytest.raises(RuntimeError): await entity.async_accept_pipeline_from_satellite(audio_stream) + + +@pytest.mark.parametrize( + ("service_data", "expected_params"), + [ + ( + {"start_message": "Hello"}, + AssistSatelliteAnnouncement( + "Hello", "https://www.home-assistant.io/resolved.mp3", "tts" + ), + ), + ( + { + "start_message": "Hello", + "start_media_id": "media-source://bla", + }, + AssistSatelliteAnnouncement( + "Hello", "https://www.home-assistant.io/resolved.mp3", "media_id" + ), + ), + ( + {"start_media_id": "http://example.com/bla.mp3"}, + AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"), + ), + ], +) +async def test_start_conversation( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + service_data: dict, + expected_params: tuple[str, str], +) -> None: + """Test starting a conversation on a device.""" + with ( + patch( + "homeassistant.components.assist_satellite.entity.tts_generate_media_source_id", + return_value="media-source://bla", + ), + patch( + "homeassistant.components.media_source.async_resolve_media", + return_value=PlayMedia( + url="https://www.home-assistant.io/resolved.mp3", + mime_type="audio/mp3", + ), + ), + ): + await hass.services.async_call( + "assist_satellite", + "start_conversation", + service_data, + target={"entity_id": "assist_satellite.test_entity"}, + blocking=True, + ) + + assert entity.start_conversations[0] == expected_params