Improve LLM tool quality by more clearly specifying device_class slots (#122723)

* Limit intent / llm API device_class slots to only necessary services and limited set of values

* Fix ruff errors

* Run ruff format

* Fix typing and improve output schema

* Fix schema and improve flattening

* Revert conftest

* Revert recorder

* Fix ruff format errors

* Update using latest version of voluptuous
This commit is contained in:
Allen Porter 2024-07-31 05:36:02 -07:00 committed by GitHub
parent 7c7b408df1
commit f14471112d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 183 additions and 21 deletions

View file

@ -4,7 +4,7 @@ from homeassistant.const import SERVICE_CLOSE_COVER, SERVICE_OPEN_COVER
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import intent
from . import DOMAIN from . import DOMAIN, CoverDeviceClass
INTENT_OPEN_COVER = "HassOpenCover" INTENT_OPEN_COVER = "HassOpenCover"
INTENT_CLOSE_COVER = "HassCloseCover" INTENT_CLOSE_COVER = "HassCloseCover"
@ -21,6 +21,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
"Opening {}", "Opening {}",
description="Opens a cover", description="Opens a cover",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={CoverDeviceClass},
), ),
) )
intent.async_register( intent.async_register(
@ -32,5 +33,6 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
"Closing {}", "Closing {}",
description="Closes a cover", description="Closes a cover",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={CoverDeviceClass},
), ),
) )

View file

@ -16,6 +16,7 @@ from homeassistant.components.cover import (
SERVICE_CLOSE_COVER, SERVICE_CLOSE_COVER,
SERVICE_OPEN_COVER, SERVICE_OPEN_COVER,
SERVICE_SET_COVER_POSITION, SERVICE_SET_COVER_POSITION,
CoverDeviceClass,
) )
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.lock import ( from homeassistant.components.lock import (
@ -23,11 +24,14 @@ from homeassistant.components.lock import (
SERVICE_LOCK, SERVICE_LOCK,
SERVICE_UNLOCK, SERVICE_UNLOCK,
) )
from homeassistant.components.media_player import MediaPlayerDeviceClass
from homeassistant.components.switch import SwitchDeviceClass
from homeassistant.components.valve import ( from homeassistant.components.valve import (
DOMAIN as VALVE_DOMAIN, DOMAIN as VALVE_DOMAIN,
SERVICE_CLOSE_VALVE, SERVICE_CLOSE_VALVE,
SERVICE_OPEN_VALVE, SERVICE_OPEN_VALVE,
SERVICE_SET_VALVE_POSITION, SERVICE_SET_VALVE_POSITION,
ValveDeviceClass,
) )
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
@ -67,6 +71,13 @@ __all__ = [
"DOMAIN", "DOMAIN",
] ]
ONOFF_DEVICE_CLASSES = {
CoverDeviceClass,
ValveDeviceClass,
SwitchDeviceClass,
MediaPlayerDeviceClass,
}
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Intent component.""" """Set up the Intent component."""
@ -85,6 +96,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
HOMEASSISTANT_DOMAIN, HOMEASSISTANT_DOMAIN,
SERVICE_TURN_ON, SERVICE_TURN_ON,
description="Turns on/opens a device or entity", description="Turns on/opens a device or entity",
device_classes=ONOFF_DEVICE_CLASSES,
), ),
) )
intent.async_register( intent.async_register(
@ -94,6 +106,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
HOMEASSISTANT_DOMAIN, HOMEASSISTANT_DOMAIN,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
description="Turns off/closes a device or entity", description="Turns off/closes a device or entity",
device_classes=ONOFF_DEVICE_CLASSES,
), ),
) )
intent.async_register( intent.async_register(
@ -103,6 +116,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
HOMEASSISTANT_DOMAIN, HOMEASSISTANT_DOMAIN,
SERVICE_TOGGLE, SERVICE_TOGGLE,
description="Toggles a device or entity", description="Toggles a device or entity",
device_classes=ONOFF_DEVICE_CLASSES,
), ),
) )
intent.async_register( intent.async_register(
@ -358,6 +372,7 @@ class SetPositionIntentHandler(intent.DynamicServiceIntentHandler):
}, },
description="Sets the position of a device or entity", description="Sets the position of a device or entity",
platforms={COVER_DOMAIN, VALVE_DOMAIN}, platforms={COVER_DOMAIN, VALVE_DOMAIN},
device_classes={CoverDeviceClass, ValveDeviceClass},
) )
def get_domain_and_service( def get_domain_and_service(

View file

@ -16,7 +16,7 @@ from homeassistant.const import (
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.helpers import intent from homeassistant.helpers import intent
from . import ATTR_MEDIA_VOLUME_LEVEL, DOMAIN from . import ATTR_MEDIA_VOLUME_LEVEL, DOMAIN, MediaPlayerDeviceClass
from .const import MediaPlayerEntityFeature, MediaPlayerState from .const import MediaPlayerEntityFeature, MediaPlayerState
INTENT_MEDIA_PAUSE = "HassMediaPause" INTENT_MEDIA_PAUSE = "HassMediaPause"
@ -69,6 +69,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
required_states={MediaPlayerState.PLAYING}, required_states={MediaPlayerState.PLAYING},
description="Skips a media player to the next item", description="Skips a media player to the next item",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={MediaPlayerDeviceClass},
), ),
) )
intent.async_register( intent.async_register(
@ -82,6 +83,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
required_states={MediaPlayerState.PLAYING}, required_states={MediaPlayerState.PLAYING},
description="Replays the previous item for a media player", description="Replays the previous item for a media player",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={MediaPlayerDeviceClass},
), ),
) )
intent.async_register( intent.async_register(
@ -100,6 +102,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
}, },
description="Sets the volume of a media player", description="Sets the volume of a media player",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={MediaPlayerDeviceClass},
), ),
) )
@ -118,6 +121,7 @@ class MediaPauseHandler(intent.ServiceIntentHandler):
required_states={MediaPlayerState.PLAYING}, required_states={MediaPlayerState.PLAYING},
description="Pauses a media player", description="Pauses a media player",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={MediaPlayerDeviceClass},
) )
self.last_paused = last_paused self.last_paused = last_paused
@ -153,6 +157,7 @@ class MediaUnpauseHandler(intent.ServiceIntentHandler):
required_states={MediaPlayerState.PAUSED}, required_states={MediaPlayerState.PAUSED},
description="Resumes a media player", description="Resumes a media player",
platforms={DOMAIN}, platforms={DOMAIN},
device_classes={MediaPlayerDeviceClass},
) )
self.last_paused = last_paused self.last_paused = last_paused

View file

@ -7,7 +7,7 @@ import asyncio
from collections.abc import Callable, Collection, Coroutine, Iterable from collections.abc import Callable, Collection, Coroutine, Iterable
import dataclasses import dataclasses
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, StrEnum, auto
from functools import cached_property from functools import cached_property
from itertools import groupby from itertools import groupby
import logging import logging
@ -820,6 +820,7 @@ class DynamicServiceIntentHandler(IntentHandler):
required_states: set[str] | None = None, required_states: set[str] | None = None,
description: str | None = None, description: str | None = None,
platforms: set[str] | None = None, platforms: set[str] | None = None,
device_classes: set[type[StrEnum]] | None = None,
) -> None: ) -> None:
"""Create Service Intent Handler.""" """Create Service Intent Handler."""
self.intent_type = intent_type self.intent_type = intent_type
@ -829,6 +830,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.required_states = required_states self.required_states = required_states
self.description = description self.description = description
self.platforms = platforms self.platforms = platforms
self.device_classes = device_classes
self.required_slots: _IntentSlotsType = {} self.required_slots: _IntentSlotsType = {}
if required_slots: if required_slots:
@ -851,13 +853,38 @@ class DynamicServiceIntentHandler(IntentHandler):
@cached_property @cached_property
def slot_schema(self) -> dict: def slot_schema(self) -> dict:
"""Return a slot schema.""" """Return a slot schema."""
domain_validator = (
vol.In(list(self.required_domains)) if self.required_domains else cv.string
)
slot_schema = { slot_schema = {
vol.Any("name", "area", "floor"): non_empty_string, vol.Any("name", "area", "floor"): non_empty_string,
vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("domain"): vol.All(cv.ensure_list, [domain_validator]),
vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("preferred_area_id"): cv.string,
vol.Optional("preferred_floor_id"): cv.string,
} }
if self.device_classes:
# The typical way to match enums is with vol.Coerce, but we build a
# flat list to make the API simpler to describe programmatically
flattened_device_classes = vol.In(
[
device_class.value
for device_class_enum in self.device_classes
for device_class in device_class_enum
]
)
slot_schema.update(
{
vol.Optional("device_class"): vol.All(
cv.ensure_list,
[flattened_device_classes],
)
}
)
slot_schema.update(
{
vol.Optional("preferred_area_id"): cv.string,
vol.Optional("preferred_floor_id"): cv.string,
}
)
if self.required_slots: if self.required_slots:
slot_schema.update( slot_schema.update(
@ -910,9 +937,6 @@ class DynamicServiceIntentHandler(IntentHandler):
if "domain" in slots: if "domain" in slots:
domains = set(slots["domain"]["value"]) domains = set(slots["domain"]["value"])
if self.required_domains:
# Must be a subset of intent's required domain(s)
domains.intersection_update(self.required_domains)
if "device_class" in slots: if "device_class" in slots:
device_classes = set(slots["device_class"]["value"]) device_classes = set(slots["device_class"]["value"])
@ -1120,6 +1144,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
required_states: set[str] | None = None, required_states: set[str] | None = None,
description: str | None = None, description: str | None = None,
platforms: set[str] | None = None, platforms: set[str] | None = None,
device_classes: set[type[StrEnum]] | None = None,
) -> None: ) -> None:
"""Create service handler.""" """Create service handler."""
super().__init__( super().__init__(
@ -1132,6 +1157,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
required_states=required_states, required_states=required_states,
description=description, description=description,
platforms=platforms, platforms=platforms,
device_classes=device_classes,
) )
self.domain = domain self.domain = domain
self.service = service self.service = service

View file

@ -123,6 +123,34 @@ async def test_cover_set_position(
assert call.data == {"entity_id": entity_id, cover.ATTR_POSITION: 50} assert call.data == {"entity_id": entity_id, cover.ATTR_POSITION: 50}
async def test_cover_device_class(
hass: HomeAssistant,
init_components,
) -> None:
"""Test the open position for covers by device class."""
await cover_intent.async_setup_intents(hass)
entity_id = f"{cover.DOMAIN}.front"
hass.states.async_set(
entity_id, STATE_CLOSED, attributes={"device_class": "garage"}
)
async_expose_entity(hass, conversation.DOMAIN, entity_id, True)
# Open service
calls = async_mock_service(hass, cover.DOMAIN, cover.SERVICE_OPEN_COVER)
result = await conversation.async_converse(
hass, "open the garage door", None, Context(), None
)
await hass.async_block_till_done()
response = result.response
assert response.response_type == intent.IntentResponseType.ACTION_DONE
assert response.speech["plain"]["speech"] == "Opened the garage"
assert len(calls) == 1
call = calls[0]
assert call.data == {"entity_id": entity_id}
async def test_valve_intents( async def test_valve_intents(
hass: HomeAssistant, hass: HomeAssistant,
init_components, init_components,

View file

@ -1,5 +1,9 @@
"""The tests for the cover platform.""" """The tests for the cover platform."""
from typing import Any
import pytest
from homeassistant.components.cover import ( from homeassistant.components.cover import (
ATTR_CURRENT_POSITION, ATTR_CURRENT_POSITION,
DOMAIN, DOMAIN,
@ -16,15 +20,24 @@ from homeassistant.setup import async_setup_component
from tests.common import async_mock_service from tests.common import async_mock_service
async def test_open_cover_intent(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
("slots"),
[
({"name": {"value": "garage door"}}),
({"device_class": {"value": "garage"}}),
],
)
async def test_open_cover_intent(hass: HomeAssistant, slots: dict[str, Any]) -> None:
"""Test HassOpenCover intent.""" """Test HassOpenCover intent."""
await cover_intent.async_setup_intents(hass) await cover_intent.async_setup_intents(hass)
hass.states.async_set(f"{DOMAIN}.garage_door", STATE_CLOSED) hass.states.async_set(
f"{DOMAIN}.garage_door", STATE_CLOSED, attributes={"device_class": "garage"}
)
calls = async_mock_service(hass, DOMAIN, SERVICE_OPEN_COVER) calls = async_mock_service(hass, DOMAIN, SERVICE_OPEN_COVER)
response = await intent.async_handle( response = await intent.async_handle(
hass, "test", cover_intent.INTENT_OPEN_COVER, {"name": {"value": "garage door"}} hass, "test", cover_intent.INTENT_OPEN_COVER, slots
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -36,18 +49,27 @@ async def test_open_cover_intent(hass: HomeAssistant) -> None:
assert call.data == {"entity_id": f"{DOMAIN}.garage_door"} assert call.data == {"entity_id": f"{DOMAIN}.garage_door"}
async def test_close_cover_intent(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
("slots"),
[
({"name": {"value": "garage door"}}),
({"device_class": {"value": "garage"}}),
],
)
async def test_close_cover_intent(hass: HomeAssistant, slots: dict[str, Any]) -> None:
"""Test HassCloseCover intent.""" """Test HassCloseCover intent."""
await cover_intent.async_setup_intents(hass) await cover_intent.async_setup_intents(hass)
hass.states.async_set(f"{DOMAIN}.garage_door", STATE_OPEN) hass.states.async_set(
f"{DOMAIN}.garage_door", STATE_OPEN, attributes={"device_class": "garage"}
)
calls = async_mock_service(hass, DOMAIN, SERVICE_CLOSE_COVER) calls = async_mock_service(hass, DOMAIN, SERVICE_CLOSE_COVER)
response = await intent.async_handle( response = await intent.async_handle(
hass, hass,
"test", "test",
cover_intent.INTENT_CLOSE_COVER, cover_intent.INTENT_CLOSE_COVER,
{"name": {"value": "garage door"}}, slots,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -59,13 +81,22 @@ async def test_close_cover_intent(hass: HomeAssistant) -> None:
assert call.data == {"entity_id": f"{DOMAIN}.garage_door"} assert call.data == {"entity_id": f"{DOMAIN}.garage_door"}
async def test_set_cover_position(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
("slots"),
[
({"name": {"value": "test cover"}, "position": {"value": 50}}),
({"device_class": {"value": "shade"}, "position": {"value": 50}}),
],
)
async def test_set_cover_position(hass: HomeAssistant, slots: dict[str, Any]) -> None:
"""Test HassSetPosition intent for covers.""" """Test HassSetPosition intent for covers."""
assert await async_setup_component(hass, "intent", {}) assert await async_setup_component(hass, "intent", {})
entity_id = f"{DOMAIN}.test_cover" entity_id = f"{DOMAIN}.test_cover"
hass.states.async_set( hass.states.async_set(
entity_id, STATE_CLOSED, attributes={ATTR_CURRENT_POSITION: 0} entity_id,
STATE_CLOSED,
attributes={ATTR_CURRENT_POSITION: 0, "device_class": "shade"},
) )
calls = async_mock_service(hass, DOMAIN, SERVICE_SET_COVER_POSITION) calls = async_mock_service(hass, DOMAIN, SERVICE_SET_COVER_POSITION)
@ -73,7 +104,7 @@ async def test_set_cover_position(hass: HomeAssistant) -> None:
hass, hass,
"test", "test",
intent.INTENT_SET_POSITION, intent.INTENT_SET_POSITION,
{"name": {"value": "test cover"}, "position": {"value": 50}}, slots,
) )
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -765,7 +765,7 @@ async def test_service_intent_handler_required_domains(hass: HomeAssistant) -> N
) )
# Still fails even if we provide the domain # Still fails even if we provide the domain
with pytest.raises(intent.MatchFailedError): with pytest.raises(intent.InvalidSlotInfo):
await intent.async_handle( await intent.async_handle(
hass, hass,
"test", "test",
@ -777,7 +777,10 @@ async def test_service_intent_handler_required_domains(hass: HomeAssistant) -> N
async def test_service_handler_empty_strings(hass: HomeAssistant) -> None: async def test_service_handler_empty_strings(hass: HomeAssistant) -> None:
"""Test that passing empty strings for filters fails in ServiceIntentHandler.""" """Test that passing empty strings for filters fails in ServiceIntentHandler."""
handler = intent.ServiceIntentHandler( handler = intent.ServiceIntentHandler(
"TestType", "light", "turn_on", "Turned {} on" "TestType",
"light",
"turn_on",
"Turned {} on",
) )
intent.async_register(hass, handler) intent.async_register(hass, handler)
@ -814,3 +817,55 @@ async def test_service_handler_no_filter(hass: HomeAssistant) -> None:
"test", "test",
"TestType", "TestType",
) )
async def test_service_handler_device_classes(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test that passing empty strings for filters fails in ServiceIntentHandler."""
# Register a fake service and a switch intent handler
call_done = asyncio.Event()
calls = []
# Register a service that takes 0.1 seconds to execute
async def mock_service(call):
"""Mock service."""
call_done.set()
calls.append(call)
hass.services.async_register("switch", "turn_on", mock_service)
handler = intent.ServiceIntentHandler(
"TestType",
"switch",
"turn_on",
"Turned {} on",
device_classes={switch.SwitchDeviceClass},
)
intent.async_register(hass, handler)
# Create a switch enttiy and match by device class
hass.states.async_set(
"switch.bedroom", "off", attributes={"device_class": "outlet"}
)
hass.states.async_set("switch.living_room", "off")
await intent.async_handle(
hass,
"test",
"TestType",
slots={"device_class": {"value": "outlet"}},
)
await call_done.wait()
assert [call.data.get("entity_id") for call in calls] == ["switch.bedroom"]
calls.clear()
# Validate which device classes are allowed
with pytest.raises(intent.InvalidSlotInfo):
await intent.async_handle(
hass,
"test",
"TestType",
slots={"device_class": {"value": "light"}},
)