Add exposed entities to the Assist LLM API prompt (#118203)

* Add exposed entities to the Assist LLM API prompt

* Check expose entities in Google test

* Copy Google default prompt test cases to LLM tests
This commit is contained in:
Paulus Schoutsen 2024-05-27 00:27:08 -04:00 committed by GitHub
parent c391d73fec
commit ecb05989ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 526 additions and 88 deletions

View file

@ -11,10 +11,13 @@ from homeassistant.helpers import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
llm,
)
from homeassistant.setup import async_setup_component
from homeassistant.util import yaml
from tests.common import MockConfigEntry
@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
async def test_assist_api_prompt(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry,
) -> None:
"""Test prompt for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
context = Context()
tool_input = llm.ToolInput(
tool_name=None,
@ -170,41 +175,232 @@ async def test_assist_api_prompt(
context=context,
user_prompt="test_text",
language="*",
assistant="test_assistant",
assistant="conversation",
device_id="test_device",
)
api = llm.async_get_api(hass, "assist")
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent."
"Only if the user wants to control a device, tell them to expose entities to their "
"voice assistant in Home Assistant."
)
# Expose entities
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
tool_input.device_id = device_registry.async_get_or_create(
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
).id
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area."
)
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, aliases=["Alternative name"])
entry1 = entity_registry.async_get_or_create(
"light",
"kitchen",
"mock-id-kitchen",
original_name="Kitchen",
suggested_object_id="kitchen",
)
entry2 = entity_registry.async_get_or_create(
"light",
"living_room",
"mock-id-living-room",
original_name="Living Room",
suggested_object_id="living_room",
device_id=device.id,
)
hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"})
hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"})
def create_entity(device: dr.DeviceEntry, write_state=True) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name or "Unnamed Device"),
suggested_object_id=str(device.name or "unnamed_device"),
)
if write_state:
entity.write_unavailable_state(hass)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
)
for i in range(3):
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device2 = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3 - disabled",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device2.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device2, False)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
)
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
assert exposed_entities == {
"light.1": {
"areas": "Test Area 2",
"names": "1",
"state": "unavailable",
},
entry1.entity_id: {
"names": "Kitchen",
"state": "on",
},
entry2.entity_id: {
"areas": "Test Area, Alternative name",
"names": "Living Room",
"state": "on",
},
"light.test_device": {
"areas": "Test Area, Alternative name",
"names": "Test Device",
"state": "unavailable",
},
"light.test_device_2": {
"areas": "Test Area 2",
"names": "Test Device 2",
"state": "unavailable",
},
"light.test_device_3": {
"areas": "Test Area 2",
"names": "Test Device 3",
"state": "unavailable",
},
"light.test_device_4": {
"areas": "Test Area 2",
"names": "Test Device 4",
"state": "unavailable",
},
"light.test_service": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_2": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_3": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.unnamed_device": {
"areas": "Test Area 2",
"names": "Unnamed Device",
"state": "unavailable",
},
}
exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n"
+ yaml.dump(exposed_entities)
)
first_part_prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
{exposed_entities_prompt}"""
)
# Fake that request is made from a specific device ID
tool_input.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area.
{exposed_entities_prompt}"""
)
# Add floor
floor = floor_registry.async_create("second floor")
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area (second floor)."
f"""{first_part_prompt}
You are in Test Area (second floor).
{exposed_entities_prompt}"""
)
# Add user
context.user_id = "12345"
mock_user = Mock()
mock_user.id = "12345"
@ -212,7 +408,8 @@ async def test_assist_api_prompt(
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area (second floor)."
" The user name is Test User."
f"""{first_part_prompt}
You are in Test Area (second floor).
The user name is Test User.
{exposed_entities_prompt}"""
)