diff --git a/homeassistant/components/cover/intent.py b/homeassistant/components/cover/intent.py index b38f698ac3d..7580cff063a 100644 --- a/homeassistant/components/cover/intent.py +++ b/homeassistant/components/cover/intent.py @@ -4,7 +4,7 @@ from homeassistant.const import SERVICE_CLOSE_COVER, SERVICE_OPEN_COVER from homeassistant.core import HomeAssistant from homeassistant.helpers import intent -from . import DOMAIN +from . import DOMAIN, CoverDeviceClass INTENT_OPEN_COVER = "HassOpenCover" INTENT_CLOSE_COVER = "HassCloseCover" @@ -21,6 +21,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None: "Opening {}", description="Opens a cover", platforms={DOMAIN}, + device_classes={CoverDeviceClass}, ), ) intent.async_register( @@ -32,5 +33,6 @@ async def async_setup_intents(hass: HomeAssistant) -> None: "Closing {}", description="Closes a cover", platforms={DOMAIN}, + device_classes={CoverDeviceClass}, ), ) diff --git a/homeassistant/components/intent/__init__.py b/homeassistant/components/intent/__init__.py index b1716a8d2d2..001f2515ebf 100644 --- a/homeassistant/components/intent/__init__.py +++ b/homeassistant/components/intent/__init__.py @@ -16,6 +16,7 @@ from homeassistant.components.cover import ( SERVICE_CLOSE_COVER, SERVICE_OPEN_COVER, SERVICE_SET_COVER_POSITION, + CoverDeviceClass, ) from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.lock import ( @@ -23,11 +24,14 @@ from homeassistant.components.lock import ( SERVICE_LOCK, SERVICE_UNLOCK, ) +from homeassistant.components.media_player import MediaPlayerDeviceClass +from homeassistant.components.switch import SwitchDeviceClass from homeassistant.components.valve import ( DOMAIN as VALVE_DOMAIN, SERVICE_CLOSE_VALVE, SERVICE_OPEN_VALVE, SERVICE_SET_VALVE_POSITION, + ValveDeviceClass, ) from homeassistant.const import ( ATTR_ENTITY_ID, @@ -67,6 +71,13 @@ __all__ = [ "DOMAIN", ] +ONOFF_DEVICE_CLASSES = { + CoverDeviceClass, + ValveDeviceClass, + SwitchDeviceClass, + MediaPlayerDeviceClass, +} + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Intent component.""" @@ -85,6 +96,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: HOMEASSISTANT_DOMAIN, SERVICE_TURN_ON, description="Turns on/opens a device or entity", + device_classes=ONOFF_DEVICE_CLASSES, ), ) intent.async_register( @@ -94,6 +106,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: HOMEASSISTANT_DOMAIN, SERVICE_TURN_OFF, description="Turns off/closes a device or entity", + device_classes=ONOFF_DEVICE_CLASSES, ), ) intent.async_register( @@ -103,6 +116,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: HOMEASSISTANT_DOMAIN, SERVICE_TOGGLE, description="Toggles a device or entity", + device_classes=ONOFF_DEVICE_CLASSES, ), ) intent.async_register( @@ -358,6 +372,7 @@ class SetPositionIntentHandler(intent.DynamicServiceIntentHandler): }, description="Sets the position of a device or entity", platforms={COVER_DOMAIN, VALVE_DOMAIN}, + device_classes={CoverDeviceClass, ValveDeviceClass}, ) def get_domain_and_service( diff --git a/homeassistant/components/media_player/intent.py b/homeassistant/components/media_player/intent.py index 8a5d824112a..edfab2a668f 100644 --- a/homeassistant/components/media_player/intent.py +++ b/homeassistant/components/media_player/intent.py @@ -16,7 +16,7 @@ from homeassistant.const import ( from homeassistant.core import Context, HomeAssistant, State from homeassistant.helpers import intent -from . import ATTR_MEDIA_VOLUME_LEVEL, DOMAIN +from . import ATTR_MEDIA_VOLUME_LEVEL, DOMAIN, MediaPlayerDeviceClass from .const import MediaPlayerEntityFeature, MediaPlayerState INTENT_MEDIA_PAUSE = "HassMediaPause" @@ -69,6 +69,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None: required_states={MediaPlayerState.PLAYING}, description="Skips a media player to the next item", platforms={DOMAIN}, + device_classes={MediaPlayerDeviceClass}, ), ) intent.async_register( @@ -82,6 +83,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None: required_states={MediaPlayerState.PLAYING}, description="Replays the previous item for a media player", platforms={DOMAIN}, + device_classes={MediaPlayerDeviceClass}, ), ) intent.async_register( @@ -100,6 +102,7 @@ async def async_setup_intents(hass: HomeAssistant) -> None: }, description="Sets the volume of a media player", platforms={DOMAIN}, + device_classes={MediaPlayerDeviceClass}, ), ) @@ -118,6 +121,7 @@ class MediaPauseHandler(intent.ServiceIntentHandler): required_states={MediaPlayerState.PLAYING}, description="Pauses a media player", platforms={DOMAIN}, + device_classes={MediaPlayerDeviceClass}, ) self.last_paused = last_paused @@ -153,6 +157,7 @@ class MediaUnpauseHandler(intent.ServiceIntentHandler): required_states={MediaPlayerState.PAUSED}, description="Resumes a media player", platforms={DOMAIN}, + device_classes={MediaPlayerDeviceClass}, ) self.last_paused = last_paused diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index eeb160934ff..be9b57bf814 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -7,7 +7,7 @@ import asyncio from collections.abc import Callable, Collection, Coroutine, Iterable import dataclasses from dataclasses import dataclass, field -from enum import Enum, auto +from enum import Enum, StrEnum, auto from functools import cached_property from itertools import groupby import logging @@ -820,6 +820,7 @@ class DynamicServiceIntentHandler(IntentHandler): required_states: set[str] | None = None, description: str | None = None, platforms: set[str] | None = None, + device_classes: set[type[StrEnum]] | None = None, ) -> None: """Create Service Intent Handler.""" self.intent_type = intent_type @@ -829,6 +830,7 @@ class DynamicServiceIntentHandler(IntentHandler): self.required_states = required_states self.description = description self.platforms = platforms + self.device_classes = device_classes self.required_slots: _IntentSlotsType = {} if required_slots: @@ -851,13 +853,38 @@ class DynamicServiceIntentHandler(IntentHandler): @cached_property def slot_schema(self) -> dict: """Return a slot schema.""" + domain_validator = ( + vol.In(list(self.required_domains)) if self.required_domains else cv.string + ) slot_schema = { vol.Any("name", "area", "floor"): non_empty_string, - vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), - vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]), - vol.Optional("preferred_area_id"): cv.string, - vol.Optional("preferred_floor_id"): cv.string, + vol.Optional("domain"): vol.All(cv.ensure_list, [domain_validator]), } + if self.device_classes: + # The typical way to match enums is with vol.Coerce, but we build a + # flat list to make the API simpler to describe programmatically + flattened_device_classes = vol.In( + [ + device_class.value + for device_class_enum in self.device_classes + for device_class in device_class_enum + ] + ) + slot_schema.update( + { + vol.Optional("device_class"): vol.All( + cv.ensure_list, + [flattened_device_classes], + ) + } + ) + + slot_schema.update( + { + vol.Optional("preferred_area_id"): cv.string, + vol.Optional("preferred_floor_id"): cv.string, + } + ) if self.required_slots: slot_schema.update( @@ -910,9 +937,6 @@ class DynamicServiceIntentHandler(IntentHandler): if "domain" in slots: domains = set(slots["domain"]["value"]) - if self.required_domains: - # Must be a subset of intent's required domain(s) - domains.intersection_update(self.required_domains) if "device_class" in slots: device_classes = set(slots["device_class"]["value"]) @@ -1120,6 +1144,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler): required_states: set[str] | None = None, description: str | None = None, platforms: set[str] | None = None, + device_classes: set[type[StrEnum]] | None = None, ) -> None: """Create service handler.""" super().__init__( @@ -1132,6 +1157,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler): required_states=required_states, description=description, platforms=platforms, + device_classes=device_classes, ) self.domain = domain self.service = service diff --git a/tests/components/conversation/test_default_agent_intents.py b/tests/components/conversation/test_default_agent_intents.py index 8be25136df4..7bae9c43f70 100644 --- a/tests/components/conversation/test_default_agent_intents.py +++ b/tests/components/conversation/test_default_agent_intents.py @@ -123,6 +123,34 @@ async def test_cover_set_position( assert call.data == {"entity_id": entity_id, cover.ATTR_POSITION: 50} +async def test_cover_device_class( + hass: HomeAssistant, + init_components, +) -> None: + """Test the open position for covers by device class.""" + await cover_intent.async_setup_intents(hass) + + entity_id = f"{cover.DOMAIN}.front" + hass.states.async_set( + entity_id, STATE_CLOSED, attributes={"device_class": "garage"} + ) + async_expose_entity(hass, conversation.DOMAIN, entity_id, True) + + # Open service + calls = async_mock_service(hass, cover.DOMAIN, cover.SERVICE_OPEN_COVER) + result = await conversation.async_converse( + hass, "open the garage door", None, Context(), None + ) + await hass.async_block_till_done() + + response = result.response + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert response.speech["plain"]["speech"] == "Opened the garage" + assert len(calls) == 1 + call = calls[0] + assert call.data == {"entity_id": entity_id} + + async def test_valve_intents( hass: HomeAssistant, init_components, diff --git a/tests/components/cover/test_intent.py b/tests/components/cover/test_intent.py index 8ee621596db..1cf23c4c3df 100644 --- a/tests/components/cover/test_intent.py +++ b/tests/components/cover/test_intent.py @@ -1,5 +1,9 @@ """The tests for the cover platform.""" +from typing import Any + +import pytest + from homeassistant.components.cover import ( ATTR_CURRENT_POSITION, DOMAIN, @@ -16,15 +20,24 @@ from homeassistant.setup import async_setup_component from tests.common import async_mock_service -async def test_open_cover_intent(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("slots"), + [ + ({"name": {"value": "garage door"}}), + ({"device_class": {"value": "garage"}}), + ], +) +async def test_open_cover_intent(hass: HomeAssistant, slots: dict[str, Any]) -> None: """Test HassOpenCover intent.""" await cover_intent.async_setup_intents(hass) - hass.states.async_set(f"{DOMAIN}.garage_door", STATE_CLOSED) + hass.states.async_set( + f"{DOMAIN}.garage_door", STATE_CLOSED, attributes={"device_class": "garage"} + ) calls = async_mock_service(hass, DOMAIN, SERVICE_OPEN_COVER) response = await intent.async_handle( - hass, "test", cover_intent.INTENT_OPEN_COVER, {"name": {"value": "garage door"}} + hass, "test", cover_intent.INTENT_OPEN_COVER, slots ) await hass.async_block_till_done() @@ -36,18 +49,27 @@ async def test_open_cover_intent(hass: HomeAssistant) -> None: assert call.data == {"entity_id": f"{DOMAIN}.garage_door"} -async def test_close_cover_intent(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("slots"), + [ + ({"name": {"value": "garage door"}}), + ({"device_class": {"value": "garage"}}), + ], +) +async def test_close_cover_intent(hass: HomeAssistant, slots: dict[str, Any]) -> None: """Test HassCloseCover intent.""" await cover_intent.async_setup_intents(hass) - hass.states.async_set(f"{DOMAIN}.garage_door", STATE_OPEN) + hass.states.async_set( + f"{DOMAIN}.garage_door", STATE_OPEN, attributes={"device_class": "garage"} + ) calls = async_mock_service(hass, DOMAIN, SERVICE_CLOSE_COVER) response = await intent.async_handle( hass, "test", cover_intent.INTENT_CLOSE_COVER, - {"name": {"value": "garage door"}}, + slots, ) await hass.async_block_till_done() @@ -59,13 +81,22 @@ async def test_close_cover_intent(hass: HomeAssistant) -> None: assert call.data == {"entity_id": f"{DOMAIN}.garage_door"} -async def test_set_cover_position(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("slots"), + [ + ({"name": {"value": "test cover"}, "position": {"value": 50}}), + ({"device_class": {"value": "shade"}, "position": {"value": 50}}), + ], +) +async def test_set_cover_position(hass: HomeAssistant, slots: dict[str, Any]) -> None: """Test HassSetPosition intent for covers.""" assert await async_setup_component(hass, "intent", {}) entity_id = f"{DOMAIN}.test_cover" hass.states.async_set( - entity_id, STATE_CLOSED, attributes={ATTR_CURRENT_POSITION: 0} + entity_id, + STATE_CLOSED, + attributes={ATTR_CURRENT_POSITION: 0, "device_class": "shade"}, ) calls = async_mock_service(hass, DOMAIN, SERVICE_SET_COVER_POSITION) @@ -73,7 +104,7 @@ async def test_set_cover_position(hass: HomeAssistant) -> None: hass, "test", intent.INTENT_SET_POSITION, - {"name": {"value": "test cover"}, "position": {"value": 50}}, + slots, ) await hass.async_block_till_done() diff --git a/tests/helpers/test_intent.py b/tests/helpers/test_intent.py index c592fc50c0a..ae8c2ed65d0 100644 --- a/tests/helpers/test_intent.py +++ b/tests/helpers/test_intent.py @@ -765,7 +765,7 @@ async def test_service_intent_handler_required_domains(hass: HomeAssistant) -> N ) # Still fails even if we provide the domain - with pytest.raises(intent.MatchFailedError): + with pytest.raises(intent.InvalidSlotInfo): await intent.async_handle( hass, "test", @@ -777,7 +777,10 @@ async def test_service_intent_handler_required_domains(hass: HomeAssistant) -> N async def test_service_handler_empty_strings(hass: HomeAssistant) -> None: """Test that passing empty strings for filters fails in ServiceIntentHandler.""" handler = intent.ServiceIntentHandler( - "TestType", "light", "turn_on", "Turned {} on" + "TestType", + "light", + "turn_on", + "Turned {} on", ) intent.async_register(hass, handler) @@ -814,3 +817,55 @@ async def test_service_handler_no_filter(hass: HomeAssistant) -> None: "test", "TestType", ) + + +async def test_service_handler_device_classes( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test that passing empty strings for filters fails in ServiceIntentHandler.""" + + # Register a fake service and a switch intent handler + call_done = asyncio.Event() + calls = [] + + # Register a service that takes 0.1 seconds to execute + async def mock_service(call): + """Mock service.""" + call_done.set() + calls.append(call) + + hass.services.async_register("switch", "turn_on", mock_service) + + handler = intent.ServiceIntentHandler( + "TestType", + "switch", + "turn_on", + "Turned {} on", + device_classes={switch.SwitchDeviceClass}, + ) + intent.async_register(hass, handler) + + # Create a switch enttiy and match by device class + hass.states.async_set( + "switch.bedroom", "off", attributes={"device_class": "outlet"} + ) + hass.states.async_set("switch.living_room", "off") + + await intent.async_handle( + hass, + "test", + "TestType", + slots={"device_class": {"value": "outlet"}}, + ) + await call_done.wait() + assert [call.data.get("entity_id") for call in calls] == ["switch.bedroom"] + calls.clear() + + # Validate which device classes are allowed + with pytest.raises(intent.InvalidSlotInfo): + await intent.async_handle( + hass, + "test", + "TestType", + slots={"device_class": {"value": "light"}}, + )