Move group helpers into their own module (#106924)

This gets rid of the legacy need to use bind_hass, and
the expand function no longer looses typing.
This commit is contained in:
J. Nick Koston 2024-01-04 06:34:56 -10:00 committed by GitHub
parent 6a02cadc13
commit 0695bf8988
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 181 additions and 61 deletions

View file

@ -11,6 +11,7 @@ from __future__ import annotations
import logging
from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers.group import expand_entity_ids
_LOGGER = logging.getLogger(__name__)
@ -21,7 +22,7 @@ def is_on(hass: HomeAssistant, entity_id: str | None = None) -> bool:
If there is no entity id given we will check all.
"""
if entity_id:
entity_ids = hass.components.group.expand_entity_ids([entity_id])
entity_ids = expand_entity_ids(hass, [entity_id])
else:
entity_ids = hass.states.entity_ids()

View file

@ -3,10 +3,10 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Collection, Iterable, Mapping
from collections.abc import Callable, Collection, Mapping
from contextvars import ContextVar
import logging
from typing import Any, Protocol, cast
from typing import Any, Protocol
import voluptuous as vol
@ -19,8 +19,6 @@ from homeassistant.const import (
CONF_ENTITIES,
CONF_ICON,
CONF_NAME,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
SERVICE_RELOAD,
STATE_OFF,
STATE_ON,
@ -41,6 +39,10 @@ from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.group import (
expand_entity_ids as _expand_entity_ids,
get_entity_ids as _get_entity_ids,
)
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
@ -167,58 +169,9 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
return False
@bind_hass
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
"""Return entity_ids with group entity ids replaced by their members.
Async friendly.
"""
found_ids: list[str] = []
for entity_id in entity_ids:
if not isinstance(entity_id, str) or entity_id in (
ENTITY_MATCH_NONE,
ENTITY_MATCH_ALL,
):
continue
entity_id = entity_id.lower()
# If entity_id points at a group, expand it
if entity_id.startswith(ENTITY_PREFIX):
child_entities = get_entity_ids(hass, entity_id)
if entity_id in child_entities:
child_entities = list(child_entities)
child_entities.remove(entity_id)
found_ids.extend(
ent_id
for ent_id in expand_entity_ids(hass, child_entities)
if ent_id not in found_ids
)
elif entity_id not in found_ids:
found_ids.append(entity_id)
return found_ids
@bind_hass
def get_entity_ids(
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
) -> list[str]:
"""Get members of this group.
Async friendly.
"""
group = hass.states.get(entity_id)
if not group or ATTR_ENTITY_ID not in group.attributes:
return []
entity_ids = group.attributes[ATTR_ENTITY_ID]
if not domain_filter:
return cast(list[str], entity_ids)
domain_filter = f"{domain_filter.lower()}."
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
# expand_entity_ids and get_entity_ids are for backwards compatibility only
expand_entity_ids = bind_hass(_expand_entity_ids)
get_entity_ids = bind_hass(_get_entity_ids)
@bind_hass

View file

@ -25,7 +25,6 @@ from zwave_js_server.model.value import (
get_value_id_str,
)
from homeassistant.components.group import expand_entity_ids
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import (
@ -39,6 +38,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.group import expand_entity_ids
from homeassistant.helpers.typing import ConfigType
from .const import (

View file

@ -25,13 +25,13 @@ from zwave_js_server.util.node import (
async_set_config_parameter,
)
from homeassistant.components.group import expand_entity_ids
from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.group import expand_entity_ids
from . import const
from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA

View file

@ -0,0 +1,58 @@
"""Helper for groups."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE
from homeassistant.core import HomeAssistant
ENTITY_PREFIX = "group."
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
"""Return entity_ids with group entity ids replaced by their members.
Async friendly.
"""
found_ids: list[str] = []
for entity_id in entity_ids:
if not isinstance(entity_id, str) or entity_id in (
ENTITY_MATCH_NONE,
ENTITY_MATCH_ALL,
):
continue
entity_id = entity_id.lower()
# If entity_id points at a group, expand it
if entity_id.startswith(ENTITY_PREFIX):
child_entities = get_entity_ids(hass, entity_id)
if entity_id in child_entities:
child_entities = list(child_entities)
child_entities.remove(entity_id)
found_ids.extend(
ent_id
for ent_id in expand_entity_ids(hass, child_entities)
if ent_id not in found_ids
)
elif entity_id not in found_ids:
found_ids.append(entity_id)
return found_ids
def get_entity_ids(
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
) -> list[str]:
"""Get members of this group.
Async friendly.
"""
group = hass.states.get(entity_id)
if not group or ATTR_ENTITY_ID not in group.attributes:
return []
entity_ids: list[str] = group.attributes[ATTR_ENTITY_ID]
if not domain_filter:
return entity_ids
domain_filter = f"{domain_filter.lower()}."
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]

View file

@ -53,6 +53,7 @@ from . import (
template,
translation,
)
from .group import expand_entity_ids
from .selector import TargetSelector
from .typing import ConfigType, TemplateVarsType
@ -459,9 +460,9 @@ def async_extract_referenced_entity_ids(
if not selector.has_any_selector:
return selected
entity_ids = selector.entity_ids
entity_ids: set[str] | list[str] = selector.entity_ids
if expand_group:
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
entity_ids = expand_entity_ids(hass, entity_ids)
selected.referenced.update(entity_ids)

107
tests/helpers/test_group.py Normal file
View file

@ -0,0 +1,107 @@
"""Test the group helper."""
from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers import group
async def test_expand_entity_ids(hass: HomeAssistant) -> None:
"""Test expand_entity_ids method."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set(
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
)
state = hass.states.get("group.init_group")
assert state is not None
assert state.attributes[ATTR_ENTITY_ID] == ["light.bowl", "light.ceiling"]
assert sorted(group.expand_entity_ids(hass, ["group.init_group"])) == [
"light.bowl",
"light.ceiling",
]
assert sorted(group.expand_entity_ids(hass, ["group.INIT_group"])) == [
"light.bowl",
"light.ceiling",
]
async def test_expand_entity_ids_does_not_return_duplicates(
hass: HomeAssistant,
) -> None:
"""Test that expand_entity_ids does not return duplicates."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set(
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
)
assert sorted(
group.expand_entity_ids(hass, ["group.init_group", "light.Ceiling"])
) == ["light.bowl", "light.ceiling"]
assert sorted(
group.expand_entity_ids(hass, ["light.bowl", "group.init_group"])
) == ["light.bowl", "light.ceiling"]
async def test_expand_entity_ids_recursive(hass: HomeAssistant) -> None:
"""Test expand_entity_ids method with a group that contains itself."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set(
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
)
hass.states.async_set(
"group.rec_group",
STATE_ON,
{ATTR_ENTITY_ID: ["group.init_group", "light.ceiling"]},
)
assert sorted(group.expand_entity_ids(hass, ["group.rec_group"])) == [
"light.bowl",
"light.ceiling",
]
async def test_expand_entity_ids_ignores_non_strings(hass: HomeAssistant) -> None:
"""Test that non string elements in lists are ignored."""
assert group.expand_entity_ids(hass, [5, True]) == []
async def test_get_entity_ids(hass: HomeAssistant) -> None:
"""Test get_entity_ids method."""
hass.states.async_set("light.Bowl", STATE_ON)
hass.states.async_set("light.Ceiling", STATE_OFF)
hass.states.async_set(
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
)
assert sorted(group.get_entity_ids(hass, "group.init_group")) == [
"light.bowl",
"light.ceiling",
]
async def test_get_entity_ids_with_domain_filter(hass: HomeAssistant) -> None:
"""Test if get_entity_ids works with a domain_filter."""
hass.states.async_set("switch.AC", STATE_OFF)
hass.states.async_set(
"group.mixed_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "switch.ac"]}
)
assert group.get_entity_ids(hass, "group.mixed_group", domain_filter="switch") == [
"switch.ac"
]
async def test_get_entity_ids_with_non_existing_group_name(hass: HomeAssistant) -> None:
"""Test get_entity_ids with a non existing group."""
assert group.get_entity_ids(hass, "non_existing") == []
async def test_get_entity_ids_with_non_group_state(hass: HomeAssistant) -> None:
"""Test get_entity_ids with a non group state."""
assert group.get_entity_ids(hass, "switch.AC") == []