Add CONTROL supported feature to Google conversation when API access (#123046)

* Add CONTROL supported feature to Google conversation when API access

* Better function name

* Handle entry update inline

* Reload instead of update
This commit is contained in:
Paulus Schoutsen 2024-08-03 08:16:30 +02:00 committed by GitHub
parent f6ad018f8f
commit aa6f0cd55a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 29 additions and 8 deletions

View file

@ -172,6 +172,10 @@ class GoogleGenerativeAIConversationEntity(
model="Generative AI",
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@property
def supported_languages(self) -> list[str] | Literal["*"]:
@ -185,6 +189,9 @@ class GoogleGenerativeAIConversationEntity(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
self.entry.async_on_unload(
self.entry.add_update_listener(self._async_entry_update_listener)
)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
@ -405,3 +412,10 @@ class GoogleGenerativeAIConversationEntity(
parts.append(llm_api.api_prompt)
return "\n".join(parts)
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Handle options update."""
# Reload as we update device info + entity name + supported features
await hass.config_entries.async_reload(entry.entry_id)

View file

@ -215,7 +215,7 @@
),
])
# ---
# name: test_default_prompt[config_entry_options0-None]
# name: test_default_prompt[config_entry_options0-0-None]
list([
tuple(
'',
@ -263,7 +263,7 @@
),
])
# ---
# name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation]
# name: test_default_prompt[config_entry_options0-0-conversation.google_generative_ai_conversation]
list([
tuple(
'',
@ -311,7 +311,7 @@
),
])
# ---
# name: test_default_prompt[config_entry_options1-None]
# name: test_default_prompt[config_entry_options1-1-None]
list([
tuple(
'',
@ -360,7 +360,7 @@
),
])
# ---
# name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation]
# name: test_default_prompt[config_entry_options1-1-conversation.google_generative_ai_conversation]
list([
tuple(
'',

View file

@ -19,7 +19,7 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
_escape_decode,
_format_schema,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
@ -39,10 +39,13 @@ def freeze_the_time():
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
@pytest.mark.parametrize(
"config_entry_options",
("config_entry_options", "expected_features"),
[
{},
({}, 0),
(
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
conversation.ConversationEntityFeature.CONTROL,
),
],
)
@pytest.mark.usefixtures("mock_init_component")
@ -52,6 +55,7 @@ async def test_default_prompt(
snapshot: SnapshotAssertion,
agent_id: str | None,
config_entry_options: {},
expected_features: conversation.ConversationEntityFeature,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the default prompt works."""
@ -98,6 +102,9 @@ async def test_default_prompt(
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
state = hass.states.get("conversation.google_generative_ai_conversation")
assert state.attributes[ATTR_SUPPORTED_FEATURES] == expected_features
@pytest.mark.parametrize(
("model_name", "supports_system_instruction"),