From 7923471b9463d98625057dd89ec5fd6989f9b545 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 7 May 2024 21:01:03 -0500 Subject: [PATCH] Intent target matching and media player enhancements (#115445) * Working * Tests are passing * Fix climate * Requested changes from review --- homeassistant/components/climate/intent.py | 2 + .../components/conversation/default_agent.py | 128 +-- homeassistant/components/intent/__init__.py | 83 +- homeassistant/components/light/intent.py | 146 +-- .../components/media_player/intent.py | 100 +- homeassistant/helpers/intent.py | 897 ++++++++++++------ tests/components/climate/test_intent.py | 56 +- .../conversation/test_default_agent.py | 24 +- .../test_default_agent_intents.py | 26 +- tests/components/intent/test_init.py | 2 +- tests/components/light/test_intent.py | 19 - tests/components/media_player/test_intent.py | 428 ++++++++- tests/helpers/test_intent.py | 412 +++++++- 13 files changed, 1734 insertions(+), 589 deletions(-) diff --git a/homeassistant/components/climate/intent.py b/homeassistant/components/climate/intent.py index 3073d3e3c26..632e678be94 100644 --- a/homeassistant/components/climate/intent.py +++ b/homeassistant/components/climate/intent.py @@ -56,6 +56,7 @@ class GetTemperatureIntent(intent.IntentHandler): if climate_state is None: raise intent.NoStatesMatchedError( + reason=intent.MatchFailedReason.AREA, name=entity_text or entity_name, area=area_name or area_id, floor=None, @@ -74,6 +75,7 @@ class GetTemperatureIntent(intent.IntentHandler): if climate_state is None: raise intent.NoStatesMatchedError( + reason=intent.MatchFailedReason.NAME, name=entity_name, area=None, floor=None, diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 10c60747d6c..0bf645c0460 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -351,10 +351,10 @@ class DefaultAgent(ConversationEntity): language, assistant=DOMAIN, ) - except intent.NoStatesMatchedError as no_states_error: + except intent.MatchFailedError as match_error: # Intent was valid, but no entities matched the constraints. - error_response_type, error_response_args = _get_no_states_matched_response( - no_states_error + error_response_type, error_response_args = _get_match_error_response( + match_error ) return _make_error_result( language, @@ -364,20 +364,6 @@ class DefaultAgent(ConversationEntity): ), conversation_id, ) - except intent.DuplicateNamesMatchedError as duplicate_names_error: - # Intent was valid, but two or more entities with the same name matched. - ( - error_response_type, - error_response_args, - ) = _get_duplicate_names_matched_response(duplicate_names_error) - return _make_error_result( - language, - intent.IntentResponseErrorCode.NO_VALID_TARGETS, - self._get_error_text( - error_response_type, lang_intents, **error_response_args - ), - conversation_id, - ) except intent.IntentHandleError: # Intent was valid and entities matched constraints, but an error # occurred during handling. @@ -804,34 +790,34 @@ class DefaultAgent(ConversationEntity): _LOGGER.debug("Exposed entities: %s", entity_names) # Expose all areas. - # - # We pass in area id here with the expectation that no two areas will - # share the same name or alias. areas = ar.async_get(self.hass) area_names = [] for area in areas.async_list_areas(): - area_names.append((area.name, area.id)) - if area.aliases: - for alias in area.aliases: - if not alias.strip(): - continue + area_names.append((area.name, area.name)) + if not area.aliases: + continue - area_names.append((alias, area.id)) + for alias in area.aliases: + alias = alias.strip() + if not alias: + continue + + area_names.append((alias, alias)) # Expose all floors. - # - # We pass in floor id here with the expectation that no two floors will - # share the same name or alias. floors = fr.async_get(self.hass) floor_names = [] for floor in floors.async_list_floors(): - floor_names.append((floor.name, floor.floor_id)) - if floor.aliases: - for alias in floor.aliases: - if not alias.strip(): - continue + floor_names.append((floor.name, floor.name)) + if not floor.aliases: + continue - floor_names.append((alias, floor.floor_id)) + for alias in floor.aliases: + alias = alias.strip() + if not alias: + continue + + floor_names.append((alias, floor.name)) self._slot_lists = { "area": TextSlotList.from_tuples(area_names, allow_template=False), @@ -1021,61 +1007,77 @@ def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str return ErrorKey.NO_INTENT, {} -def _get_no_states_matched_response( - no_states_error: intent.NoStatesMatchedError, +def _get_match_error_response( + match_error: intent.MatchFailedError, ) -> tuple[ErrorKey, dict[str, Any]]: - """Return key and template arguments for error when intent returns no matching states.""" + """Return key and template arguments for error when target matching fails.""" - # Device classes should be checked before domains - if no_states_error.device_classes: - device_class = next(iter(no_states_error.device_classes)) # first device class - if no_states_error.area: + constraints, result = match_error.constraints, match_error.result + reason = result.no_match_reason + + if ( + reason + in (intent.MatchFailedReason.DEVICE_CLASS, intent.MatchFailedReason.DOMAIN) + ) and constraints.device_classes: + device_class = next(iter(constraints.device_classes)) # first device class + if constraints.area_name: # device_class in area return ErrorKey.NO_DEVICE_CLASS_IN_AREA, { "device_class": device_class, - "area": no_states_error.area, + "area": constraints.area_name, } # device_class only return ErrorKey.NO_DEVICE_CLASS, {"device_class": device_class} - if no_states_error.domains: - domain = next(iter(no_states_error.domains)) # first domain - if no_states_error.area: + if (reason == intent.MatchFailedReason.DOMAIN) and constraints.domains: + domain = next(iter(constraints.domains)) # first domain + if constraints.area_name: # domain in area return ErrorKey.NO_DOMAIN_IN_AREA, { "domain": domain, - "area": no_states_error.area, + "area": constraints.area_name, } - if no_states_error.floor: + if constraints.floor_name: # domain in floor return ErrorKey.NO_DOMAIN_IN_FLOOR, { "domain": domain, - "floor": no_states_error.floor, + "floor": constraints.floor_name, } # domain only return ErrorKey.NO_DOMAIN, {"domain": domain} + if reason == intent.MatchFailedReason.DUPLICATE_NAME: + if constraints.floor_name: + # duplicate on floor + return ErrorKey.DUPLICATE_ENTITIES_IN_FLOOR, { + "entity": result.no_match_name, + "floor": constraints.floor_name, + } + + if constraints.area_name: + # duplicate on area + return ErrorKey.DUPLICATE_ENTITIES_IN_AREA, { + "entity": result.no_match_name, + "area": constraints.area_name, + } + + return ErrorKey.DUPLICATE_ENTITIES, {"entity": result.no_match_name} + + if reason == intent.MatchFailedReason.INVALID_AREA: + # Invalid area name + return ErrorKey.NO_AREA, {"area": result.no_match_name} + + if reason == intent.MatchFailedReason.INVALID_FLOOR: + # Invalid floor name + return ErrorKey.NO_FLOOR, {"floor": result.no_match_name} + # Default error return ErrorKey.NO_INTENT, {} -def _get_duplicate_names_matched_response( - duplicate_names_error: intent.DuplicateNamesMatchedError, -) -> tuple[ErrorKey, dict[str, Any]]: - """Return key and template arguments for error when intent returns duplicate matches.""" - - if duplicate_names_error.area: - return ErrorKey.DUPLICATE_ENTITIES_IN_AREA, { - "entity": duplicate_names_error.name, - "area": duplicate_names_error.area, - } - - return ErrorKey.DUPLICATE_ENTITIES, {"entity": duplicate_names_error.name} - - def _collect_list_references(expression: Expression, list_names: set[str]) -> None: """Collect list reference names recursively.""" if isinstance(expression, Sequence): diff --git a/homeassistant/components/intent/__init__.py b/homeassistant/components/intent/__init__.py index 7fd9fd4b712..d367cc20ac5 100644 --- a/homeassistant/components/intent/__init__.py +++ b/homeassistant/components/intent/__init__.py @@ -35,12 +35,7 @@ from homeassistant.const import ( SERVICE_TURN_ON, ) from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant, State -from homeassistant.helpers import ( - area_registry as ar, - config_validation as cv, - integration_platform, - intent, -) +from homeassistant.helpers import config_validation as cv, integration_platform, intent from homeassistant.helpers.typing import ConfigType from .const import DOMAIN @@ -176,7 +171,7 @@ class GetStateIntentHandler(intent.IntentHandler): intent_type = intent.INTENT_GET_STATE slot_schema = { - vol.Any("name", "area"): cv.string, + vol.Any("name", "area", "floor"): cv.string, vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("state"): vol.All(cv.ensure_list, [cv.string]), @@ -190,18 +185,13 @@ class GetStateIntentHandler(intent.IntentHandler): # Entity name to match name_slot = slots.get("name", {}) entity_name: str | None = name_slot.get("value") - entity_text: str | None = name_slot.get("text") - # Look up area first to fail early + # Get area/floor info area_slot = slots.get("area", {}) area_id = area_slot.get("value") - area_name = area_slot.get("text") - area: ar.AreaEntry | None = None - if area_id is not None: - areas = ar.async_get(hass) - area = areas.async_get_area(area_id) - if area is None: - raise intent.IntentHandleError(f"No area named {area_name}") + + floor_slot = slots.get("floor", {}) + floor_id = floor_slot.get("value") # Optional domain/device class filters. # Convert to sets for speed. @@ -218,32 +208,24 @@ class GetStateIntentHandler(intent.IntentHandler): if "state" in slots: state_names = set(slots["state"]["value"]) - states = list( - intent.async_match_states( - hass, - name=entity_name, - area=area, - domains=domains, - device_classes=device_classes, - assistant=intent_obj.assistant, - ) + match_constraints = intent.MatchTargetsConstraints( + name=entity_name, + area_name=area_id, + floor_name=floor_id, + domains=domains, + device_classes=device_classes, + assistant=intent_obj.assistant, ) - - _LOGGER.debug( - "Found %s state(s) that matched: name=%s, area=%s, domains=%s, device_classes=%s, assistant=%s", - len(states), - entity_name, - area, - domains, - device_classes, - intent_obj.assistant, - ) - - if entity_name and (len(states) > 1): - # Multiple entities matched for the same name - raise intent.DuplicateNamesMatchedError( - name=entity_text or entity_name, - area=area_name or area_id, + match_result = intent.async_match_targets(hass, match_constraints) + if ( + (not match_result.is_match) + and (match_result.no_match_reason is not None) + and (not match_result.no_match_reason.is_no_entities_reason()) + ): + # Don't try to answer questions for certain errors. + # Other match failure reasons are OK. + raise intent.MatchFailedError( + result=match_result, constraints=match_constraints ) # Create response @@ -251,13 +233,24 @@ class GetStateIntentHandler(intent.IntentHandler): response.response_type = intent.IntentResponseType.QUERY_ANSWER success_results: list[intent.IntentResponseTarget] = [] - if area is not None: - success_results.append( + if match_result.areas: + success_results.extend( intent.IntentResponseTarget( type=intent.IntentResponseTargetType.AREA, name=area.name, id=area.id, ) + for area in match_result.areas + ) + + if match_result.floors: + success_results.extend( + intent.IntentResponseTarget( + type=intent.IntentResponseTargetType.FLOOR, + name=floor.name, + id=floor.floor_id, + ) + for floor in match_result.floors ) # If we are matching a state name (e.g., "which lights are on?"), then @@ -271,7 +264,7 @@ class GetStateIntentHandler(intent.IntentHandler): matched_states: list[State] = [] unmatched_states: list[State] = [] - for state in states: + for state in match_result.states: success_results.append( intent.IntentResponseTarget( type=intent.IntentResponseTargetType.ENTITY, @@ -309,7 +302,7 @@ class SetPositionIntentHandler(intent.DynamicServiceIntentHandler): """Create set position handler.""" super().__init__( intent.INTENT_SET_POSITION, - extra_slots={ATTR_POSITION: vol.All(vol.Range(min=0, max=100))}, + required_slots={ATTR_POSITION: vol.All(vol.Range(min=0, max=100))}, ) def get_domain_and_service( diff --git a/homeassistant/components/light/intent.py b/homeassistant/components/light/intent.py index 53127babee9..1092c42d6d2 100644 --- a/homeassistant/components/light/intent.py +++ b/homeassistant/components/light/intent.py @@ -2,25 +2,16 @@ from __future__ import annotations -import asyncio import logging -from typing import Any import voluptuous as vol -from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON +from homeassistant.const import SERVICE_TURN_ON from homeassistant.core import HomeAssistant -from homeassistant.helpers import area_registry as ar, config_validation as cv, intent +from homeassistant.helpers import intent import homeassistant.util.color as color_util -from . import ( - ATTR_BRIGHTNESS_PCT, - ATTR_RGB_COLOR, - ATTR_SUPPORTED_COLOR_MODES, - DOMAIN, - brightness_supported, - color_supported, -) +from . import ATTR_BRIGHTNESS_PCT, ATTR_RGB_COLOR, DOMAIN _LOGGER = logging.getLogger(__name__) @@ -29,120 +20,17 @@ INTENT_SET = "HassLightSet" async def async_setup_intents(hass: HomeAssistant) -> None: """Set up the light intents.""" - intent.async_register(hass, SetIntentHandler()) - - -class SetIntentHandler(intent.IntentHandler): - """Handle set color intents.""" - - intent_type = INTENT_SET - slot_schema = { - vol.Any("name", "area"): cv.string, - vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), - vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]), - vol.Optional("color"): color_util.color_name_to_rgb, - vol.Optional("brightness"): vol.All(vol.Coerce(int), vol.Range(0, 100)), - } - - async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: - """Handle the hass intent.""" - hass = intent_obj.hass - service_data: dict[str, Any] = {} - slots = self.async_validate_slots(intent_obj.slots) - - name: str | None = slots.get("name", {}).get("value") - if name == "all": - # Don't match on name if targeting all entities - name = None - - # Look up area first to fail early - area_name = slots.get("area", {}).get("value") - area: ar.AreaEntry | None = None - if area_name is not None: - areas = ar.async_get(hass) - area = areas.async_get_area(area_name) or areas.async_get_area_by_name( - area_name - ) - if area is None: - raise intent.IntentHandleError(f"No area named {area_name}") - - # Optional domain/device class filters. - # Convert to sets for speed. - domains: set[str] | None = None - device_classes: set[str] | None = None - - if "domain" in slots: - domains = set(slots["domain"]["value"]) - - if "device_class" in slots: - device_classes = set(slots["device_class"]["value"]) - - states = list( - intent.async_match_states( - hass, - name=name, - area=area, - domains=domains, - device_classes=device_classes, - ) - ) - - if not states: - raise intent.IntentHandleError("No entities matched") - - if "color" in slots: - service_data[ATTR_RGB_COLOR] = slots["color"]["value"] - - if "brightness" in slots: - service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"] - - response = intent_obj.create_response() - needs_brightness = ATTR_BRIGHTNESS_PCT in service_data - needs_color = ATTR_RGB_COLOR in service_data - - success_results: list[intent.IntentResponseTarget] = [] - failed_results: list[intent.IntentResponseTarget] = [] - service_coros = [] - - if area is not None: - success_results.append( - intent.IntentResponseTarget( - type=intent.IntentResponseTargetType.AREA, - name=area.name, - id=area.id, - ) - ) - - for state in states: - target = intent.IntentResponseTarget( - type=intent.IntentResponseTargetType.ENTITY, - name=state.name, - id=state.entity_id, - ) - - # Test brightness/color - supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES) - if (needs_color and not color_supported(supported_color_modes)) or ( - needs_brightness and not brightness_supported(supported_color_modes) - ): - failed_results.append(target) - continue - - service_coros.append( - hass.services.async_call( - DOMAIN, - SERVICE_TURN_ON, - {**service_data, ATTR_ENTITY_ID: state.entity_id}, - context=intent_obj.context, - ) - ) - success_results.append(target) - - # Handle service calls in parallel. - await asyncio.gather(*service_coros) - - response.async_set_results( - success_results=success_results, failed_results=failed_results - ) - - return response + intent.async_register( + hass, + intent.ServiceIntentHandler( + INTENT_SET, + DOMAIN, + SERVICE_TURN_ON, + optional_slots={ + ("color", ATTR_RGB_COLOR): color_util.color_name_to_rgb, + ("brightness", ATTR_BRIGHTNESS_PCT): vol.All( + vol.Coerce(int), vol.Range(0, 100) + ), + }, + ), + ) diff --git a/homeassistant/components/media_player/intent.py b/homeassistant/components/media_player/intent.py index b0c0e7f559e..3a3237bf663 100644 --- a/homeassistant/components/media_player/intent.py +++ b/homeassistant/components/media_player/intent.py @@ -12,27 +12,29 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import intent from . import ATTR_MEDIA_VOLUME_LEVEL, DOMAIN +from .const import MediaPlayerEntityFeature, MediaPlayerState INTENT_MEDIA_PAUSE = "HassMediaPause" INTENT_MEDIA_UNPAUSE = "HassMediaUnpause" INTENT_MEDIA_NEXT = "HassMediaNext" INTENT_SET_VOLUME = "HassSetVolume" +DATA_LAST_PAUSED = f"{DOMAIN}.last_paused" + async def async_setup_intents(hass: HomeAssistant) -> None: """Set up the media_player intents.""" - intent.async_register( - hass, - intent.ServiceIntentHandler(INTENT_MEDIA_UNPAUSE, DOMAIN, SERVICE_MEDIA_PLAY), - ) - intent.async_register( - hass, - intent.ServiceIntentHandler(INTENT_MEDIA_PAUSE, DOMAIN, SERVICE_MEDIA_PAUSE), - ) + intent.async_register(hass, MediaUnpauseHandler()) + intent.async_register(hass, MediaPauseHandler()) intent.async_register( hass, intent.ServiceIntentHandler( - INTENT_MEDIA_NEXT, DOMAIN, SERVICE_MEDIA_NEXT_TRACK + INTENT_MEDIA_NEXT, + DOMAIN, + SERVICE_MEDIA_NEXT_TRACK, + required_domains={DOMAIN}, + required_features=MediaPlayerEntityFeature.NEXT_TRACK, + required_states={MediaPlayerState.PLAYING}, ), ) intent.async_register( @@ -41,10 +43,88 @@ async def async_setup_intents(hass: HomeAssistant) -> None: INTENT_SET_VOLUME, DOMAIN, SERVICE_VOLUME_SET, - extra_slots={ + required_domains={DOMAIN}, + required_states={MediaPlayerState.PLAYING}, + required_features=MediaPlayerEntityFeature.VOLUME_SET, + required_slots={ ATTR_MEDIA_VOLUME_LEVEL: vol.All( vol.Range(min=0, max=100), lambda val: val / 100 ) }, ), ) + + +class MediaPauseHandler(intent.ServiceIntentHandler): + """Handler for pause intent. Records last paused media players.""" + + def __init__(self) -> None: + """Initialize handler.""" + super().__init__( + INTENT_MEDIA_PAUSE, + DOMAIN, + SERVICE_MEDIA_PAUSE, + required_domains={DOMAIN}, + required_features=MediaPlayerEntityFeature.PAUSE, + required_states={MediaPlayerState.PLAYING}, + ) + + async def async_handle_states( + self, + intent_obj: intent.Intent, + match_result: intent.MatchTargetsResult, + match_constraints: intent.MatchTargetsConstraints, + match_preferences: intent.MatchTargetsPreferences | None = None, + ) -> intent.IntentResponse: + """Record last paused media players.""" + hass = intent_obj.hass + + if match_result.is_match: + # Save entity ids of paused media players + hass.data[DATA_LAST_PAUSED] = {s.entity_id for s in match_result.states} + + return await super().async_handle_states( + intent_obj, match_result, match_constraints + ) + + +class MediaUnpauseHandler(intent.ServiceIntentHandler): + """Handler for unpause/resume intent. Uses last paused media players.""" + + def __init__(self) -> None: + """Initialize handler.""" + super().__init__( + INTENT_MEDIA_UNPAUSE, + DOMAIN, + SERVICE_MEDIA_PLAY, + required_domains={DOMAIN}, + required_states={MediaPlayerState.PAUSED}, + ) + + async def async_handle_states( + self, + intent_obj: intent.Intent, + match_result: intent.MatchTargetsResult, + match_constraints: intent.MatchTargetsConstraints, + match_preferences: intent.MatchTargetsPreferences | None = None, + ) -> intent.IntentResponse: + """Unpause last paused media players.""" + hass = intent_obj.hass + + if ( + match_result.is_match + and (not match_constraints.name) + and (last_paused := hass.data.get(DATA_LAST_PAUSED)) + ): + # Resume only the previously paused media players if they are in the + # targeted set. + targeted_ids = {s.entity_id for s in match_result.states} + overlapping_ids = targeted_ids.intersection(last_paused) + if overlapping_ids: + match_result.states = [ + s for s in match_result.states if s.entity_id in overlapping_ids + ] + + return await super().async_handle_states( + intent_obj, match_result, match_constraints + ) diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index 8d7f34007f8..daf0229e8ce 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -6,9 +6,10 @@ from abc import abstractmethod import asyncio from collections.abc import Collection, Coroutine, Iterable import dataclasses -from dataclasses import dataclass -from enum import Enum +from dataclasses import dataclass, field +from enum import Enum, auto from functools import cached_property +from itertools import groupby import logging from typing import Any @@ -145,11 +146,144 @@ class IntentUnexpectedError(IntentError): """Unexpected error while handling intent.""" -class NoStatesMatchedError(IntentError): +class MatchFailedReason(Enum): + """Possible reasons for match failure in async_match_targets.""" + + NAME = auto() + """No entities matched name constraint.""" + + AREA = auto() + """No entities matched area constraint.""" + + FLOOR = auto() + """No entities matched floor constraint.""" + + DOMAIN = auto() + """No entities matched domain constraint.""" + + DEVICE_CLASS = auto() + """No entities matched device class constraint.""" + + FEATURE = auto() + """No entities matched supported features constraint.""" + + STATE = auto() + """No entities matched required states constraint.""" + + ASSISTANT = auto() + """No entities matched exposed to assistant constraint.""" + + INVALID_AREA = auto() + """Area name from constraint does not exist.""" + + INVALID_FLOOR = auto() + """Floor name from constraint does not exist.""" + + DUPLICATE_NAME = auto() + """Two or more entities matched the same name constraint and could not be disambiguated.""" + + def is_no_entities_reason(self) -> bool: + """Return True if the match failed because no entities matched.""" + return self not in ( + MatchFailedReason.INVALID_AREA, + MatchFailedReason.INVALID_FLOOR, + MatchFailedReason.DUPLICATE_NAME, + ) + + +@dataclass +class MatchTargetsConstraints: + """Constraints for async_match_targets.""" + + name: str | None = None + """Entity name or alias.""" + + area_name: str | None = None + """Area name, id, or alias.""" + + floor_name: str | None = None + """Floor name, id, or alias.""" + + domains: Collection[str] | None = None + """Domain names.""" + + device_classes: Collection[str] | None = None + """Device class names.""" + + features: int | None = None + """Required supported features.""" + + states: Collection[str] | None = None + """Required states for entities.""" + + assistant: str | None = None + """Name of assistant that entities should be exposed to.""" + + allow_duplicate_names: bool = False + """True if entities with duplicate names are allowed in result.""" + + +@dataclass +class MatchTargetsPreferences: + """Preferences used to disambiguate duplicate name matches in async_match_targets.""" + + area_id: str | None = None + """Id of area to use when deduplicating names.""" + + floor_id: str | None = None + """Id of floor to use when deduplicating names.""" + + +@dataclass +class MatchTargetsResult: + """Result from async_match_targets.""" + + is_match: bool + """True if one or more entities matched.""" + + no_match_reason: MatchFailedReason | None = None + """Reason for failed match when is_match = False.""" + + states: list[State] = field(default_factory=list) + """List of matched entity states when is_match = True.""" + + no_match_name: str | None = None + """Name of invalid area/floor or duplicate name when match fails for those reasons.""" + + areas: list[area_registry.AreaEntry] = field(default_factory=list) + """Areas that were targeted.""" + + floors: list[floor_registry.FloorEntry] = field(default_factory=list) + """Floors that were targeted.""" + + +class MatchFailedError(IntentError): + """Error when target matching fails.""" + + def __init__( + self, + result: MatchTargetsResult, + constraints: MatchTargetsConstraints, + preferences: MatchTargetsPreferences | None = None, + ) -> None: + """Initialize error.""" + super().__init__() + + self.result = result + self.constraints = constraints + self.preferences = preferences + + def __str__(self) -> str: + """Return string representation.""" + return f"" + + +class NoStatesMatchedError(MatchFailedError): """Error when no states match the intent's constraints.""" def __init__( self, + reason: MatchFailedReason, name: str | None = None, area: str | None = None, floor: str | None = None, @@ -157,123 +291,379 @@ class NoStatesMatchedError(IntentError): device_classes: set[str] | None = None, ) -> None: """Initialize error.""" - super().__init__() - - self.name = name - self.area = area - self.floor = floor - self.domains = domains - self.device_classes = device_classes + super().__init__( + result=MatchTargetsResult(False, reason), + constraints=MatchTargetsConstraints( + name=name, + area_name=area, + floor_name=floor, + domains=domains, + device_classes=device_classes, + ), + ) -class DuplicateNamesMatchedError(IntentError): - """Error when two or more entities with the same name matched.""" +@dataclass +class MatchTargetsCandidate: + """Candidate for async_match_targets.""" - def __init__(self, name: str, area: str | None) -> None: - """Initialize error.""" - super().__init__() - - self.name = name - self.area = area + state: State + entity: entity_registry.RegistryEntry | None = None + area: area_registry.AreaEntry | None = None + floor: floor_registry.FloorEntry | None = None + device: device_registry.DeviceEntry | None = None + matched_name: str | None = None -def _is_device_class( - state: State, - entity: entity_registry.RegistryEntry | None, - device_classes: Collection[str], -) -> bool: - """Return true if entity device class matches.""" - # Try entity first - if (entity is not None) and (entity.device_class is not None): - # Entity device class can be None or blank as "unset" - if entity.device_class in device_classes: - return True - - # Fall back to state attribute - device_class = state.attributes.get(ATTR_DEVICE_CLASS) - return (device_class is not None) and (device_class in device_classes) - - -def _has_name( - state: State, entity: entity_registry.RegistryEntry | None, name: str -) -> bool: - """Return true if entity name or alias matches.""" - if name in (state.entity_id, state.name.casefold()): - return True - - # Check name/aliases - if (entity is None) or (not entity.aliases): - return False - - return any(name == alias.casefold() for alias in entity.aliases) - - -def _find_area( - id_or_name: str, areas: area_registry.AreaRegistry -) -> area_registry.AreaEntry | None: - """Find an area by id or name, checking aliases too.""" - area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name) - if area is not None: - return area - - # Check area aliases - for maybe_area in areas.areas.values(): - if not maybe_area.aliases: +def _find_areas( + name: str, areas: area_registry.AreaRegistry +) -> Iterable[area_registry.AreaEntry]: + """Find all areas matching a name (including aliases).""" + name_norm = _normalize_name(name) + for area in areas.async_list_areas(): + # Accept name or area id + if (area.id == name) or (_normalize_name(area.name) == name_norm): + yield area continue - for area_alias in maybe_area.aliases: - if id_or_name == area_alias.casefold(): - return maybe_area - - return None - - -def _find_floor( - id_or_name: str, floors: floor_registry.FloorRegistry -) -> floor_registry.FloorEntry | None: - """Find an floor by id or name, checking aliases too.""" - floor = floors.async_get_floor(id_or_name) or floors.async_get_floor_by_name( - id_or_name - ) - if floor is not None: - return floor - - # Check floor aliases - for maybe_floor in floors.floors.values(): - if not maybe_floor.aliases: + if not area.aliases: continue - for floor_alias in maybe_floor.aliases: - if id_or_name == floor_alias.casefold(): - return maybe_floor - - return None + for alias in area.aliases: + if _normalize_name(alias) == name_norm: + yield area + break -def _filter_by_areas( - states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]], - areas: Iterable[area_registry.AreaEntry], +def _find_floors( + name: str, floors: floor_registry.FloorRegistry +) -> Iterable[floor_registry.FloorEntry]: + """Find all floors matching a name (including aliases).""" + name_norm = _normalize_name(name) + for floor in floors.async_list_floors(): + # Accept name or floor id + if (floor.floor_id == name) or (_normalize_name(floor.name) == name_norm): + yield floor + continue + + if not floor.aliases: + continue + + for alias in floor.aliases: + if _normalize_name(alias) == name_norm: + yield floor + break + + +def _normalize_name(name: str) -> str: + """Normalize name for comparison.""" + return name.strip().casefold() + + +def _filter_by_name( + name: str, + candidates: Iterable[MatchTargetsCandidate], +) -> Iterable[MatchTargetsCandidate]: + """Filter candidates by name.""" + name_norm = _normalize_name(name) + + for candidate in candidates: + # Accept name or entity id + if (candidate.state.entity_id == name) or _normalize_name( + candidate.state.name + ) == name_norm: + candidate.matched_name = name + yield candidate + continue + + if candidate.entity is None: + continue + + if candidate.entity.name and ( + _normalize_name(candidate.entity.name) == name_norm + ): + candidate.matched_name = name + yield candidate + continue + + # Check aliases + if candidate.entity.aliases: + for alias in candidate.entity.aliases: + if _normalize_name(alias) == name_norm: + candidate.matched_name = name + yield candidate + break + + +def _filter_by_features( + features: int, + candidates: Iterable[MatchTargetsCandidate], +) -> Iterable[MatchTargetsCandidate]: + """Filter candidates by supported features.""" + for candidate in candidates: + if (candidate.entity is not None) and ( + (candidate.entity.supported_features & features) == features + ): + yield candidate + continue + + supported_features = candidate.state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) + if (supported_features & features) == features: + yield candidate + + +def _filter_by_device_classes( + device_classes: Iterable[str], + candidates: Iterable[MatchTargetsCandidate], +) -> Iterable[MatchTargetsCandidate]: + """Filter candidates by device classes.""" + for candidate in candidates: + if ( + (candidate.entity is not None) + and candidate.entity.device_class + and (candidate.entity.device_class in device_classes) + ): + yield candidate + continue + + device_class = candidate.state.attributes.get(ATTR_DEVICE_CLASS) + if device_class and (device_class in device_classes): + yield candidate + + +def _add_areas( + areas: area_registry.AreaRegistry, devices: device_registry.DeviceRegistry, -) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]: - """Filter state/entity pairs by an area.""" - filter_area_ids: set[str | None] = {a.id for a in areas} - entity_area_ids: dict[str, str | None] = {} - for _state, entity in states_and_entities: - if entity is None: + candidates: Iterable[MatchTargetsCandidate], +) -> None: + """Add area and device entries to match candidates.""" + for candidate in candidates: + if candidate.entity is None: continue - if entity.area_id: - # Use entity's area id first - entity_area_ids[entity.id] = entity.area_id - elif entity.device_id: - # Fall back to device area if not set on entity - device = devices.async_get(entity.device_id) - if device is not None: - entity_area_ids[entity.id] = device.area_id + if candidate.entity.device_id: + candidate.device = devices.async_get(candidate.entity.device_id) - for state, entity in states_and_entities: - if (entity is not None) and (entity_area_ids.get(entity.id) in filter_area_ids): - yield (state, entity) + if candidate.entity.area_id: + # Use entity area first + candidate.area = areas.async_get_area(candidate.entity.area_id) + assert candidate.area is not None + elif (candidate.device is not None) and candidate.device.area_id: + # Fall back to device area + candidate.area = areas.async_get_area(candidate.device.area_id) + + +@callback +def async_match_targets( # noqa: C901 + hass: HomeAssistant, + constraints: MatchTargetsConstraints, + preferences: MatchTargetsPreferences | None = None, + states: list[State] | None = None, +) -> MatchTargetsResult: + """Match entities based on constraints in order to handle an intent.""" + preferences = preferences or MatchTargetsPreferences() + filtered_by_domain = False + + if not states: + # Get all states and filter by domain + states = hass.states.async_all(constraints.domains) + filtered_by_domain = True + if not states: + return MatchTargetsResult(False, MatchFailedReason.DOMAIN) + + if constraints.assistant: + # Filter by exposure + states = [ + s + for s in states + if async_should_expose(hass, constraints.assistant, s.entity_id) + ] + if not states: + return MatchTargetsResult(False, MatchFailedReason.ASSISTANT) + + if constraints.domains and (not filtered_by_domain): + # Filter by domain (if we didn't already do it) + states = [s for s in states if s.domain in constraints.domains] + if not states: + return MatchTargetsResult(False, MatchFailedReason.DOMAIN) + + if constraints.states: + # Filter by state + states = [s for s in states if s.state in constraints.states] + if not states: + return MatchTargetsResult(False, MatchFailedReason.STATE) + + # Exit early so we can to avoid registry lookups + if not ( + constraints.name + or constraints.features + or constraints.device_classes + or constraints.area_name + or constraints.floor_name + ): + return MatchTargetsResult(True, states=states) + + # We need entity registry entries now + er = entity_registry.async_get(hass) + candidates = [MatchTargetsCandidate(s, er.async_get(s.entity_id)) for s in states] + + if constraints.name: + # Filter by entity name or alias + candidates = list(_filter_by_name(constraints.name, candidates)) + if not candidates: + return MatchTargetsResult(False, MatchFailedReason.NAME) + + if constraints.features: + # Filter by supported features + candidates = list(_filter_by_features(constraints.features, candidates)) + if not candidates: + return MatchTargetsResult(False, MatchFailedReason.FEATURE) + + if constraints.device_classes: + # Filter by device class + candidates = list( + _filter_by_device_classes(constraints.device_classes, candidates) + ) + if not candidates: + return MatchTargetsResult(False, MatchFailedReason.DEVICE_CLASS) + + # Check floor/area constraints + targeted_floors: list[floor_registry.FloorEntry] | None = None + targeted_areas: list[area_registry.AreaEntry] | None = None + + # True when area information has been added to candidates + areas_added = False + + if constraints.floor_name or constraints.area_name: + ar = area_registry.async_get(hass) + dr = device_registry.async_get(hass) + _add_areas(ar, dr, candidates) + areas_added = True + + if constraints.floor_name: + # Filter by areas associated with floor + fr = floor_registry.async_get(hass) + targeted_floors = list(_find_floors(constraints.floor_name, fr)) + if not targeted_floors: + return MatchTargetsResult( + False, + MatchFailedReason.INVALID_FLOOR, + no_match_name=constraints.floor_name, + ) + + possible_floor_ids = {floor.floor_id for floor in targeted_floors} + possible_area_ids = { + area.id + for area in ar.async_list_areas() + if area.floor_id in possible_floor_ids + } + + candidates = [ + c + for c in candidates + if (c.area is not None) and (c.area.id in possible_area_ids) + ] + if not candidates: + return MatchTargetsResult( + False, MatchFailedReason.FLOOR, floors=targeted_floors + ) + else: + # All areas are possible + possible_area_ids = {area.id for area in ar.async_list_areas()} + + if constraints.area_name: + targeted_areas = list(_find_areas(constraints.area_name, ar)) + if not targeted_areas: + return MatchTargetsResult( + False, + MatchFailedReason.INVALID_AREA, + no_match_name=constraints.area_name, + ) + + matching_area_ids = {area.id for area in targeted_areas} + + # May be constrained by floors above + possible_area_ids.intersection_update(matching_area_ids) + candidates = [ + c + for c in candidates + if (c.area is not None) and (c.area.id in possible_area_ids) + ] + if not candidates: + return MatchTargetsResult( + False, MatchFailedReason.AREA, areas=targeted_areas + ) + + if constraints.name and (not constraints.allow_duplicate_names): + # Check for duplicates + if not areas_added: + ar = area_registry.async_get(hass) + dr = device_registry.async_get(hass) + _add_areas(ar, dr, candidates) + areas_added = True + + sorted_candidates = sorted( + [c for c in candidates if c.matched_name], + key=lambda c: c.matched_name or "", + ) + final_candidates: list[MatchTargetsCandidate] = [] + for name, group in groupby(sorted_candidates, key=lambda c: c.matched_name): + group_candidates = list(group) + if len(group_candidates) < 2: + # No duplicates for name + final_candidates.extend(group_candidates) + continue + + # Try to disambiguate by preferences + if preferences.floor_id: + group_candidates = [ + c + for c in group_candidates + if (c.area is not None) + and (c.area.floor_id == preferences.floor_id) + ] + if len(group_candidates) < 2: + # Disambiguated by floor + final_candidates.extend(group_candidates) + continue + + if preferences.area_id: + group_candidates = [ + c + for c in group_candidates + if (c.area is not None) and (c.area.id == preferences.area_id) + ] + if len(group_candidates) < 2: + # Disambiguated by area + final_candidates.extend(group_candidates) + continue + + # Couldn't disambiguate duplicate names + return MatchTargetsResult( + False, + MatchFailedReason.DUPLICATE_NAME, + no_match_name=name, + areas=targeted_areas or [], + floors=targeted_floors or [], + ) + + if not final_candidates: + return MatchTargetsResult( + False, + MatchFailedReason.NAME, + areas=targeted_areas or [], + floors=targeted_floors or [], + ) + + candidates = final_candidates + + return MatchTargetsResult( + True, + None, + states=[c.state for c in candidates], + areas=targeted_areas or [], + floors=targeted_floors or [], + ) @callback @@ -282,111 +672,24 @@ def async_match_states( hass: HomeAssistant, name: str | None = None, area_name: str | None = None, - area: area_registry.AreaEntry | None = None, floor_name: str | None = None, - floor: floor_registry.FloorEntry | None = None, domains: Collection[str] | None = None, device_classes: Collection[str] | None = None, - states: Iterable[State] | None = None, - entities: entity_registry.EntityRegistry | None = None, - areas: area_registry.AreaRegistry | None = None, - floors: floor_registry.FloorRegistry | None = None, - devices: device_registry.DeviceRegistry | None = None, - assistant: str | None = None, + states: list[State] | None = None, ) -> Iterable[State]: - """Find states that match the constraints.""" - if states is None: - # All states - states = hass.states.async_all() - - if entities is None: - entities = entity_registry.async_get(hass) - - if devices is None: - devices = device_registry.async_get(hass) - - if areas is None: - areas = area_registry.async_get(hass) - - if floors is None: - floors = floor_registry.async_get(hass) - - # Gather entities - states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]] = [] - for state in states: - entity = entities.async_get(state.entity_id) - if (entity is not None) and entity.entity_category: - # Skip diagnostic entities - continue - - states_and_entities.append((state, entity)) - - # Filter by domain and device class - if domains: - states_and_entities = [ - (state, entity) - for state, entity in states_and_entities - if state.domain in domains - ] - - if device_classes: - # Check device class in state attribute and in entity entry (if available) - states_and_entities = [ - (state, entity) - for state, entity in states_and_entities - if _is_device_class(state, entity, device_classes) - ] - - filter_areas: list[area_registry.AreaEntry] = [] - - if (floor is None) and (floor_name is not None): - # Look up floor by name - floor = _find_floor(floor_name, floors) - if floor is None: - _LOGGER.warning("Floor not found: %s", floor_name) - return - - if floor is not None: - filter_areas = [ - a for a in areas.async_list_areas() if a.floor_id == floor.floor_id - ] - - if (area is None) and (area_name is not None): - # Look up area by name - area = _find_area(area_name, areas) - if area is None: - _LOGGER.warning("Area not found: %s", area_name) - return - - if area is not None: - filter_areas = [area] - - if filter_areas: - # Filter by states/entities by area - states_and_entities = list( - _filter_by_areas(states_and_entities, filter_areas, devices) - ) - - if assistant is not None: - # Filter by exposure - states_and_entities = [ - (state, entity) - for state, entity in states_and_entities - if async_should_expose(hass, assistant, state.entity_id) - ] - - if name is not None: - # Filter by name - name = name.casefold() - - # Check states - for state, entity in states_and_entities: - if _has_name(state, entity, name): - yield state - else: - # Not filtered by name - for state, _entity in states_and_entities: - yield state + """Simplified interface to async_match_targets that returns states matching the constraints.""" + result = async_match_targets( + hass, + constraints=MatchTargetsConstraints( + name=name, + area_name=area_name, + floor_name=floor_name, + domains=domains, + device_classes=device_classes, + ), + states=states, + ) + return result.states @callback @@ -447,6 +750,8 @@ class DynamicServiceIntentHandler(IntentHandler): vol.Any("name", "area", "floor"): cv.string, vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]), + vol.Optional("preferred_area_id"): cv.string, + vol.Optional("preferred_floor_id"): cv.string, } # We use a small timeout in service calls to (hopefully) pass validation @@ -457,12 +762,36 @@ class DynamicServiceIntentHandler(IntentHandler): self, intent_type: str, speech: str | None = None, - extra_slots: dict[str, vol.Schema] | None = None, + required_slots: dict[str | tuple[str, str], vol.Schema] | None = None, + optional_slots: dict[str | tuple[str, str], vol.Schema] | None = None, + required_domains: set[str] | None = None, + required_features: int | None = None, + required_states: set[str] | None = None, ) -> None: """Create Service Intent Handler.""" self.intent_type = intent_type self.speech = speech - self.extra_slots = extra_slots + self.required_domains = required_domains + self.required_features = required_features + self.required_states = required_states + + self.required_slots: dict[tuple[str, str], vol.Schema] = {} + if required_slots: + for key, value_schema in required_slots.items(): + if isinstance(key, str): + # Slot name/service data key + key = (key, key) + + self.required_slots[key] = value_schema + + self.optional_slots: dict[tuple[str, str], vol.Schema] = {} + if optional_slots: + for key, value_schema in optional_slots.items(): + if isinstance(key, str): + # Slot name/service data key + key = (key, key) + + self.optional_slots[key] = value_schema @cached_property def _slot_schema(self) -> vol.Schema: @@ -470,12 +799,16 @@ class DynamicServiceIntentHandler(IntentHandler): if self.slot_schema is None: raise ValueError("Slot schema is not defined") - if self.extra_slots: + if self.required_slots or self.optional_slots: slot_schema = { **self.slot_schema, **{ - vol.Required(key): schema - for key, schema in self.extra_slots.items() + vol.Required(key[0]): schema + for key, schema in self.required_slots.items() + }, + **{ + vol.Optional(key[0]): schema + for key, schema in self.optional_slots.items() }, } else: @@ -508,97 +841,107 @@ class DynamicServiceIntentHandler(IntentHandler): # Don't match on name if targeting all entities entity_name = None - # Look up area to fail early + # Get area/floor info area_slot = slots.get("area", {}) area_id = area_slot.get("value") - area_name = area_slot.get("text") - area: area_registry.AreaEntry | None = None - if area_id is not None: - areas = area_registry.async_get(hass) - area = areas.async_get_area(area_id) - if area is None: - raise IntentHandleError(f"No area named {area_name}") - # Look up floor to fail early floor_slot = slots.get("floor", {}) floor_id = floor_slot.get("value") - floor_name = floor_slot.get("text") - floor: floor_registry.FloorEntry | None = None - if floor_id is not None: - floors = floor_registry.async_get(hass) - floor = floors.async_get_floor(floor_id) - if floor is None: - raise IntentHandleError(f"No floor named {floor_name}") # Optional domain/device class filters. # Convert to sets for speed. - domains: set[str] | None = None + domains: set[str] | None = self.required_domains device_classes: set[str] | None = None if "domain" in slots: domains = set(slots["domain"]["value"]) + if self.required_domains: + # Must be a subset of intent's required domain(s) + domains.intersection_update(self.required_domains) if "device_class" in slots: device_classes = set(slots["device_class"]["value"]) - states = list( - async_match_states( - hass, - name=entity_name, - area=area, - floor=floor, - domains=domains, - device_classes=device_classes, - assistant=intent_obj.assistant, - ) + match_constraints = MatchTargetsConstraints( + name=entity_name, + area_name=area_id, + floor_name=floor_id, + domains=domains, + device_classes=device_classes, + assistant=intent_obj.assistant, + features=self.required_features, + states=self.required_states, + ) + match_preferences = MatchTargetsPreferences( + area_id=slots.get("preferred_area_id", {}).get("value"), + floor_id=slots.get("preferred_floor_id", {}).get("value"), ) - if not states: - # No states matched constraints - raise NoStatesMatchedError( - name=entity_text or entity_name, - area=area_name or area_id, - floor=floor_name or floor_id, - domains=domains, - device_classes=device_classes, + match_result = async_match_targets(hass, match_constraints, match_preferences) + if not match_result.is_match: + raise MatchFailedError( + result=match_result, + constraints=match_constraints, + preferences=match_preferences, ) - if entity_name and (len(states) > 1): - # Multiple entities matched for the same name - raise DuplicateNamesMatchedError( - name=entity_text or entity_name, - area=area_name or area_id, - ) + # Ensure name is text + if ("name" in slots) and entity_text: + slots["name"]["value"] = entity_text + + # Replace area/floor values with the resolved ids for use in templates + if ("area" in slots) and match_result.areas: + slots["area"]["value"] = match_result.areas[0].id + + if ("floor" in slots) and match_result.floors: + slots["floor"]["value"] = match_result.floors[0].floor_id # Update intent slots to include any transformations done by the schemas intent_obj.slots = slots - response = await self.async_handle_states(intent_obj, states, area) + response = await self.async_handle_states( + intent_obj, match_result, match_constraints, match_preferences + ) # Make the matched states available in the response - response.async_set_states(matched_states=states, unmatched_states=[]) + response.async_set_states( + matched_states=match_result.states, unmatched_states=[] + ) return response async def async_handle_states( self, intent_obj: Intent, - states: list[State], - area: area_registry.AreaEntry | None = None, + match_result: MatchTargetsResult, + match_constraints: MatchTargetsConstraints, + match_preferences: MatchTargetsPreferences | None = None, ) -> IntentResponse: """Complete action on matched entity states.""" - assert states, "No states" - hass = intent_obj.hass - success_results: list[IntentResponseTarget] = [] + states = match_result.states response = intent_obj.create_response() - if area is not None: - success_results.append( + hass = intent_obj.hass + success_results: list[IntentResponseTarget] = [] + + if match_result.floors: + success_results.extend( + IntentResponseTarget( + type=IntentResponseTargetType.FLOOR, + name=floor.name, + id=floor.floor_id, + ) + for floor in match_result.floors + ) + speech_name = match_result.floors[0].name + elif match_result.areas: + success_results.extend( IntentResponseTarget( type=IntentResponseTargetType.AREA, name=area.name, id=area.id ) + for area in match_result.areas ) - speech_name = area.name + speech_name = match_result.areas[0].name else: speech_name = states[0].name @@ -654,11 +997,20 @@ class DynamicServiceIntentHandler(IntentHandler): hass = intent_obj.hass service_data: dict[str, Any] = {ATTR_ENTITY_ID: state.entity_id} - if self.extra_slots: + if self.required_slots: service_data.update( - {key: intent_obj.slots[key]["value"] for key in self.extra_slots} + { + key[1]: intent_obj.slots[key[0]]["value"] + for key in self.required_slots + } ) + if self.optional_slots: + for key in self.optional_slots: + value = intent_obj.slots.get(key[0]) + if value: + service_data[key[1]] = value["value"] + await self._run_then_background( hass.async_create_task_internal( hass.services.async_call( @@ -702,10 +1054,22 @@ class ServiceIntentHandler(DynamicServiceIntentHandler): domain: str, service: str, speech: str | None = None, - extra_slots: dict[str, vol.Schema] | None = None, + required_slots: dict[str | tuple[str, str], vol.Schema] | None = None, + optional_slots: dict[str | tuple[str, str], vol.Schema] | None = None, + required_domains: set[str] | None = None, + required_features: int | None = None, + required_states: set[str] | None = None, ) -> None: """Create service handler.""" - super().__init__(intent_type, speech=speech, extra_slots=extra_slots) + super().__init__( + intent_type, + speech=speech, + required_slots=required_slots, + optional_slots=optional_slots, + required_domains=required_domains, + required_features=required_features, + required_states=required_states, + ) self.domain = domain self.service = service @@ -806,6 +1170,7 @@ class IntentResponseTargetType(str, Enum): """Type of target for an intent response.""" AREA = "area" + FLOOR = "floor" DEVICE = "device" ENTITY = "entity" DOMAIN = "domain" diff --git a/tests/components/climate/test_intent.py b/tests/components/climate/test_intent.py index e4f92759793..1aaea386320 100644 --- a/tests/components/climate/test_intent.py +++ b/tests/components/climate/test_intent.py @@ -183,7 +183,7 @@ async def test_get_temperature( assert state.attributes["current_temperature"] == 22.0 # Check area with no climate entities - with pytest.raises(intent.NoStatesMatchedError) as error: + with pytest.raises(intent.MatchFailedError) as error: response = await intent.async_handle( hass, "test", @@ -192,14 +192,16 @@ async def test_get_temperature( ) # Exception should contain details of what we tried to match - assert isinstance(error.value, intent.NoStatesMatchedError) - assert error.value.name is None - assert error.value.area == office_area.name - assert error.value.domains == {DOMAIN} - assert error.value.device_classes is None + assert isinstance(error.value, intent.MatchFailedError) + assert error.value.result.no_match_reason == intent.MatchFailedReason.AREA + constraints = error.value.constraints + assert constraints.name is None + assert constraints.area_name == office_area.name + assert constraints.domains == {DOMAIN} + assert constraints.device_classes is None # Check wrong name - with pytest.raises(intent.NoStatesMatchedError) as error: + with pytest.raises(intent.MatchFailedError) as error: response = await intent.async_handle( hass, "test", @@ -207,14 +209,16 @@ async def test_get_temperature( {"name": {"value": "Does not exist"}}, ) - assert isinstance(error.value, intent.NoStatesMatchedError) - assert error.value.name == "Does not exist" - assert error.value.area is None - assert error.value.domains == {DOMAIN} - assert error.value.device_classes is None + assert isinstance(error.value, intent.MatchFailedError) + assert error.value.result.no_match_reason == intent.MatchFailedReason.NAME + constraints = error.value.constraints + assert constraints.name == "Does not exist" + assert constraints.area_name is None + assert constraints.domains == {DOMAIN} + assert constraints.device_classes is None # Check wrong name with area - with pytest.raises(intent.NoStatesMatchedError) as error: + with pytest.raises(intent.MatchFailedError) as error: response = await intent.async_handle( hass, "test", @@ -222,11 +226,13 @@ async def test_get_temperature( {"name": {"value": "Climate 1"}, "area": {"value": bedroom_area.name}}, ) - assert isinstance(error.value, intent.NoStatesMatchedError) - assert error.value.name == "Climate 1" - assert error.value.area == bedroom_area.name - assert error.value.domains == {DOMAIN} - assert error.value.device_classes is None + assert isinstance(error.value, intent.MatchFailedError) + assert error.value.result.no_match_reason == intent.MatchFailedReason.AREA + constraints = error.value.constraints + assert constraints.name == "Climate 1" + assert constraints.area_name == bedroom_area.name + assert constraints.domains == {DOMAIN} + assert constraints.device_classes is None async def test_get_temperature_no_entities( @@ -275,7 +281,7 @@ async def test_get_temperature_no_state( with ( patch("homeassistant.core.StateMachine.async_all", return_value=[]), - pytest.raises(intent.NoStatesMatchedError) as error, + pytest.raises(intent.MatchFailedError) as error, ): await intent.async_handle( hass, @@ -285,8 +291,10 @@ async def test_get_temperature_no_state( ) # Exception should contain details of what we tried to match - assert isinstance(error.value, intent.NoStatesMatchedError) - assert error.value.name is None - assert error.value.area == "Living Room" - assert error.value.domains == {DOMAIN} - assert error.value.device_classes is None + assert isinstance(error.value, intent.MatchFailedError) + assert error.value.result.no_match_reason == intent.MatchFailedReason.AREA + constraints = error.value.constraints + assert constraints.name is None + assert constraints.area_name == "Living Room" + assert constraints.domains == {DOMAIN} + assert constraints.device_classes is None diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 9048a1259c5..f100dc810fb 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -6,12 +6,12 @@ from unittest.mock import AsyncMock, patch from hassil.recognize import Intent, IntentData, MatchEntity, RecognizeResult import pytest -from homeassistant.components import conversation +from homeassistant.components import conversation, cover from homeassistant.components.conversation import default_agent from homeassistant.components.homeassistant.exposed_entities import ( async_get_assistant_settings, ) -from homeassistant.const import ATTR_FRIENDLY_NAME +from homeassistant.const import ATTR_DEVICE_CLASS, ATTR_FRIENDLY_NAME, STATE_CLOSED from homeassistant.core import DOMAIN as HASS_DOMAIN, Context, HomeAssistant from homeassistant.helpers import ( area_registry as ar, @@ -607,14 +607,23 @@ async def test_error_no_domain_in_floor( async def test_error_no_device_class(hass: HomeAssistant, init_components) -> None: """Test error message when no entities of a device class exist.""" + # Create a cover entity that is not a window. + # This ensures that the filtering below won't exit early because there are + # no entities in the cover domain. + hass.states.async_set( + "cover.garage_door", + STATE_CLOSED, + attributes={ATTR_DEVICE_CLASS: cover.CoverDeviceClass.GARAGE}, + ) # We don't have a sentence for opening all windows + cover_domain = MatchEntity(name="domain", value="cover", text="cover") window_class = MatchEntity(name="device_class", value="window", text="windows") recognize_result = RecognizeResult( intent=Intent("HassTurnOn"), intent_data=IntentData([]), - entities={"device_class": window_class}, - entities_list=[window_class], + entities={"domain": cover_domain, "device_class": window_class}, + entities_list=[cover_domain, window_class], ) with patch( @@ -792,7 +801,9 @@ async def test_no_states_matched_default_error( with patch( "homeassistant.components.conversation.default_agent.intent.async_handle", - side_effect=intent.NoStatesMatchedError(), + side_effect=intent.MatchFailedError( + intent.MatchTargetsResult(False), intent.MatchTargetsConstraints() + ), ): result = await conversation.async_converse( hass, "turn on lights in the kitchen", None, Context(), None @@ -863,17 +874,14 @@ async def test_empty_aliases( assert slot_lists.keys() == {"area", "name", "floor"} areas = slot_lists["area"] assert len(areas.values) == 1 - assert areas.values[0].value_out == area_kitchen.id assert areas.values[0].text_in.text == area_kitchen.normalized_name names = slot_lists["name"] assert len(names.values) == 1 - assert names.values[0].value_out == kitchen_light.name assert names.values[0].text_in.text == kitchen_light.name floors = slot_lists["floor"] assert len(floors.values) == 1 - assert floors.values[0].value_out == floor_1.floor_id assert floors.values[0].text_in.text == floor_1.name diff --git a/tests/components/conversation/test_default_agent_intents.py b/tests/components/conversation/test_default_agent_intents.py index 9636ac07f63..16b0ccf3107 100644 --- a/tests/components/conversation/test_default_agent_intents.py +++ b/tests/components/conversation/test_default_agent_intents.py @@ -12,9 +12,17 @@ from homeassistant.components import ( ) from homeassistant.components.cover import intent as cover_intent from homeassistant.components.homeassistant.exposed_entities import async_expose_entity -from homeassistant.components.media_player import intent as media_player_intent +from homeassistant.components.media_player import ( + MediaPlayerEntityFeature, + intent as media_player_intent, +) from homeassistant.components.vacuum import intent as vaccum_intent -from homeassistant.const import STATE_CLOSED +from homeassistant.const import ( + ATTR_SUPPORTED_FEATURES, + STATE_CLOSED, + STATE_PAUSED, + STATE_PLAYING, +) from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import ( area_registry as ar, @@ -189,7 +197,13 @@ async def test_media_player_intents( await media_player_intent.async_setup_intents(hass) entity_id = f"{media_player.DOMAIN}.tv" - hass.states.async_set(entity_id, media_player.STATE_PLAYING) + attributes = { + ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature.PAUSE + | MediaPlayerEntityFeature.NEXT_TRACK + | MediaPlayerEntityFeature.VOLUME_SET + } + + hass.states.async_set(entity_id, STATE_PLAYING, attributes=attributes) async_expose_entity(hass, conversation.DOMAIN, entity_id, True) # pause @@ -206,6 +220,9 @@ async def test_media_player_intents( call = calls[0] assert call.data == {"entity_id": entity_id} + # Unpause requires paused state + hass.states.async_set(entity_id, STATE_PAUSED, attributes=attributes) + # unpause calls = async_mock_service( hass, media_player.DOMAIN, media_player.SERVICE_MEDIA_PLAY @@ -222,6 +239,9 @@ async def test_media_player_intents( call = calls[0] assert call.data == {"entity_id": entity_id} + # Next track requires playing state + hass.states.async_set(entity_id, STATE_PLAYING, attributes=attributes) + # next calls = async_mock_service( hass, media_player.DOMAIN, media_player.SERVICE_MEDIA_NEXT_TRACK diff --git a/tests/components/intent/test_init.py b/tests/components/intent/test_init.py index 77a6a368c01..586ea7dd8a2 100644 --- a/tests/components/intent/test_init.py +++ b/tests/components/intent/test_init.py @@ -422,7 +422,7 @@ async def test_get_state_intent( assert not result.matched_states and not result.unmatched_states # Test unknown area failure - with pytest.raises(intent.IntentHandleError): + with pytest.raises(intent.MatchFailedError): await intent.async_handle( hass, "test", diff --git a/tests/components/light/test_intent.py b/tests/components/light/test_intent.py index b21b9367bba..94457928b5b 100644 --- a/tests/components/light/test_intent.py +++ b/tests/components/light/test_intent.py @@ -34,25 +34,6 @@ async def test_intent_set_color(hass: HomeAssistant) -> None: assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255) -async def test_intent_set_color_tests_feature(hass: HomeAssistant) -> None: - """Test the set color intent.""" - hass.states.async_set("light.hello", "off") - calls = async_mock_service(hass, light.DOMAIN, light.SERVICE_TURN_ON) - await intent.async_setup_intents(hass) - - response = await async_handle( - hass, - "test", - intent.INTENT_SET, - {"name": {"value": "Hello"}, "color": {"value": "blue"}}, - ) - - # Response should contain one failed target - assert len(response.success_results) == 0 - assert len(response.failed_results) == 1 - assert len(calls) == 0 - - async def test_intent_set_color_and_brightness(hass: HomeAssistant) -> None: """Test the set color intent.""" hass.states.async_set( diff --git a/tests/components/media_player/test_intent.py b/tests/components/media_player/test_intent.py index b0ea7fe8e94..8cce7cff44c 100644 --- a/tests/components/media_player/test_intent.py +++ b/tests/components/media_player/test_intent.py @@ -1,5 +1,7 @@ """The tests for the media_player platform.""" +import pytest + from homeassistant.components.media_player import ( DOMAIN, SERVICE_MEDIA_NEXT_TRACK, @@ -8,9 +10,20 @@ from homeassistant.components.media_player import ( SERVICE_VOLUME_SET, intent as media_player_intent, ) -from homeassistant.const import STATE_IDLE +from homeassistant.components.media_player.const import MediaPlayerEntityFeature +from homeassistant.const import ( + ATTR_SUPPORTED_FEATURES, + STATE_IDLE, + STATE_PAUSED, + STATE_PLAYING, +) from homeassistant.core import HomeAssistant -from homeassistant.helpers import intent +from homeassistant.helpers import ( + area_registry as ar, + entity_registry as er, + floor_registry as fr, + intent, +) from tests.common import async_mock_service @@ -20,14 +33,19 @@ async def test_pause_media_player_intent(hass: HomeAssistant) -> None: await media_player_intent.async_setup_intents(hass) entity_id = f"{DOMAIN}.test_media_player" - hass.states.async_set(entity_id, STATE_IDLE) - calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PAUSE) + attributes = {ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature.PAUSE} + + hass.states.async_set(entity_id, STATE_PLAYING, attributes=attributes) + calls = async_mock_service( + hass, + DOMAIN, + SERVICE_MEDIA_PAUSE, + ) response = await intent.async_handle( hass, "test", media_player_intent.INTENT_MEDIA_PAUSE, - {"name": {"value": "test media player"}}, ) await hass.async_block_till_done() @@ -38,20 +56,45 @@ async def test_pause_media_player_intent(hass: HomeAssistant) -> None: assert call.service == SERVICE_MEDIA_PAUSE assert call.data == {"entity_id": entity_id} + # Test if not playing + hass.states.async_set(entity_id, STATE_IDLE, attributes=attributes) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + ) + await hass.async_block_till_done() + + # Test feature not supported + hass.states.async_set( + entity_id, + STATE_PLAYING, + attributes={ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature(0)}, + ) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + ) + await hass.async_block_till_done() + async def test_unpause_media_player_intent(hass: HomeAssistant) -> None: """Test HassMediaUnpause intent for media players.""" await media_player_intent.async_setup_intents(hass) entity_id = f"{DOMAIN}.test_media_player" - hass.states.async_set(entity_id, STATE_IDLE) + hass.states.async_set(entity_id, STATE_PAUSED) calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PLAY) response = await intent.async_handle( hass, "test", media_player_intent.INTENT_MEDIA_UNPAUSE, - {"name": {"value": "test media player"}}, ) await hass.async_block_till_done() @@ -62,20 +105,36 @@ async def test_unpause_media_player_intent(hass: HomeAssistant) -> None: assert call.service == SERVICE_MEDIA_PLAY assert call.data == {"entity_id": entity_id} + # Test if not paused + hass.states.async_set( + entity_id, + STATE_PLAYING, + ) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_UNPAUSE, + ) + await hass.async_block_till_done() + async def test_next_media_player_intent(hass: HomeAssistant) -> None: """Test HassMediaNext intent for media players.""" await media_player_intent.async_setup_intents(hass) entity_id = f"{DOMAIN}.test_media_player" - hass.states.async_set(entity_id, STATE_IDLE) + attributes = {ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature.NEXT_TRACK} + + hass.states.async_set(entity_id, STATE_PLAYING, attributes=attributes) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_NEXT_TRACK) response = await intent.async_handle( hass, "test", media_player_intent.INTENT_MEDIA_NEXT, - {"name": {"value": "test media player"}}, ) await hass.async_block_till_done() @@ -86,20 +145,49 @@ async def test_next_media_player_intent(hass: HomeAssistant) -> None: assert call.service == SERVICE_MEDIA_NEXT_TRACK assert call.data == {"entity_id": entity_id} + # Test if not playing + hass.states.async_set(entity_id, STATE_IDLE, attributes=attributes) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_NEXT, + ) + await hass.async_block_till_done() + + # Test feature not supported + hass.states.async_set( + entity_id, + STATE_PLAYING, + attributes={ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature(0)}, + ) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_NEXT, + {"name": {"value": "test media player"}}, + ) + await hass.async_block_till_done() + async def test_volume_media_player_intent(hass: HomeAssistant) -> None: """Test HassSetVolume intent for media players.""" await media_player_intent.async_setup_intents(hass) entity_id = f"{DOMAIN}.test_media_player" - hass.states.async_set(entity_id, STATE_IDLE) + attributes = {ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature.VOLUME_SET} + + hass.states.async_set(entity_id, STATE_PLAYING, attributes=attributes) calls = async_mock_service(hass, DOMAIN, SERVICE_VOLUME_SET) response = await intent.async_handle( hass, "test", media_player_intent.INTENT_SET_VOLUME, - {"name": {"value": "test media player"}, "volume_level": {"value": 50}}, + {"volume_level": {"value": 50}}, ) await hass.async_block_till_done() @@ -109,3 +197,321 @@ async def test_volume_media_player_intent(hass: HomeAssistant) -> None: assert call.domain == DOMAIN assert call.service == SERVICE_VOLUME_SET assert call.data == {"entity_id": entity_id, "volume_level": 0.5} + + # Test if not playing + hass.states.async_set(entity_id, STATE_IDLE, attributes=attributes) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_SET_VOLUME, + {"volume_level": {"value": 50}}, + ) + await hass.async_block_till_done() + + # Test feature not supported + hass.states.async_set( + entity_id, + STATE_PLAYING, + attributes={ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature(0)}, + ) + + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_SET_VOLUME, + {"volume_level": {"value": 50}}, + ) + await hass.async_block_till_done() + + +async def test_multiple_media_players( + hass: HomeAssistant, + area_registry: ar.AreaRegistry, + entity_registry: er.EntityRegistry, + floor_registry: fr.FloorRegistry, +) -> None: + """Test HassMedia* intents with multiple media players.""" + await media_player_intent.async_setup_intents(hass) + + attributes = { + ATTR_SUPPORTED_FEATURES: MediaPlayerEntityFeature.PAUSE + | MediaPlayerEntityFeature.NEXT_TRACK + | MediaPlayerEntityFeature.VOLUME_SET + } + + # House layout + # Floor 1 (ground): + # - Kitchen + # - Smart speaker + # - Living room + # - TV + # - Smart speaker + # Floor 2 (upstairs): + # - Bedroom + # - TV + # - Smart speaker + # - Bathroom + # - Smart speaker + + # Floor 1 + floor_1 = floor_registry.async_create("first floor", aliases={"ground"}) + area_kitchen = area_registry.async_get_or_create("kitchen") + area_kitchen = area_registry.async_update( + area_kitchen.id, floor_id=floor_1.floor_id + ) + area_living_room = area_registry.async_get_or_create("living room") + area_living_room = area_registry.async_update( + area_living_room.id, floor_id=floor_1.floor_id + ) + + kitchen_smart_speaker = entity_registry.async_get_or_create( + "media_player", "test", "kitchen_smart_speaker" + ) + kitchen_smart_speaker = entity_registry.async_update_entity( + kitchen_smart_speaker.entity_id, name="smart speaker", area_id=area_kitchen.id + ) + hass.states.async_set( + kitchen_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + living_room_smart_speaker = entity_registry.async_get_or_create( + "media_player", "test", "living_room_smart_speaker" + ) + living_room_smart_speaker = entity_registry.async_update_entity( + living_room_smart_speaker.entity_id, + name="smart speaker", + area_id=area_living_room.id, + ) + hass.states.async_set( + living_room_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + living_room_tv = entity_registry.async_get_or_create( + "media_player", "test", "living_room_tv" + ) + living_room_tv = entity_registry.async_update_entity( + living_room_tv.entity_id, name="TV", area_id=area_living_room.id + ) + hass.states.async_set( + living_room_tv.entity_id, STATE_PLAYING, attributes=attributes + ) + + # Floor 2 + floor_2 = floor_registry.async_create("second floor", aliases={"upstairs"}) + area_bedroom = area_registry.async_get_or_create("bedroom") + area_bedroom = area_registry.async_update( + area_bedroom.id, floor_id=floor_2.floor_id + ) + area_bathroom = area_registry.async_get_or_create("bathroom") + area_bathroom = area_registry.async_update( + area_bathroom.id, floor_id=floor_2.floor_id + ) + + bedroom_tv = entity_registry.async_get_or_create( + "media_player", "test", "bedroom_tv" + ) + bedroom_tv = entity_registry.async_update_entity( + bedroom_tv.entity_id, name="TV", area_id=area_bedroom.id + ) + hass.states.async_set(bedroom_tv.entity_id, STATE_PLAYING, attributes=attributes) + + bedroom_smart_speaker = entity_registry.async_get_or_create( + "media_player", "test", "bedroom_smart_speaker" + ) + bedroom_smart_speaker = entity_registry.async_update_entity( + bedroom_smart_speaker.entity_id, name="smart speaker", area_id=area_bedroom.id + ) + hass.states.async_set( + bedroom_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + bathroom_smart_speaker = entity_registry.async_get_or_create( + "media_player", "test", "bathroom_smart_speaker" + ) + bathroom_smart_speaker = entity_registry.async_update_entity( + bathroom_smart_speaker.entity_id, name="smart speaker", area_id=area_bathroom.id + ) + hass.states.async_set( + bathroom_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + # ----- + + # There are multiple TV's currently playing + with pytest.raises(intent.MatchFailedError): + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + {"name": {"value": "TV"}}, + ) + await hass.async_block_till_done() + + # Pause the upstairs TV + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PAUSE) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + {"name": {"value": "TV"}, "floor": {"value": "upstairs"}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": bedroom_tv.entity_id} + hass.states.async_set(bedroom_tv.entity_id, STATE_PAUSED, attributes=attributes) + + # Now we can pause the only playing TV (living room) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PAUSE) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + {"name": {"value": "TV"}}, + ) + + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": living_room_tv.entity_id} + hass.states.async_set(living_room_tv.entity_id, STATE_PAUSED, attributes=attributes) + + # Unpause the kitchen smart speaker (explicit area) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PLAY) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_UNPAUSE, + {"name": {"value": "smart speaker"}, "area": {"value": "kitchen"}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": kitchen_smart_speaker.entity_id} + hass.states.async_set( + kitchen_smart_speaker.entity_id, STATE_PLAYING, attributes=attributes + ) + + # Unpause living room smart speaker (context area) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PLAY) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_UNPAUSE, + { + "name": {"value": "smart speaker"}, + "preferred_area_id": {"value": area_living_room.id}, + }, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": living_room_smart_speaker.entity_id} + hass.states.async_set( + living_room_smart_speaker.entity_id, STATE_PLAYING, attributes=attributes + ) + + # Unpause all of the upstairs media players + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PLAY) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_UNPAUSE, + {"floor": {"value": "upstairs"}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 3 + assert {call.data["entity_id"] for call in calls} == { + bedroom_tv.entity_id, + bedroom_smart_speaker.entity_id, + bathroom_smart_speaker.entity_id, + } + for entity in (bedroom_tv, bedroom_smart_speaker, bathroom_smart_speaker): + hass.states.async_set(entity.entity_id, STATE_PLAYING, attributes=attributes) + + # Pause bedroom TV (context floor) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PAUSE) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + { + "name": {"value": "TV"}, + "preferred_floor_id": {"value": floor_2.floor_id}, + }, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": bedroom_tv.entity_id} + hass.states.async_set(bedroom_tv.entity_id, STATE_PAUSED, attributes=attributes) + + # Set volume in the bathroom + calls = async_mock_service(hass, DOMAIN, SERVICE_VOLUME_SET) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_SET_VOLUME, + {"area": {"value": "bathroom"}, "volume_level": {"value": 50}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == { + "entity_id": bathroom_smart_speaker.entity_id, + "volume_level": 0.5, + } + + # Next track in the kitchen (only media player that is playing on ground floor) + hass.states.async_set( + living_room_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_NEXT_TRACK) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_NEXT, + {"floor": {"value": "ground"}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": kitchen_smart_speaker.entity_id} + + # Pause the kitchen smart speaker (all ground floor media players are now paused) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PAUSE) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_PAUSE, + {"area": {"value": "kitchen"}}, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": kitchen_smart_speaker.entity_id} + + hass.states.async_set( + kitchen_smart_speaker.entity_id, STATE_PAUSED, attributes=attributes + ) + + # Unpause with no context (only kitchen should be resumed) + calls = async_mock_service(hass, DOMAIN, SERVICE_MEDIA_PLAY) + response = await intent.async_handle( + hass, + "test", + media_player_intent.INTENT_MEDIA_UNPAUSE, + ) + await hass.async_block_till_done() + assert response.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + assert calls[0].data == {"entity_id": kitchen_smart_speaker.entity_id} + + hass.states.async_set( + kitchen_smart_speaker.entity_id, STATE_PLAYING, attributes=attributes + ) diff --git a/tests/helpers/test_intent.py b/tests/helpers/test_intent.py index d77eb698205..5e54277b423 100644 --- a/tests/helpers/test_intent.py +++ b/tests/helpers/test_intent.py @@ -6,9 +6,13 @@ from unittest.mock import MagicMock, patch import pytest import voluptuous as vol -from homeassistant.components import conversation -from homeassistant.components.switch import SwitchDeviceClass -from homeassistant.const import ATTR_FRIENDLY_NAME +from homeassistant.components import conversation, light, switch +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + ATTR_FRIENDLY_NAME, + ATTR_SUPPORTED_FEATURES, +) from homeassistant.core import Context, HomeAssistant, State from homeassistant.helpers import ( area_registry as ar, @@ -20,13 +24,13 @@ from homeassistant.helpers import ( ) from homeassistant.setup import async_setup_component -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, async_mock_service class MockIntentHandler(intent.IntentHandler): """Provide a mock intent handler.""" - def __init__(self, slot_schema): + def __init__(self, slot_schema) -> None: """Initialize the mock handler.""" self.slot_schema = slot_schema @@ -73,7 +77,7 @@ async def test_async_match_states( entity_registry.async_update_entity( state2.entity_id, area_id=area_bedroom.id, - device_class=SwitchDeviceClass.OUTLET, + device_class=switch.SwitchDeviceClass.OUTLET, aliases={"kill switch"}, ) @@ -126,7 +130,7 @@ async def test_async_match_states( assert list( intent.async_match_states( hass, - device_classes={SwitchDeviceClass.OUTLET}, + device_classes={switch.SwitchDeviceClass.OUTLET}, area_name="bedroom", states=[state1, state2], ) @@ -162,6 +166,346 @@ async def test_async_match_states( ) +async def test_async_match_targets( + hass: HomeAssistant, + area_registry: ar.AreaRegistry, + entity_registry: er.EntityRegistry, + floor_registry: fr.FloorRegistry, + device_registry: dr.DeviceRegistry, +) -> None: + """Tests for async_match_targets function.""" + # Needed for exposure + assert await async_setup_component(hass, "homeassistant", {}) + + # House layout + # Floor 1 (ground): + # - Kitchen + # - Outlet + # - Bathroom + # - Light + # Floor 2 (upstairs) + # - Bedroom + # - Switch + # - Bathroom + # - Light + # Floor 3 (also upstairs) + # - Bedroom + # - Switch + # - Bathroom + # - Light + + # Floor 1 + floor_1 = floor_registry.async_create("first floor", aliases={"ground"}) + area_kitchen = area_registry.async_get_or_create("kitchen") + area_kitchen = area_registry.async_update( + area_kitchen.id, floor_id=floor_1.floor_id + ) + area_bathroom_1 = area_registry.async_get_or_create("first floor bathroom") + area_bathroom_1 = area_registry.async_update( + area_bathroom_1.id, aliases={"bathroom"}, floor_id=floor_1.floor_id + ) + + kitchen_outlet = entity_registry.async_get_or_create( + "switch", "test", "kitchen_outlet" + ) + kitchen_outlet = entity_registry.async_update_entity( + kitchen_outlet.entity_id, + name="kitchen outlet", + device_class=switch.SwitchDeviceClass.OUTLET, + area_id=area_kitchen.id, + ) + state_kitchen_outlet = State(kitchen_outlet.entity_id, "on") + + bathroom_light_1 = entity_registry.async_get_or_create( + "light", "test", "bathroom_light_1" + ) + bathroom_light_1 = entity_registry.async_update_entity( + bathroom_light_1.entity_id, + name="bathroom light", + aliases={"overhead light"}, + area_id=area_bathroom_1.id, + ) + state_bathroom_light_1 = State(bathroom_light_1.entity_id, "off") + + # Floor 2 + floor_2 = floor_registry.async_create("second floor", aliases={"upstairs"}) + area_bedroom_2 = area_registry.async_get_or_create("bedroom") + area_bedroom_2 = area_registry.async_update( + area_bedroom_2.id, floor_id=floor_2.floor_id + ) + area_bathroom_2 = area_registry.async_get_or_create("second floor bathroom") + area_bathroom_2 = area_registry.async_update( + area_bathroom_2.id, aliases={"bathroom"}, floor_id=floor_2.floor_id + ) + + bedroom_switch_2 = entity_registry.async_get_or_create( + "switch", "test", "bedroom_switch_2" + ) + bedroom_switch_2 = entity_registry.async_update_entity( + bedroom_switch_2.entity_id, + name="second floor bedroom switch", + area_id=area_bedroom_2.id, + ) + state_bedroom_switch_2 = State( + bedroom_switch_2.entity_id, + "off", + ) + + bathroom_light_2 = entity_registry.async_get_or_create( + "light", "test", "bathroom_light_2" + ) + bathroom_light_2 = entity_registry.async_update_entity( + bathroom_light_2.entity_id, + aliases={"bathroom light", "overhead light"}, + area_id=area_bathroom_2.id, + supported_features=light.LightEntityFeature.EFFECT, + ) + state_bathroom_light_2 = State(bathroom_light_2.entity_id, "off") + + # Floor 3 + floor_3 = floor_registry.async_create("third floor", aliases={"upstairs"}) + area_bedroom_3 = area_registry.async_get_or_create("bedroom") + area_bedroom_3 = area_registry.async_update( + area_bedroom_3.id, floor_id=floor_3.floor_id + ) + area_bathroom_3 = area_registry.async_get_or_create("third floor bathroom") + area_bathroom_3 = area_registry.async_update( + area_bathroom_3.id, aliases={"bathroom"}, floor_id=floor_3.floor_id + ) + + bedroom_switch_3 = entity_registry.async_get_or_create( + "switch", "test", "bedroom_switch_3" + ) + bedroom_switch_3 = entity_registry.async_update_entity( + bedroom_switch_3.entity_id, + name="third floor bedroom switch", + area_id=area_bedroom_3.id, + ) + state_bedroom_switch_3 = State( + bedroom_switch_3.entity_id, + "off", + attributes={ATTR_DEVICE_CLASS: switch.SwitchDeviceClass.OUTLET}, + ) + + bathroom_light_3 = entity_registry.async_get_or_create( + "light", "test", "bathroom_light_3" + ) + bathroom_light_3 = entity_registry.async_update_entity( + bathroom_light_3.entity_id, + name="overhead light", + area_id=area_bathroom_3.id, + ) + state_bathroom_light_3 = State( + bathroom_light_3.entity_id, + "on", + attributes={ + ATTR_FRIENDLY_NAME: "bathroom light", + ATTR_SUPPORTED_FEATURES: light.LightEntityFeature.EFFECT, + }, + ) + + # ----- + bathroom_light_states = [ + state_bathroom_light_1, + state_bathroom_light_2, + state_bathroom_light_3, + ] + states = [ + *bathroom_light_states, + state_kitchen_outlet, + state_bedroom_switch_2, + state_bedroom_switch_3, + ] + + # Not a unique name + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light"), + states=states, + ) + assert not result.is_match + assert result.no_match_reason == intent.MatchFailedReason.DUPLICATE_NAME + assert result.no_match_name == "bathroom light" + + # Works with duplicate names allowed + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + name="bathroom light", allow_duplicate_names=True + ), + states=states, + ) + assert result.is_match + assert {s.entity_id for s in result.states} == { + s.entity_id for s in bathroom_light_states + } + + # Also works when name is not a constraint + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(domains={"light"}), + states=states, + ) + assert result.is_match + assert {s.entity_id for s in result.states} == { + s.entity_id for s in bathroom_light_states + } + + # We can disambiguate by preferred floor (from context) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light"), + intent.MatchTargetsPreferences(floor_id=floor_3.floor_id), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_3.entity_id + + # Also disambiguate by preferred area (from context) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light"), + intent.MatchTargetsPreferences(area_id=area_bathroom_2.id), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_2.entity_id + + # Disambiguate by floor name, if unique + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light", floor_name="ground"), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_1.entity_id + + # Doesn't work if floor name/alias is not unique + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light", floor_name="upstairs"), + states=states, + ) + assert not result.is_match + assert result.no_match_reason == intent.MatchFailedReason.DUPLICATE_NAME + + # Disambiguate by area name, if unique + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + name="bathroom light", area_name="first floor bathroom" + ), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_1.entity_id + + # Doesn't work if area name/alias is not unique + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(name="bathroom light", area_name="bathroom"), + states=states, + ) + assert not result.is_match + assert result.no_match_reason == intent.MatchFailedReason.DUPLICATE_NAME + + # Does work if floor/area name combo is unique + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + name="bathroom light", area_name="bathroom", floor_name="ground" + ), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_1.entity_id + + # Doesn't work if area is not part of the floor + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + name="bathroom light", + area_name="second floor bathroom", + floor_name="ground", + ), + states=states, + ) + assert not result.is_match + assert result.no_match_reason == intent.MatchFailedReason.AREA + + # Check state constraint (only third floor bathroom light is on) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(domains={"light"}, states={"on"}), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_3.entity_id + + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + domains={"light"}, states={"on"}, floor_name="ground" + ), + states=states, + ) + assert not result.is_match + + # Check assistant constraint (exposure) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(assistant="test"), + states=states, + ) + assert not result.is_match + + async_expose_entity(hass, "test", bathroom_light_1.entity_id, True) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints(assistant="test"), + states=states, + ) + assert result.is_match + assert len(result.states) == 1 + assert result.states[0].entity_id == bathroom_light_1.entity_id + + # Check device class constraint + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + domains={"switch"}, device_classes={switch.SwitchDeviceClass.OUTLET} + ), + states=states, + ) + assert result.is_match + assert len(result.states) == 2 + assert {s.entity_id for s in result.states} == { + kitchen_outlet.entity_id, + bedroom_switch_3.entity_id, + } + + # Check features constraint (second and third floor bathroom lights have effects) + result = intent.async_match_targets( + hass, + intent.MatchTargetsConstraints( + domains={"light"}, features=light.LightEntityFeature.EFFECT + ), + states=states, + ) + assert result.is_match + assert len(result.states) == 2 + assert {s.entity_id for s in result.states} == { + bathroom_light_2.entity_id, + bathroom_light_3.entity_id, + } + + async def test_match_device_area( hass: HomeAssistant, area_registry: ar.AreaRegistry, @@ -353,24 +697,72 @@ async def test_validate_then_run_in_background(hass: HomeAssistant) -> None: async def test_invalid_area_floor_names(hass: HomeAssistant) -> None: - """Test that we throw an intent handle error with invalid area/floor names.""" + """Test that we throw an appropriate errors with invalid area/floor names.""" handler = intent.ServiceIntentHandler( "TestType", "light", "turn_on", "Turned {} on" ) intent.async_register(hass, handler) - with pytest.raises(intent.IntentHandleError): + with pytest.raises(intent.MatchFailedError) as err: await intent.async_handle( hass, "test", "TestType", slots={"area": {"value": "invalid area"}}, ) + assert err.value.result.no_match_reason == intent.MatchFailedReason.INVALID_AREA - with pytest.raises(intent.IntentHandleError): + with pytest.raises(intent.MatchFailedError) as err: await intent.async_handle( hass, "test", "TestType", slots={"floor": {"value": "invalid floor"}}, ) + assert ( + err.value.result.no_match_reason == intent.MatchFailedReason.INVALID_FLOOR + ) + + +async def test_service_intent_handler_required_domains(hass: HomeAssistant) -> None: + """Test that required_domains restricts the domain of a ServiceIntentHandler.""" + hass.states.async_set("light.kitchen", "off") + hass.states.async_set("switch.bedroom", "off") + + calls = async_mock_service(hass, "homeassistant", "turn_on") + handler = intent.ServiceIntentHandler( + "TestType", + "homeassistant", + "turn_on", + "Turned {} on", + required_domains={"light"}, + ) + intent.async_register(hass, handler) + + # Should work fine + result = await intent.async_handle( + hass, + "test", + "TestType", + slots={"name": {"value": "kitchen"}, "domain": {"value": "light"}}, + ) + assert result.response_type == intent.IntentResponseType.ACTION_DONE + assert len(calls) == 1 + + # Fails because the intent handler is restricted to lights only + with pytest.raises(intent.MatchFailedError): + await intent.async_handle( + hass, + "test", + "TestType", + slots={"name": {"value": "bedroom"}}, + ) + + # Still fails even if we provide the domain + with pytest.raises(intent.MatchFailedError): + await intent.async_handle( + hass, + "test", + "TestType", + slots={"name": {"value": "bedroom"}, "domain": {"value": "switch"}}, + )