Move group helpers into their own module (#106924)
This gets rid of the legacy need to use bind_hass, and the expand function no longer looses typing.
This commit is contained in:
parent
6a02cadc13
commit
0695bf8988
7 changed files with 181 additions and 61 deletions
|
@ -11,6 +11,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
|
||||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
from homeassistant.helpers.group import expand_entity_ids
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -21,7 +22,7 @@ def is_on(hass: HomeAssistant, entity_id: str | None = None) -> bool:
|
|||
If there is no entity id given we will check all.
|
||||
"""
|
||||
if entity_id:
|
||||
entity_ids = hass.components.group.expand_entity_ids([entity_id])
|
||||
entity_ids = expand_entity_ids(hass, [entity_id])
|
||||
else:
|
||||
entity_ids = hass.states.entity_ids()
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Collection, Iterable, Mapping
|
||||
from collections.abc import Callable, Collection, Mapping
|
||||
from contextvars import ContextVar
|
||||
import logging
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -19,8 +19,6 @@ from homeassistant.const import (
|
|||
CONF_ENTITIES,
|
||||
CONF_ICON,
|
||||
CONF_NAME,
|
||||
ENTITY_MATCH_ALL,
|
||||
ENTITY_MATCH_NONE,
|
||||
SERVICE_RELOAD,
|
||||
STATE_OFF,
|
||||
STATE_ON,
|
||||
|
@ -41,6 +39,10 @@ from homeassistant.helpers.event import (
|
|||
EventStateChangedData,
|
||||
async_track_state_change_event,
|
||||
)
|
||||
from homeassistant.helpers.group import (
|
||||
expand_entity_ids as _expand_entity_ids,
|
||||
get_entity_ids as _get_entity_ids,
|
||||
)
|
||||
from homeassistant.helpers.integration_platform import (
|
||||
async_process_integration_platforms,
|
||||
)
|
||||
|
@ -167,58 +169,9 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@bind_hass
|
||||
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
|
||||
"""Return entity_ids with group entity ids replaced by their members.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
found_ids: list[str] = []
|
||||
for entity_id in entity_ids:
|
||||
if not isinstance(entity_id, str) or entity_id in (
|
||||
ENTITY_MATCH_NONE,
|
||||
ENTITY_MATCH_ALL,
|
||||
):
|
||||
continue
|
||||
|
||||
entity_id = entity_id.lower()
|
||||
# If entity_id points at a group, expand it
|
||||
if entity_id.startswith(ENTITY_PREFIX):
|
||||
child_entities = get_entity_ids(hass, entity_id)
|
||||
if entity_id in child_entities:
|
||||
child_entities = list(child_entities)
|
||||
child_entities.remove(entity_id)
|
||||
found_ids.extend(
|
||||
ent_id
|
||||
for ent_id in expand_entity_ids(hass, child_entities)
|
||||
if ent_id not in found_ids
|
||||
)
|
||||
elif entity_id not in found_ids:
|
||||
found_ids.append(entity_id)
|
||||
|
||||
return found_ids
|
||||
|
||||
|
||||
@bind_hass
|
||||
def get_entity_ids(
|
||||
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
|
||||
) -> list[str]:
|
||||
"""Get members of this group.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
group = hass.states.get(entity_id)
|
||||
|
||||
if not group or ATTR_ENTITY_ID not in group.attributes:
|
||||
return []
|
||||
|
||||
entity_ids = group.attributes[ATTR_ENTITY_ID]
|
||||
if not domain_filter:
|
||||
return cast(list[str], entity_ids)
|
||||
|
||||
domain_filter = f"{domain_filter.lower()}."
|
||||
|
||||
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
|
||||
# expand_entity_ids and get_entity_ids are for backwards compatibility only
|
||||
expand_entity_ids = bind_hass(_expand_entity_ids)
|
||||
get_entity_ids = bind_hass(_get_entity_ids)
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
|
|
@ -25,7 +25,6 @@ from zwave_js_server.model.value import (
|
|||
get_value_id_str,
|
||||
)
|
||||
|
||||
from homeassistant.components.group import expand_entity_ids
|
||||
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.const import (
|
||||
|
@ -39,6 +38,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.device_registry import DeviceInfo
|
||||
from homeassistant.helpers.group import expand_entity_ids
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
|
|
|
@ -25,13 +25,13 @@ from zwave_js_server.util.node import (
|
|||
async_set_config_parameter,
|
||||
)
|
||||
|
||||
from homeassistant.components.group import expand_entity_ids
|
||||
from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.group import expand_entity_ids
|
||||
|
||||
from . import const
|
||||
from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA
|
||||
|
|
58
homeassistant/helpers/group.py
Normal file
58
homeassistant/helpers/group.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
"""Helper for groups."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
ENTITY_PREFIX = "group."
|
||||
|
||||
|
||||
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
|
||||
"""Return entity_ids with group entity ids replaced by their members.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
found_ids: list[str] = []
|
||||
for entity_id in entity_ids:
|
||||
if not isinstance(entity_id, str) or entity_id in (
|
||||
ENTITY_MATCH_NONE,
|
||||
ENTITY_MATCH_ALL,
|
||||
):
|
||||
continue
|
||||
|
||||
entity_id = entity_id.lower()
|
||||
# If entity_id points at a group, expand it
|
||||
if entity_id.startswith(ENTITY_PREFIX):
|
||||
child_entities = get_entity_ids(hass, entity_id)
|
||||
if entity_id in child_entities:
|
||||
child_entities = list(child_entities)
|
||||
child_entities.remove(entity_id)
|
||||
found_ids.extend(
|
||||
ent_id
|
||||
for ent_id in expand_entity_ids(hass, child_entities)
|
||||
if ent_id not in found_ids
|
||||
)
|
||||
elif entity_id not in found_ids:
|
||||
found_ids.append(entity_id)
|
||||
|
||||
return found_ids
|
||||
|
||||
|
||||
def get_entity_ids(
|
||||
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
|
||||
) -> list[str]:
|
||||
"""Get members of this group.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
group = hass.states.get(entity_id)
|
||||
if not group or ATTR_ENTITY_ID not in group.attributes:
|
||||
return []
|
||||
entity_ids: list[str] = group.attributes[ATTR_ENTITY_ID]
|
||||
if not domain_filter:
|
||||
return entity_ids
|
||||
domain_filter = f"{domain_filter.lower()}."
|
||||
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
|
|
@ -53,6 +53,7 @@ from . import (
|
|||
template,
|
||||
translation,
|
||||
)
|
||||
from .group import expand_entity_ids
|
||||
from .selector import TargetSelector
|
||||
from .typing import ConfigType, TemplateVarsType
|
||||
|
||||
|
@ -459,9 +460,9 @@ def async_extract_referenced_entity_ids(
|
|||
if not selector.has_any_selector:
|
||||
return selected
|
||||
|
||||
entity_ids = selector.entity_ids
|
||||
entity_ids: set[str] | list[str] = selector.entity_ids
|
||||
if expand_group:
|
||||
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
||||
entity_ids = expand_entity_ids(hass, entity_ids)
|
||||
|
||||
selected.referenced.update(entity_ids)
|
||||
|
||||
|
|
107
tests/helpers/test_group.py
Normal file
107
tests/helpers/test_group.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
"""Test the group helper."""
|
||||
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import group
|
||||
|
||||
|
||||
async def test_expand_entity_ids(hass: HomeAssistant) -> None:
|
||||
"""Test expand_entity_ids method."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set(
|
||||
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||
)
|
||||
state = hass.states.get("group.init_group")
|
||||
assert state is not None
|
||||
assert state.attributes[ATTR_ENTITY_ID] == ["light.bowl", "light.ceiling"]
|
||||
|
||||
assert sorted(group.expand_entity_ids(hass, ["group.init_group"])) == [
|
||||
"light.bowl",
|
||||
"light.ceiling",
|
||||
]
|
||||
assert sorted(group.expand_entity_ids(hass, ["group.INIT_group"])) == [
|
||||
"light.bowl",
|
||||
"light.ceiling",
|
||||
]
|
||||
|
||||
|
||||
async def test_expand_entity_ids_does_not_return_duplicates(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that expand_entity_ids does not return duplicates."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set(
|
||||
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||
)
|
||||
|
||||
assert sorted(
|
||||
group.expand_entity_ids(hass, ["group.init_group", "light.Ceiling"])
|
||||
) == ["light.bowl", "light.ceiling"]
|
||||
|
||||
assert sorted(
|
||||
group.expand_entity_ids(hass, ["light.bowl", "group.init_group"])
|
||||
) == ["light.bowl", "light.ceiling"]
|
||||
|
||||
|
||||
async def test_expand_entity_ids_recursive(hass: HomeAssistant) -> None:
|
||||
"""Test expand_entity_ids method with a group that contains itself."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set(
|
||||
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||
)
|
||||
|
||||
hass.states.async_set(
|
||||
"group.rec_group",
|
||||
STATE_ON,
|
||||
{ATTR_ENTITY_ID: ["group.init_group", "light.ceiling"]},
|
||||
)
|
||||
|
||||
assert sorted(group.expand_entity_ids(hass, ["group.rec_group"])) == [
|
||||
"light.bowl",
|
||||
"light.ceiling",
|
||||
]
|
||||
|
||||
|
||||
async def test_expand_entity_ids_ignores_non_strings(hass: HomeAssistant) -> None:
|
||||
"""Test that non string elements in lists are ignored."""
|
||||
assert group.expand_entity_ids(hass, [5, True]) == []
|
||||
|
||||
|
||||
async def test_get_entity_ids(hass: HomeAssistant) -> None:
|
||||
"""Test get_entity_ids method."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set(
|
||||
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||
)
|
||||
|
||||
assert sorted(group.get_entity_ids(hass, "group.init_group")) == [
|
||||
"light.bowl",
|
||||
"light.ceiling",
|
||||
]
|
||||
|
||||
|
||||
async def test_get_entity_ids_with_domain_filter(hass: HomeAssistant) -> None:
|
||||
"""Test if get_entity_ids works with a domain_filter."""
|
||||
hass.states.async_set("switch.AC", STATE_OFF)
|
||||
hass.states.async_set(
|
||||
"group.mixed_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "switch.ac"]}
|
||||
)
|
||||
|
||||
assert group.get_entity_ids(hass, "group.mixed_group", domain_filter="switch") == [
|
||||
"switch.ac"
|
||||
]
|
||||
|
||||
|
||||
async def test_get_entity_ids_with_non_existing_group_name(hass: HomeAssistant) -> None:
|
||||
"""Test get_entity_ids with a non existing group."""
|
||||
assert group.get_entity_ids(hass, "non_existing") == []
|
||||
|
||||
|
||||
async def test_get_entity_ids_with_non_group_state(hass: HomeAssistant) -> None:
|
||||
"""Test get_entity_ids with a non group state."""
|
||||
assert group.get_entity_ids(hass, "switch.AC") == []
|
Loading…
Add table
Add a link
Reference in a new issue