Add CONTROL supported feature to Google conversation when API access (#123046)
* Add CONTROL supported feature to Google conversation when API access * Better function name * Handle entry update inline * Reload instead of update
This commit is contained in:
parent
f6ad018f8f
commit
aa6f0cd55a
3 changed files with 29 additions and 8 deletions
|
@ -172,6 +172,10 @@ class GoogleGenerativeAIConversationEntity(
|
|||
model="Generative AI",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
|
@ -185,6 +189,9 @@ class GoogleGenerativeAIConversationEntity(
|
|||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||
)
|
||||
conversation.async_set_agent(self.hass, self.entry, self)
|
||||
self.entry.async_on_unload(
|
||||
self.entry.add_update_listener(self._async_entry_update_listener)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""When entity will be removed from Home Assistant."""
|
||||
|
@ -405,3 +412,10 @@ class GoogleGenerativeAIConversationEntity(
|
|||
parts.append(llm_api.api_prompt)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Handle options update."""
|
||||
# Reload as we update device info + entity name + supported features
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
|
|
@ -215,7 +215,7 @@
|
|||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[config_entry_options0-None]
|
||||
# name: test_default_prompt[config_entry_options0-0-None]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
|
@ -263,7 +263,7 @@
|
|||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[config_entry_options0-conversation.google_generative_ai_conversation]
|
||||
# name: test_default_prompt[config_entry_options0-0-conversation.google_generative_ai_conversation]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
|
@ -311,7 +311,7 @@
|
|||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[config_entry_options1-None]
|
||||
# name: test_default_prompt[config_entry_options1-1-None]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
|
@ -360,7 +360,7 @@
|
|||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[config_entry_options1-conversation.google_generative_ai_conversation]
|
||||
# name: test_default_prompt[config_entry_options1-1-conversation.google_generative_ai_conversation]
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
|
|
|
@ -19,7 +19,7 @@ from homeassistant.components.google_generative_ai_conversation.conversation imp
|
|||
_escape_decode,
|
||||
_format_schema,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
@ -39,10 +39,13 @@ def freeze_the_time():
|
|||
"agent_id", [None, "conversation.google_generative_ai_conversation"]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"config_entry_options",
|
||||
("config_entry_options", "expected_features"),
|
||||
[
|
||||
{},
|
||||
({}, 0),
|
||||
(
|
||||
{CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||
conversation.ConversationEntityFeature.CONTROL,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
|
@ -52,6 +55,7 @@ async def test_default_prompt(
|
|||
snapshot: SnapshotAssertion,
|
||||
agent_id: str | None,
|
||||
config_entry_options: {},
|
||||
expected_features: conversation.ConversationEntityFeature,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
|
@ -98,6 +102,9 @@ async def test_default_prompt(
|
|||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||
|
||||
state = hass.states.get("conversation.google_generative_ai_conversation")
|
||||
assert state.attributes[ATTR_SUPPORTED_FEATURES] == expected_features
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "supports_system_instruction"),
|
||||
|
|
Loading…
Add table
Reference in a new issue