Add exposed entities to the Assist LLM API prompt (#118203)
* Add exposed entities to the Assist LLM API prompt * Check expose entities in Google test * Copy Google default prompt test cases to LLM tests
This commit is contained in:
parent
c391d73fec
commit
ecb05989ca
4 changed files with 526 additions and 88 deletions
|
@ -11,10 +11,13 @@ from homeassistant.helpers import (
|
|||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
llm,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import yaml
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
|
|||
async def test_assist_api_prompt(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
area_registry: ar.AreaRegistry,
|
||||
floor_registry: fr.FloorRegistry,
|
||||
) -> None:
|
||||
"""Test prompt for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
context = Context()
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=None,
|
||||
|
@ -170,41 +175,232 @@ async def test_assist_api_prompt(
|
|||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="test_assistant",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
)
|
||||
api = llm.async_get_api(hass, "assist")
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
"Call the intent tools to control Home Assistant."
|
||||
" Just pass the name to the intent."
|
||||
"Only if the user wants to control a device, tell them to expose entities to their "
|
||||
"voice assistant in Home Assistant."
|
||||
)
|
||||
|
||||
# Expose entities
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
tool_input.device_id = device_registry.async_get_or_create(
|
||||
device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "1234")},
|
||||
name="Test Device",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
).id
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
"Call the intent tools to control Home Assistant."
|
||||
" Just pass the name to the intent. You are in Test Area."
|
||||
)
|
||||
area = area_registry.async_get_area_by_name("Test Area")
|
||||
area_registry.async_update(area.id, aliases=["Alternative name"])
|
||||
entry1 = entity_registry.async_get_or_create(
|
||||
"light",
|
||||
"kitchen",
|
||||
"mock-id-kitchen",
|
||||
original_name="Kitchen",
|
||||
suggested_object_id="kitchen",
|
||||
)
|
||||
entry2 = entity_registry.async_get_or_create(
|
||||
"light",
|
||||
"living_room",
|
||||
"mock-id-living-room",
|
||||
original_name="Living Room",
|
||||
suggested_object_id="living_room",
|
||||
device_id=device.id,
|
||||
)
|
||||
hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"})
|
||||
hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"})
|
||||
|
||||
def create_entity(device: dr.DeviceEntry, write_state=True) -> None:
|
||||
"""Create an entity for a device and track entity_id."""
|
||||
entity = entity_registry.async_get_or_create(
|
||||
"light",
|
||||
"test",
|
||||
device.id,
|
||||
device_id=device.id,
|
||||
original_name=str(device.name or "Unnamed Device"),
|
||||
suggested_object_id=str(device.name or "unnamed_device"),
|
||||
)
|
||||
if write_state:
|
||||
entity.write_unavailable_state(hass)
|
||||
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "1234")},
|
||||
name="Test Device",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
)
|
||||
)
|
||||
for i in range(3):
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", f"{i}abcd")},
|
||||
name="Test Service",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
)
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "5678")},
|
||||
name="Test Device 2",
|
||||
manufacturer="Test Manufacturer 2",
|
||||
model="Device 2",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
)
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876")},
|
||||
name="Test Device 3",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
model="Test Model 3A",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
)
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "qwer")},
|
||||
name="Test Device 4",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
)
|
||||
device2 = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-disabled")},
|
||||
name="Test Device 3 - disabled",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
model="Test Model 3A",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_registry.async_update_device(
|
||||
device2.id, disabled_by=dr.DeviceEntryDisabler.USER
|
||||
)
|
||||
create_entity(device2, False)
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-no-name")},
|
||||
manufacturer="Test Manufacturer NoName",
|
||||
model="Test Model NoName",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
)
|
||||
create_entity(
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-integer-values")},
|
||||
name=1,
|
||||
manufacturer=2,
|
||||
model=3,
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
)
|
||||
|
||||
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
|
||||
assert exposed_entities == {
|
||||
"light.1": {
|
||||
"areas": "Test Area 2",
|
||||
"names": "1",
|
||||
"state": "unavailable",
|
||||
},
|
||||
entry1.entity_id: {
|
||||
"names": "Kitchen",
|
||||
"state": "on",
|
||||
},
|
||||
entry2.entity_id: {
|
||||
"areas": "Test Area, Alternative name",
|
||||
"names": "Living Room",
|
||||
"state": "on",
|
||||
},
|
||||
"light.test_device": {
|
||||
"areas": "Test Area, Alternative name",
|
||||
"names": "Test Device",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_device_2": {
|
||||
"areas": "Test Area 2",
|
||||
"names": "Test Device 2",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_device_3": {
|
||||
"areas": "Test Area 2",
|
||||
"names": "Test Device 3",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_device_4": {
|
||||
"areas": "Test Area 2",
|
||||
"names": "Test Device 4",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_service": {
|
||||
"areas": "Test Area, Alternative name",
|
||||
"names": "Test Service",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_service_2": {
|
||||
"areas": "Test Area, Alternative name",
|
||||
"names": "Test Service",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.test_service_3": {
|
||||
"areas": "Test Area, Alternative name",
|
||||
"names": "Test Service",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"light.unnamed_device": {
|
||||
"areas": "Test Area 2",
|
||||
"names": "Unnamed Device",
|
||||
"state": "unavailable",
|
||||
},
|
||||
}
|
||||
exposed_entities_prompt = (
|
||||
"An overview of the areas and the devices in this smart home:\n"
|
||||
+ yaml.dump(exposed_entities)
|
||||
)
|
||||
first_part_prompt = (
|
||||
"Call the intent tools to control Home Assistant. "
|
||||
"Just pass the name to the intent. "
|
||||
"When controlling an area, prefer passing area name."
|
||||
)
|
||||
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Fake that request is made from a specific device ID
|
||||
tool_input.device_id = device.id
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
f"""{first_part_prompt}
|
||||
You are in Test Area.
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Add floor
|
||||
floor = floor_registry.async_create("second floor")
|
||||
area = area_registry.async_get_area_by_name("Test Area")
|
||||
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
"Call the intent tools to control Home Assistant."
|
||||
" Just pass the name to the intent. You are in Test Area (second floor)."
|
||||
f"""{first_part_prompt}
|
||||
You are in Test Area (second floor).
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Add user
|
||||
context.user_id = "12345"
|
||||
mock_user = Mock()
|
||||
mock_user.id = "12345"
|
||||
|
@ -212,7 +408,8 @@ async def test_assist_api_prompt(
|
|||
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
"Call the intent tools to control Home Assistant."
|
||||
" Just pass the name to the intent. You are in Test Area (second floor)."
|
||||
" The user name is Test User."
|
||||
f"""{first_part_prompt}
|
||||
You are in Test Area (second floor).
|
||||
The user name is Test User.
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue