Deduplicate entities derived from GroupEntity (#98893)

This commit is contained in:
Erik Montnemery 2023-08-23 19:20:58 +02:00 committed by GitHub
parent ee1b6a60a0
commit 3c10d0e1f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 249 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Collection, Iterable from collections.abc import Callable, Collection, Iterable, Mapping
from contextvars import ContextVar from contextvars import ContextVar
import logging import logging
from typing import Any, Protocol, cast from typing import Any, Protocol, cast
@ -473,9 +473,60 @@ class GroupEntity(Entity):
"""Representation of a Group of entities.""" """Representation of a Group of entities."""
_attr_should_poll = False _attr_should_poll = False
_entity_ids: list[str]
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
if event:
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Register listeners.""" """Register listeners."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
async def _update_at_start(_: HomeAssistant) -> None: async def _update_at_start(_: HomeAssistant) -> None:
self.async_update_group_state() self.async_update_group_state()
@ -493,9 +544,18 @@ class GroupEntity(Entity):
self.async_write_ha_state() self.async_write_ha_state()
@abstractmethod @abstractmethod
@callback
def async_update_group_state(self) -> None: def async_update_group_state(self) -> None:
"""Abstract method to update the entity.""" """Abstract method to update the entity."""
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
) -> None:
"""Update dictionaries with supported features."""
class Group(Entity): class Group(Entity):
"""Track a group of entity ids.""" """Track a group of entity ids."""

View file

@ -1,9 +1,6 @@
"""Platform allowing several binary sensor to be grouped into one binary sensor.""" """Platform allowing several binary sensor to be grouped into one binary sensor."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import ( from homeassistant.components.binary_sensor import (
@ -24,14 +21,10 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
@ -116,45 +109,6 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
if mode: if mode:
self.mode = all self.mode = all
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
@callback @callback
def async_update_group_state(self) -> None: def async_update_group_state(self) -> None:
"""Query all members and determine the binary sensor group state.""" """Query all members and determine the binary sensor group state."""

View file

@ -41,11 +41,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
from .util import attribute_equal, reduce_attribute from .util import attribute_equal, reduce_attribute
@ -112,7 +108,7 @@ class CoverGroup(GroupEntity, CoverEntity):
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None: def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a CoverGroup entity.""" """Initialize a CoverGroup entity."""
self._entities = entities self._entity_ids = entities
self._covers: dict[str, set[str]] = { self._covers: dict[str, set[str]] = {
KEY_OPEN_CLOSE: set(), KEY_OPEN_CLOSE: set(),
KEY_STOP: set(), KEY_STOP: set(),
@ -128,21 +124,11 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities} self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities}
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
@callback @callback
def async_update_supported_features( def async_update_supported_features(
self, self,
entity_id: str, entity_id: str,
new_state: State | None, new_state: State | None,
update_state: bool = True,
) -> None: ) -> None:
"""Update dictionaries with supported features.""" """Update dictionaries with supported features."""
if not new_state: if not new_state:
@ -150,8 +136,6 @@ class CoverGroup(GroupEntity, CoverEntity):
values.discard(entity_id) values.discard(entity_id)
for values in self._tilts.values(): for values in self._tilts.values():
values.discard(entity_id) values.discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
return return
features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
@ -182,25 +166,6 @@ class CoverGroup(GroupEntity, CoverEntity):
else: else:
self._tilts[KEY_POSITION].discard(entity_id) self._tilts[KEY_POSITION].discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)
await super().async_added_to_hass()
async def async_open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the covers up.""" """Move the covers up."""
data = {ATTR_ENTITY_ID: self._covers[KEY_OPEN_CLOSE]} data = {ATTR_ENTITY_ID: self._covers[KEY_OPEN_CLOSE]}
@ -278,7 +243,7 @@ class CoverGroup(GroupEntity, CoverEntity):
states = [ states = [
state.state state.state
for entity_id in self._entities for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None if (state := self.hass.states.get(entity_id)) is not None
] ]
@ -292,7 +257,7 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_is_closed = True self._attr_is_closed = True
self._attr_is_closing = False self._attr_is_closing = False
self._attr_is_opening = False self._attr_is_opening = False
for entity_id in self._entities: for entity_id in self._entity_ids:
if not (state := self.hass.states.get(entity_id)): if not (state := self.hass.states.get(entity_id)):
continue continue
if state.state == STATE_OPEN: if state.state == STATE_OPEN:
@ -347,7 +312,7 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_supported_features = supported_features self._attr_supported_features = supported_features
if not self._attr_assumed_state: if not self._attr_assumed_state:
for entity_id in self._entities: for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None: if (state := self.hass.states.get(entity_id)) is None:
continue continue
if state and state.attributes.get(ATTR_ASSUMED_STATE): if state and state.attributes.get(ATTR_ASSUMED_STATE):

View file

@ -38,11 +38,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
from .util import ( from .util import (
@ -108,7 +104,7 @@ class FanGroup(GroupEntity, FanEntity):
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None: def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a FanGroup entity.""" """Initialize a FanGroup entity."""
self._entities = entities self._entity_ids = entities
self._fans: dict[int, set[str]] = {flag: set() for flag in SUPPORTED_FLAGS} self._fans: dict[int, set[str]] = {flag: set() for flag in SUPPORTED_FLAGS}
self._percentage = None self._percentage = None
self._oscillating = None self._oscillating = None
@ -144,21 +140,11 @@ class FanGroup(GroupEntity, FanEntity):
"""Return whether or not the fan is currently oscillating.""" """Return whether or not the fan is currently oscillating."""
return self._oscillating return self._oscillating
@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
@callback @callback
def async_update_supported_features( def async_update_supported_features(
self, self,
entity_id: str, entity_id: str,
new_state: State | None, new_state: State | None,
update_state: bool = True,
) -> None: ) -> None:
"""Update dictionaries with supported features.""" """Update dictionaries with supported features."""
if not new_state: if not new_state:
@ -172,25 +158,6 @@ class FanGroup(GroupEntity, FanEntity):
else: else:
self._fans[feature].discard(entity_id) self._fans[feature].discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)
await super().async_added_to_hass()
async def async_set_percentage(self, percentage: int) -> None: async def async_set_percentage(self, percentage: int) -> None:
"""Set the speed of the fan, as a percentage.""" """Set the speed of the fan, as a percentage."""
if percentage == 0: if percentage == 0:
@ -250,7 +217,7 @@ class FanGroup(GroupEntity, FanEntity):
await self.hass.services.async_call( await self.hass.services.async_call(
DOMAIN, DOMAIN,
service, service,
{ATTR_ENTITY_ID: self._entities}, {ATTR_ENTITY_ID: self._entity_ids},
blocking=True, blocking=True,
context=self._context, context=self._context,
) )
@ -275,7 +242,7 @@ class FanGroup(GroupEntity, FanEntity):
states = [ states = [
state state
for entity_id in self._entities for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None if (state := self.hass.states.get(entity_id)) is not None
] ]
self._attr_assumed_state |= not states_equal(states) self._attr_assumed_state |= not states_equal(states)

View file

@ -47,11 +47,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
from .util import find_state_attributes, mean_tuple, reduce_attribute from .util import find_state_attributes, mean_tuple, reduce_attribute
@ -153,25 +149,6 @@ class LightGroup(GroupEntity, LightEntity):
if mode: if mode:
self.mode = all self.mode = all
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Forward the turn_on command to all lights in the light group.""" """Forward the turn_on command to all lights in the light group."""
data = { data = {

View file

@ -31,11 +31,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
@ -114,25 +110,6 @@ class LockGroup(GroupEntity, LockEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entity_ids} self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entity_ids}
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_lock(self, **kwargs: Any) -> None: async def async_lock(self, **kwargs: Any) -> None:
"""Forward the lock command to all locks in the group.""" """Forward the lock command to all locks in the group."""
data = {ATTR_ENTITY_ID: self._entity_ids} data = {ATTR_ENTITY_ID: self._entity_ids}

View file

@ -1,7 +1,7 @@
"""Platform allowing several sensors to be grouped into one sensor to provide numeric combinations.""" """Platform allowing several sensors to be grouped into one sensor to provide numeric combinations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping from collections.abc import Callable
from datetime import datetime from datetime import datetime
import logging import logging
import statistics import statistics
@ -33,19 +33,10 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, State, callback from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
EventType,
StateType,
)
from . import GroupEntity from . import GroupEntity
from .const import CONF_IGNORE_NON_NUMERIC from .const import CONF_IGNORE_NON_NUMERIC
@ -303,45 +294,6 @@ class SensorGroup(GroupEntity, SensorEntity):
self._state_incorrect: set[str] = set() self._state_incorrect: set[str] = set()
self._extra_state_attribute: dict[str, Any] = {} self._extra_state_attribute: dict[str, Any] = {}
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
@callback @callback
def async_update_group_state(self) -> None: def async_update_group_state(self) -> None:
"""Query all members and determine the sensor group state.""" """Query all members and determine the sensor group state."""

View file

@ -22,11 +22,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import ( from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from . import GroupEntity from . import GroupEntity
@ -112,25 +108,6 @@ class SwitchGroup(GroupEntity, SwitchEntity):
if mode: if mode:
self.mode = all self.mode = all
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Forward the turn_on command to all switches in the group.""" """Forward the turn_on command to all switches in the group."""
data = {ATTR_ENTITY_ID: self._entity_ids} data = {ATTR_ENTITY_ID: self._entity_ids}