diff --git a/homeassistant/components/__init__.py b/homeassistant/components/__init__.py index 690b38b4871..839a66af25d 100644 --- a/homeassistant/components/__init__.py +++ b/homeassistant/components/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations import logging from homeassistant.core import HomeAssistant, split_entity_id +from homeassistant.helpers.group import expand_entity_ids _LOGGER = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def is_on(hass: HomeAssistant, entity_id: str | None = None) -> bool: If there is no entity id given we will check all. """ if entity_id: - entity_ids = hass.components.group.expand_entity_ids([entity_id]) + entity_ids = expand_entity_ids(hass, [entity_id]) else: entity_ids = hass.states.entity_ids() diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index a2a61b3016a..894a20629ee 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -3,10 +3,10 @@ from __future__ import annotations from abc import abstractmethod import asyncio -from collections.abc import Callable, Collection, Iterable, Mapping +from collections.abc import Callable, Collection, Mapping from contextvars import ContextVar import logging -from typing import Any, Protocol, cast +from typing import Any, Protocol import voluptuous as vol @@ -19,8 +19,6 @@ from homeassistant.const import ( CONF_ENTITIES, CONF_ICON, CONF_NAME, - ENTITY_MATCH_ALL, - ENTITY_MATCH_NONE, SERVICE_RELOAD, STATE_OFF, STATE_ON, @@ -41,6 +39,10 @@ from homeassistant.helpers.event import ( EventStateChangedData, async_track_state_change_event, ) +from homeassistant.helpers.group import ( + expand_entity_ids as _expand_entity_ids, + get_entity_ids as _get_entity_ids, +) from homeassistant.helpers.integration_platform import ( async_process_integration_platforms, ) @@ -167,58 +169,9 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool: return False -@bind_hass -def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]: - """Return entity_ids with group entity ids replaced by their members. - - Async friendly. - """ - found_ids: list[str] = [] - for entity_id in entity_ids: - if not isinstance(entity_id, str) or entity_id in ( - ENTITY_MATCH_NONE, - ENTITY_MATCH_ALL, - ): - continue - - entity_id = entity_id.lower() - # If entity_id points at a group, expand it - if entity_id.startswith(ENTITY_PREFIX): - child_entities = get_entity_ids(hass, entity_id) - if entity_id in child_entities: - child_entities = list(child_entities) - child_entities.remove(entity_id) - found_ids.extend( - ent_id - for ent_id in expand_entity_ids(hass, child_entities) - if ent_id not in found_ids - ) - elif entity_id not in found_ids: - found_ids.append(entity_id) - - return found_ids - - -@bind_hass -def get_entity_ids( - hass: HomeAssistant, entity_id: str, domain_filter: str | None = None -) -> list[str]: - """Get members of this group. - - Async friendly. - """ - group = hass.states.get(entity_id) - - if not group or ATTR_ENTITY_ID not in group.attributes: - return [] - - entity_ids = group.attributes[ATTR_ENTITY_ID] - if not domain_filter: - return cast(list[str], entity_ids) - - domain_filter = f"{domain_filter.lower()}." - - return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)] +# expand_entity_ids and get_entity_ids are for backwards compatibility only +expand_entity_ids = bind_hass(_expand_entity_ids) +get_entity_ids = bind_hass(_get_entity_ids) @bind_hass diff --git a/homeassistant/components/zwave_js/helpers.py b/homeassistant/components/zwave_js/helpers.py index a211832039b..c8eb02ad6cb 100644 --- a/homeassistant/components/zwave_js/helpers.py +++ b/homeassistant/components/zwave_js/helpers.py @@ -25,7 +25,6 @@ from zwave_js_server.model.value import ( get_value_id_str, ) -from homeassistant.components.group import expand_entity_ids from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import ( @@ -39,6 +38,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.device_registry import DeviceInfo +from homeassistant.helpers.group import expand_entity_ids from homeassistant.helpers.typing import ConfigType from .const import ( diff --git a/homeassistant/components/zwave_js/services.py b/homeassistant/components/zwave_js/services.py index 9b4f9827c1d..e8ef1df4b96 100644 --- a/homeassistant/components/zwave_js/services.py +++ b/homeassistant/components/zwave_js/services.py @@ -25,13 +25,13 @@ from zwave_js_server.util.node import ( async_set_config_parameter, ) -from homeassistant.components.group import expand_entity_ids from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, entity_registry as er import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.group import expand_entity_ids from . import const from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA diff --git a/homeassistant/helpers/group.py b/homeassistant/helpers/group.py new file mode 100644 index 00000000000..437df226118 --- /dev/null +++ b/homeassistant/helpers/group.py @@ -0,0 +1,58 @@ +"""Helper for groups.""" +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE +from homeassistant.core import HomeAssistant + +ENTITY_PREFIX = "group." + + +def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]: + """Return entity_ids with group entity ids replaced by their members. + + Async friendly. + """ + found_ids: list[str] = [] + for entity_id in entity_ids: + if not isinstance(entity_id, str) or entity_id in ( + ENTITY_MATCH_NONE, + ENTITY_MATCH_ALL, + ): + continue + + entity_id = entity_id.lower() + # If entity_id points at a group, expand it + if entity_id.startswith(ENTITY_PREFIX): + child_entities = get_entity_ids(hass, entity_id) + if entity_id in child_entities: + child_entities = list(child_entities) + child_entities.remove(entity_id) + found_ids.extend( + ent_id + for ent_id in expand_entity_ids(hass, child_entities) + if ent_id not in found_ids + ) + elif entity_id not in found_ids: + found_ids.append(entity_id) + + return found_ids + + +def get_entity_ids( + hass: HomeAssistant, entity_id: str, domain_filter: str | None = None +) -> list[str]: + """Get members of this group. + + Async friendly. + """ + group = hass.states.get(entity_id) + if not group or ATTR_ENTITY_ID not in group.attributes: + return [] + entity_ids: list[str] = group.attributes[ATTR_ENTITY_ID] + if not domain_filter: + return entity_ids + domain_filter = f"{domain_filter.lower()}." + return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)] diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 59fd061d8c9..4813a54ac8b 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -53,6 +53,7 @@ from . import ( template, translation, ) +from .group import expand_entity_ids from .selector import TargetSelector from .typing import ConfigType, TemplateVarsType @@ -459,9 +460,9 @@ def async_extract_referenced_entity_ids( if not selector.has_any_selector: return selected - entity_ids = selector.entity_ids + entity_ids: set[str] | list[str] = selector.entity_ids if expand_group: - entity_ids = hass.components.group.expand_entity_ids(entity_ids) + entity_ids = expand_entity_ids(hass, entity_ids) selected.referenced.update(entity_ids) diff --git a/tests/helpers/test_group.py b/tests/helpers/test_group.py new file mode 100644 index 00000000000..b1300009607 --- /dev/null +++ b/tests/helpers/test_group.py @@ -0,0 +1,107 @@ +"""Test the group helper.""" + + +from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant +from homeassistant.helpers import group + + +async def test_expand_entity_ids(hass: HomeAssistant) -> None: + """Test expand_entity_ids method.""" + hass.states.async_set("light.Bowl", STATE_ON) + hass.states.async_set("light.Ceiling", STATE_OFF) + hass.states.async_set( + "group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]} + ) + state = hass.states.get("group.init_group") + assert state is not None + assert state.attributes[ATTR_ENTITY_ID] == ["light.bowl", "light.ceiling"] + + assert sorted(group.expand_entity_ids(hass, ["group.init_group"])) == [ + "light.bowl", + "light.ceiling", + ] + assert sorted(group.expand_entity_ids(hass, ["group.INIT_group"])) == [ + "light.bowl", + "light.ceiling", + ] + + +async def test_expand_entity_ids_does_not_return_duplicates( + hass: HomeAssistant, +) -> None: + """Test that expand_entity_ids does not return duplicates.""" + hass.states.async_set("light.Bowl", STATE_ON) + hass.states.async_set("light.Ceiling", STATE_OFF) + hass.states.async_set( + "group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]} + ) + + assert sorted( + group.expand_entity_ids(hass, ["group.init_group", "light.Ceiling"]) + ) == ["light.bowl", "light.ceiling"] + + assert sorted( + group.expand_entity_ids(hass, ["light.bowl", "group.init_group"]) + ) == ["light.bowl", "light.ceiling"] + + +async def test_expand_entity_ids_recursive(hass: HomeAssistant) -> None: + """Test expand_entity_ids method with a group that contains itself.""" + hass.states.async_set("light.Bowl", STATE_ON) + hass.states.async_set("light.Ceiling", STATE_OFF) + hass.states.async_set( + "group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]} + ) + + hass.states.async_set( + "group.rec_group", + STATE_ON, + {ATTR_ENTITY_ID: ["group.init_group", "light.ceiling"]}, + ) + + assert sorted(group.expand_entity_ids(hass, ["group.rec_group"])) == [ + "light.bowl", + "light.ceiling", + ] + + +async def test_expand_entity_ids_ignores_non_strings(hass: HomeAssistant) -> None: + """Test that non string elements in lists are ignored.""" + assert group.expand_entity_ids(hass, [5, True]) == [] + + +async def test_get_entity_ids(hass: HomeAssistant) -> None: + """Test get_entity_ids method.""" + hass.states.async_set("light.Bowl", STATE_ON) + hass.states.async_set("light.Ceiling", STATE_OFF) + hass.states.async_set( + "group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]} + ) + + assert sorted(group.get_entity_ids(hass, "group.init_group")) == [ + "light.bowl", + "light.ceiling", + ] + + +async def test_get_entity_ids_with_domain_filter(hass: HomeAssistant) -> None: + """Test if get_entity_ids works with a domain_filter.""" + hass.states.async_set("switch.AC", STATE_OFF) + hass.states.async_set( + "group.mixed_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "switch.ac"]} + ) + + assert group.get_entity_ids(hass, "group.mixed_group", domain_filter="switch") == [ + "switch.ac" + ] + + +async def test_get_entity_ids_with_non_existing_group_name(hass: HomeAssistant) -> None: + """Test get_entity_ids with a non existing group.""" + assert group.get_entity_ids(hass, "non_existing") == [] + + +async def test_get_entity_ids_with_non_group_state(hass: HomeAssistant) -> None: + """Test get_entity_ids with a non group state.""" + assert group.get_entity_ids(hass, "switch.AC") == []