Move group helpers into their own module (#106924)
This gets rid of the legacy need to use bind_hass, and the expand function no longer looses typing.
This commit is contained in:
parent
6a02cadc13
commit
0695bf8988
7 changed files with 181 additions and 61 deletions
|
@ -11,6 +11,7 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, split_entity_id
|
from homeassistant.core import HomeAssistant, split_entity_id
|
||||||
|
from homeassistant.helpers.group import expand_entity_ids
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ def is_on(hass: HomeAssistant, entity_id: str | None = None) -> bool:
|
||||||
If there is no entity id given we will check all.
|
If there is no entity id given we will check all.
|
||||||
"""
|
"""
|
||||||
if entity_id:
|
if entity_id:
|
||||||
entity_ids = hass.components.group.expand_entity_ids([entity_id])
|
entity_ids = expand_entity_ids(hass, [entity_id])
|
||||||
else:
|
else:
|
||||||
entity_ids = hass.states.entity_ids()
|
entity_ids = hass.states.entity_ids()
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Collection, Iterable, Mapping
|
from collections.abc import Callable, Collection, Mapping
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Protocol, cast
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -19,8 +19,6 @@ from homeassistant.const import (
|
||||||
CONF_ENTITIES,
|
CONF_ENTITIES,
|
||||||
CONF_ICON,
|
CONF_ICON,
|
||||||
CONF_NAME,
|
CONF_NAME,
|
||||||
ENTITY_MATCH_ALL,
|
|
||||||
ENTITY_MATCH_NONE,
|
|
||||||
SERVICE_RELOAD,
|
SERVICE_RELOAD,
|
||||||
STATE_OFF,
|
STATE_OFF,
|
||||||
STATE_ON,
|
STATE_ON,
|
||||||
|
@ -41,6 +39,10 @@ from homeassistant.helpers.event import (
|
||||||
EventStateChangedData,
|
EventStateChangedData,
|
||||||
async_track_state_change_event,
|
async_track_state_change_event,
|
||||||
)
|
)
|
||||||
|
from homeassistant.helpers.group import (
|
||||||
|
expand_entity_ids as _expand_entity_ids,
|
||||||
|
get_entity_ids as _get_entity_ids,
|
||||||
|
)
|
||||||
from homeassistant.helpers.integration_platform import (
|
from homeassistant.helpers.integration_platform import (
|
||||||
async_process_integration_platforms,
|
async_process_integration_platforms,
|
||||||
)
|
)
|
||||||
|
@ -167,58 +169,9 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
# expand_entity_ids and get_entity_ids are for backwards compatibility only
|
||||||
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
|
expand_entity_ids = bind_hass(_expand_entity_ids)
|
||||||
"""Return entity_ids with group entity ids replaced by their members.
|
get_entity_ids = bind_hass(_get_entity_ids)
|
||||||
|
|
||||||
Async friendly.
|
|
||||||
"""
|
|
||||||
found_ids: list[str] = []
|
|
||||||
for entity_id in entity_ids:
|
|
||||||
if not isinstance(entity_id, str) or entity_id in (
|
|
||||||
ENTITY_MATCH_NONE,
|
|
||||||
ENTITY_MATCH_ALL,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
entity_id = entity_id.lower()
|
|
||||||
# If entity_id points at a group, expand it
|
|
||||||
if entity_id.startswith(ENTITY_PREFIX):
|
|
||||||
child_entities = get_entity_ids(hass, entity_id)
|
|
||||||
if entity_id in child_entities:
|
|
||||||
child_entities = list(child_entities)
|
|
||||||
child_entities.remove(entity_id)
|
|
||||||
found_ids.extend(
|
|
||||||
ent_id
|
|
||||||
for ent_id in expand_entity_ids(hass, child_entities)
|
|
||||||
if ent_id not in found_ids
|
|
||||||
)
|
|
||||||
elif entity_id not in found_ids:
|
|
||||||
found_ids.append(entity_id)
|
|
||||||
|
|
||||||
return found_ids
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
|
||||||
def get_entity_ids(
|
|
||||||
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
|
|
||||||
) -> list[str]:
|
|
||||||
"""Get members of this group.
|
|
||||||
|
|
||||||
Async friendly.
|
|
||||||
"""
|
|
||||||
group = hass.states.get(entity_id)
|
|
||||||
|
|
||||||
if not group or ATTR_ENTITY_ID not in group.attributes:
|
|
||||||
return []
|
|
||||||
|
|
||||||
entity_ids = group.attributes[ATTR_ENTITY_ID]
|
|
||||||
if not domain_filter:
|
|
||||||
return cast(list[str], entity_ids)
|
|
||||||
|
|
||||||
domain_filter = f"{domain_filter.lower()}."
|
|
||||||
|
|
||||||
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
|
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
|
|
@ -25,7 +25,6 @@ from zwave_js_server.model.value import (
|
||||||
get_value_id_str,
|
get_value_id_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components.group import expand_entity_ids
|
|
||||||
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
|
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
@ -39,6 +38,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
from homeassistant.helpers.device_registry import DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceInfo
|
||||||
|
from homeassistant.helpers.group import expand_entity_ids
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
|
|
@ -25,13 +25,13 @@ from zwave_js_server.util.node import (
|
||||||
async_set_config_parameter,
|
async_set_config_parameter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components.group import expand_entity_ids
|
|
||||||
from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID
|
from homeassistant.const import ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID
|
||||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
|
from homeassistant.helpers.group import expand_entity_ids
|
||||||
|
|
||||||
from . import const
|
from . import const
|
||||||
from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA
|
from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA
|
||||||
|
|
58
homeassistant/helpers/group.py
Normal file
58
homeassistant/helpers/group.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
"""Helper for groups."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
ENTITY_PREFIX = "group."
|
||||||
|
|
||||||
|
|
||||||
|
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
|
||||||
|
"""Return entity_ids with group entity ids replaced by their members.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
|
found_ids: list[str] = []
|
||||||
|
for entity_id in entity_ids:
|
||||||
|
if not isinstance(entity_id, str) or entity_id in (
|
||||||
|
ENTITY_MATCH_NONE,
|
||||||
|
ENTITY_MATCH_ALL,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
entity_id = entity_id.lower()
|
||||||
|
# If entity_id points at a group, expand it
|
||||||
|
if entity_id.startswith(ENTITY_PREFIX):
|
||||||
|
child_entities = get_entity_ids(hass, entity_id)
|
||||||
|
if entity_id in child_entities:
|
||||||
|
child_entities = list(child_entities)
|
||||||
|
child_entities.remove(entity_id)
|
||||||
|
found_ids.extend(
|
||||||
|
ent_id
|
||||||
|
for ent_id in expand_entity_ids(hass, child_entities)
|
||||||
|
if ent_id not in found_ids
|
||||||
|
)
|
||||||
|
elif entity_id not in found_ids:
|
||||||
|
found_ids.append(entity_id)
|
||||||
|
|
||||||
|
return found_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_ids(
|
||||||
|
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
|
||||||
|
) -> list[str]:
|
||||||
|
"""Get members of this group.
|
||||||
|
|
||||||
|
Async friendly.
|
||||||
|
"""
|
||||||
|
group = hass.states.get(entity_id)
|
||||||
|
if not group or ATTR_ENTITY_ID not in group.attributes:
|
||||||
|
return []
|
||||||
|
entity_ids: list[str] = group.attributes[ATTR_ENTITY_ID]
|
||||||
|
if not domain_filter:
|
||||||
|
return entity_ids
|
||||||
|
domain_filter = f"{domain_filter.lower()}."
|
||||||
|
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
|
|
@ -53,6 +53,7 @@ from . import (
|
||||||
template,
|
template,
|
||||||
translation,
|
translation,
|
||||||
)
|
)
|
||||||
|
from .group import expand_entity_ids
|
||||||
from .selector import TargetSelector
|
from .selector import TargetSelector
|
||||||
from .typing import ConfigType, TemplateVarsType
|
from .typing import ConfigType, TemplateVarsType
|
||||||
|
|
||||||
|
@ -459,9 +460,9 @@ def async_extract_referenced_entity_ids(
|
||||||
if not selector.has_any_selector:
|
if not selector.has_any_selector:
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
entity_ids = selector.entity_ids
|
entity_ids: set[str] | list[str] = selector.entity_ids
|
||||||
if expand_group:
|
if expand_group:
|
||||||
entity_ids = hass.components.group.expand_entity_ids(entity_ids)
|
entity_ids = expand_entity_ids(hass, entity_ids)
|
||||||
|
|
||||||
selected.referenced.update(entity_ids)
|
selected.referenced.update(entity_ids)
|
||||||
|
|
||||||
|
|
107
tests/helpers/test_group.py
Normal file
107
tests/helpers/test_group.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
"""Test the group helper."""
|
||||||
|
|
||||||
|
|
||||||
|
from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import group
|
||||||
|
|
||||||
|
|
||||||
|
async def test_expand_entity_ids(hass: HomeAssistant) -> None:
|
||||||
|
"""Test expand_entity_ids method."""
|
||||||
|
hass.states.async_set("light.Bowl", STATE_ON)
|
||||||
|
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||||
|
)
|
||||||
|
state = hass.states.get("group.init_group")
|
||||||
|
assert state is not None
|
||||||
|
assert state.attributes[ATTR_ENTITY_ID] == ["light.bowl", "light.ceiling"]
|
||||||
|
|
||||||
|
assert sorted(group.expand_entity_ids(hass, ["group.init_group"])) == [
|
||||||
|
"light.bowl",
|
||||||
|
"light.ceiling",
|
||||||
|
]
|
||||||
|
assert sorted(group.expand_entity_ids(hass, ["group.INIT_group"])) == [
|
||||||
|
"light.bowl",
|
||||||
|
"light.ceiling",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_expand_entity_ids_does_not_return_duplicates(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test that expand_entity_ids does not return duplicates."""
|
||||||
|
hass.states.async_set("light.Bowl", STATE_ON)
|
||||||
|
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
group.expand_entity_ids(hass, ["group.init_group", "light.Ceiling"])
|
||||||
|
) == ["light.bowl", "light.ceiling"]
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
group.expand_entity_ids(hass, ["light.bowl", "group.init_group"])
|
||||||
|
) == ["light.bowl", "light.ceiling"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_expand_entity_ids_recursive(hass: HomeAssistant) -> None:
|
||||||
|
"""Test expand_entity_ids method with a group that contains itself."""
|
||||||
|
hass.states.async_set("light.Bowl", STATE_ON)
|
||||||
|
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.rec_group",
|
||||||
|
STATE_ON,
|
||||||
|
{ATTR_ENTITY_ID: ["group.init_group", "light.ceiling"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted(group.expand_entity_ids(hass, ["group.rec_group"])) == [
|
||||||
|
"light.bowl",
|
||||||
|
"light.ceiling",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_expand_entity_ids_ignores_non_strings(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that non string elements in lists are ignored."""
|
||||||
|
assert group.expand_entity_ids(hass, [5, True]) == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_entity_ids(hass: HomeAssistant) -> None:
|
||||||
|
"""Test get_entity_ids method."""
|
||||||
|
hass.states.async_set("light.Bowl", STATE_ON)
|
||||||
|
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.init_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "light.ceiling"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted(group.get_entity_ids(hass, "group.init_group")) == [
|
||||||
|
"light.bowl",
|
||||||
|
"light.ceiling",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_entity_ids_with_domain_filter(hass: HomeAssistant) -> None:
|
||||||
|
"""Test if get_entity_ids works with a domain_filter."""
|
||||||
|
hass.states.async_set("switch.AC", STATE_OFF)
|
||||||
|
hass.states.async_set(
|
||||||
|
"group.mixed_group", STATE_ON, {ATTR_ENTITY_ID: ["light.bowl", "switch.ac"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert group.get_entity_ids(hass, "group.mixed_group", domain_filter="switch") == [
|
||||||
|
"switch.ac"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_entity_ids_with_non_existing_group_name(hass: HomeAssistant) -> None:
|
||||||
|
"""Test get_entity_ids with a non existing group."""
|
||||||
|
assert group.get_entity_ids(hass, "non_existing") == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_entity_ids_with_non_group_state(hass: HomeAssistant) -> None:
|
||||||
|
"""Test get_entity_ids with a non group state."""
|
||||||
|
assert group.get_entity_ids(hass, "switch.AC") == []
|
Loading…
Add table
Add a link
Reference in a new issue