Deduplicate entities derived from GroupEntity (#98893)

This commit is contained in:
Erik Montnemery 2023-08-23 19:20:58 +02:00 committed by GitHub
parent ee1b6a60a0
commit 3c10d0e1f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 249 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Collection, Iterable
from collections.abc import Callable, Collection, Iterable, Mapping
from contextvars import ContextVar
import logging
from typing import Any, Protocol, cast
@ -473,9 +473,60 @@ class GroupEntity(Entity):
"""Representation of a Group of entities."""
_attr_should_poll = False
_entity_ids: list[str]
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
if event:
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
async def _update_at_start(_: HomeAssistant) -> None:
self.async_update_group_state()
@ -493,9 +544,18 @@ class GroupEntity(Entity):
self.async_write_ha_state()
@abstractmethod
@callback
def async_update_group_state(self) -> None:
"""Abstract method to update the entity."""
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
) -> None:
"""Update dictionaries with supported features."""
class Group(Entity):
"""Track a group of entity ids."""

View file

@ -1,9 +1,6 @@
"""Platform allowing several binary sensor to be grouped into one binary sensor."""
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any
import voluptuous as vol
from homeassistant.components.binary_sensor import (
@ -24,14 +21,10 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
@ -116,45 +109,6 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
if mode:
self.mode = all
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the binary sensor group state."""

View file

@ -41,11 +41,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
from .util import attribute_equal, reduce_attribute
@ -112,7 +108,7 @@ class CoverGroup(GroupEntity, CoverEntity):
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a CoverGroup entity."""
self._entities = entities
self._entity_ids = entities
self._covers: dict[str, set[str]] = {
KEY_OPEN_CLOSE: set(),
KEY_STOP: set(),
@ -128,21 +124,11 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities}
self._attr_unique_id = unique_id
@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
update_state: bool = True,
) -> None:
"""Update dictionaries with supported features."""
if not new_state:
@ -150,8 +136,6 @@ class CoverGroup(GroupEntity, CoverEntity):
values.discard(entity_id)
for values in self._tilts.values():
values.discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
return
features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
@ -182,25 +166,6 @@ class CoverGroup(GroupEntity, CoverEntity):
else:
self._tilts[KEY_POSITION].discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)
await super().async_added_to_hass()
async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the covers up."""
data = {ATTR_ENTITY_ID: self._covers[KEY_OPEN_CLOSE]}
@ -278,7 +243,7 @@ class CoverGroup(GroupEntity, CoverEntity):
states = [
state.state
for entity_id in self._entities
for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None
]
@ -292,7 +257,7 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_is_closed = True
self._attr_is_closing = False
self._attr_is_opening = False
for entity_id in self._entities:
for entity_id in self._entity_ids:
if not (state := self.hass.states.get(entity_id)):
continue
if state.state == STATE_OPEN:
@ -347,7 +312,7 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_supported_features = supported_features
if not self._attr_assumed_state:
for entity_id in self._entities:
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
if state and state.attributes.get(ATTR_ASSUMED_STATE):

View file

@ -38,11 +38,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
from .util import (
@ -108,7 +104,7 @@ class FanGroup(GroupEntity, FanEntity):
def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a FanGroup entity."""
self._entities = entities
self._entity_ids = entities
self._fans: dict[int, set[str]] = {flag: set() for flag in SUPPORTED_FLAGS}
self._percentage = None
self._oscillating = None
@ -144,21 +140,11 @@ class FanGroup(GroupEntity, FanEntity):
"""Return whether or not the fan is currently oscillating."""
return self._oscillating
@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
update_state: bool = True,
) -> None:
"""Update dictionaries with supported features."""
if not new_state:
@ -172,25 +158,6 @@ class FanGroup(GroupEntity, FanEntity):
else:
self._fans[feature].discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)
await super().async_added_to_hass()
async def async_set_percentage(self, percentage: int) -> None:
"""Set the speed of the fan, as a percentage."""
if percentage == 0:
@ -250,7 +217,7 @@ class FanGroup(GroupEntity, FanEntity):
await self.hass.services.async_call(
DOMAIN,
service,
{ATTR_ENTITY_ID: self._entities},
{ATTR_ENTITY_ID: self._entity_ids},
blocking=True,
context=self._context,
)
@ -275,7 +242,7 @@ class FanGroup(GroupEntity, FanEntity):
states = [
state
for entity_id in self._entities
for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None
]
self._attr_assumed_state |= not states_equal(states)

View file

@ -47,11 +47,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
from .util import find_state_attributes, mean_tuple, reduce_attribute
@ -153,25 +149,6 @@ class LightGroup(GroupEntity, LightEntity):
if mode:
self.mode = all
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_turn_on(self, **kwargs: Any) -> None:
"""Forward the turn_on command to all lights in the light group."""
data = {

View file

@ -31,11 +31,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
@ -114,25 +110,6 @@ class LockGroup(GroupEntity, LockEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entity_ids}
self._attr_unique_id = unique_id
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_lock(self, **kwargs: Any) -> None:
"""Forward the lock command to all locks in the group."""
data = {ATTR_ENTITY_ID: self._entity_ids}

View file

@ -1,7 +1,7 @@
"""Platform allowing several sensors to be grouped into one sensor to provide numeric combinations."""
from __future__ import annotations
from collections.abc import Callable, Mapping
from collections.abc import Callable
from datetime import datetime
import logging
import statistics
@ -33,19 +33,10 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, State, callback
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import (
ConfigType,
DiscoveryInfoType,
EventType,
StateType,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType
from . import GroupEntity
from .const import CONF_IGNORE_NON_NUMERIC
@ -303,45 +294,6 @@ class SensorGroup(GroupEntity, SensorEntity):
self._state_incorrect: set[str] = set()
self._extra_state_attribute: dict[str, Any] = {}
@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())
async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the sensor group state."""

View file

@ -22,11 +22,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import GroupEntity
@ -112,25 +108,6 @@ class SwitchGroup(GroupEntity, SwitchEntity):
if mode:
self.mode = all
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)
await super().async_added_to_hass()
async def async_turn_on(self, **kwargs: Any) -> None:
"""Forward the turn_on command to all switches in the group."""
data = {ATTR_ENTITY_ID: self._entity_ids}