Update ZHA group entity when Zigbee group membership changes (#33378)

* cleanup group entities

* add test

* appease pylint

* fix order
This commit is contained in:
David F. Mulcahey 2020-03-28 20:38:48 -04:00 committed by GitHub
parent 5bedc4ede2
commit f7ae78f78e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 131 deletions

View file

@ -208,6 +208,7 @@ SIGNAL_SET_LEVEL = "set_level"
SIGNAL_STATE_ATTR = "update_state_attribute"
SIGNAL_UPDATE_DEVICE = "{}_zha_update_device"
SIGNAL_REMOVE_GROUP = "remove_group"
SIGNAL_GROUP_MEMBERSHIP_CHANGE = "group_membership_change"
UNKNOWN = "unknown"
UNKNOWN_MANUFACTURER = "unk_manufacturer"

View file

@ -52,6 +52,7 @@ from .const import (
DEFAULT_DATABASE_NAME,
DOMAIN,
SIGNAL_ADD_ENTITIES,
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
SIGNAL_REMOVE,
SIGNAL_REMOVE_GROUP,
UNKNOWN_MANUFACTURER,
@ -256,6 +257,9 @@ class ZHAGateway:
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_member_removed - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED)
async_dispatcher_send(
self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{zigpy_group.group_id}"
)
def group_member_added(
self, zigpy_group: ZigpyGroupType, endpoint: ZigpyEndpointType
@ -265,6 +269,9 @@ class ZHAGateway:
zha_group = self._async_get_or_create_group(zigpy_group)
zha_group.info("group_member_added - endpoint: %s", endpoint)
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED)
async_dispatcher_send(
self._hass, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{zigpy_group.group_id}"
)
def group_added(self, zigpy_group: ZigpyGroupType) -> None:
"""Handle zigpy group added event."""

View file

@ -3,12 +3,13 @@
import asyncio
import logging
import time
from typing import Any, Awaitable, Dict, List
from typing import Any, Awaitable, Dict, List, Optional
from homeassistant.core import callback
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.helpers import entity
from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_state_change
from homeassistant.helpers.restore_state import RestoreEntity
from .core.const import (
@ -18,7 +19,9 @@ from .core.const import (
DATA_ZHA,
DATA_ZHA_BRIDGE_ID,
DOMAIN,
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
SIGNAL_REMOVE,
SIGNAL_REMOVE_GROUP,
)
from .core.helpers import LogMixin
from .core.typing import CALLABLE_T, ChannelsType, ChannelType, ZhaDeviceType
@ -213,3 +216,75 @@ class ZhaEntity(BaseZhaEntity):
for channel in self.cluster_channels.values():
if hasattr(channel, "async_update"):
await channel.async_update()
class ZhaGroupEntity(BaseZhaEntity):
"""A base class for ZHA group entities."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a light group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
self._entity_ids: List[str] = entity_ids
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
await self.async_accept_signal(
None,
f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_{self._group_id}",
self._update_group_entities,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
def _update_group_entities(self):
"""Update tracked entities when membership changes."""
group = self.zha_device.gateway.get_group(self._group_id)
self._entity_ids = group.get_domain_entity_ids(self.platform.domain)
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self) -> None:
"""Update the state of the group entity."""
pass

View file

@ -1,7 +1,7 @@
"""Fans on Zigbee Home Automation networks."""
import functools
import logging
from typing import List, Optional
from typing import List
from zigpy.exceptions import DeliveryError
import zigpy.zcl.clusters.hvac as hvac
@ -16,9 +16,8 @@ from homeassistant.components.fan import (
FanEntity,
)
from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.core import State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_state_change
from .core import discovery
from .core.const import (
@ -27,10 +26,9 @@ from .core.const import (
DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
)
from .core.registries import ZHA_ENTITIES
from .entity import BaseZhaEntity, ZhaEntity
from .entity import ZhaEntity, ZhaGroupEntity
_LOGGER = logging.getLogger(__name__)
@ -73,7 +71,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
class BaseFan(BaseZhaEntity, FanEntity):
class BaseFan(FanEntity):
"""Base representation of a ZHA fan."""
def __init__(self, *args, **kwargs):
@ -120,9 +118,14 @@ class BaseFan(BaseZhaEntity, FanEntity):
await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed])
self.async_set_state(0, "fan_mode", speed)
@callback
def async_set_state(self, attr_id, attr_name, value):
"""Handle state update from channel."""
pass
@STRICT_MATCH(channel_names=CHANNEL_FAN)
class ZhaFan(ZhaEntity, BaseFan):
class ZhaFan(BaseFan, ZhaEntity):
"""Representation of a ZHA fan."""
def __init__(self, unique_id, zha_device, channels, **kwargs):
@ -158,19 +161,15 @@ class ZhaFan(ZhaEntity, BaseFan):
@GROUP_MATCH()
class FanGroup(BaseFan):
class FanGroup(BaseFan, ZhaGroupEntity):
"""Representation of a fan group."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a fan group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs)
self._available: bool = False
self._entity_ids: List[str] = entity_ids
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
group = self.zha_device.gateway.get_group(self._group_id)
self._fan_channel = group.endpoint[hvac.Fan.cluster_id]
@ -185,35 +184,6 @@ class FanGroup(BaseFan):
self._fan_channel.async_set_speed = async_set_speed
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self):
"""Attempt to retrieve on off state from the fan."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]

View file

@ -30,12 +30,9 @@ from homeassistant.components.light import (
SUPPORT_WHITE_VALUE,
)
from homeassistant.const import ATTR_SUPPORTED_FEATURES, STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.core import State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import (
async_track_state_change,
async_track_time_interval,
)
from homeassistant.helpers.event import async_track_time_interval
import homeassistant.util.color as color_util
from .core import discovery, helpers
@ -50,12 +47,12 @@ from .core.const import (
EFFECT_DEFAULT_VARIANT,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
SIGNAL_SET_LEVEL,
)
from .core.helpers import LogMixin
from .core.registries import ZHA_ENTITIES
from .core.typing import ZhaDeviceType
from .entity import BaseZhaEntity, ZhaEntity
from .entity import ZhaEntity, ZhaGroupEntity
_LOGGER = logging.getLogger(__name__)
@ -100,7 +97,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
class BaseLight(BaseZhaEntity, light.Light):
class BaseLight(LogMixin, light.Light):
"""Operations common to all light entities."""
def __init__(self, *args, **kwargs):
@ -307,7 +304,7 @@ class BaseLight(BaseZhaEntity, light.Light):
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF, aux_channels={CHANNEL_COLOR, CHANNEL_LEVEL})
class Light(ZhaEntity, BaseLight):
class Light(BaseLight, ZhaEntity):
"""Representation of a ZHA or ZLL light."""
_REFRESH_INTERVAL = (45, 75)
@ -471,52 +468,19 @@ class HueLight(Light):
@GROUP_MATCH()
class LightGroup(BaseLight):
class LightGroup(BaseLight, ZhaGroupEntity):
"""Representation of a light group."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a light group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
self._entity_ids: List[str] = entity_ids
super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs)
group = self.zha_device.gateway.get_group(self._group_id)
self._on_off_channel = group.endpoint[OnOff.cluster_id]
self._level_channel = group.endpoint[LevelControl.cluster_id]
self._color_channel = group.endpoint[Color.cluster_id]
self._identify_channel = group.endpoint[Identify.cluster_id]
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self) -> None:
"""Query all members and determine the light group state."""

View file

@ -1,16 +1,15 @@
"""Switches on Zigbee Home Automation networks."""
import functools
import logging
from typing import Any, List, Optional
from typing import Any, List
from zigpy.zcl.clusters.general import OnOff
from zigpy.zcl.foundation import Status
from homeassistant.components.switch import DOMAIN, SwitchDevice
from homeassistant.const import STATE_ON, STATE_UNAVAILABLE
from homeassistant.core import CALLBACK_TYPE, State, callback
from homeassistant.core import State, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import async_track_state_change
from .core import discovery
from .core.const import (
@ -19,10 +18,9 @@ from .core.const import (
DATA_ZHA_DISPATCHERS,
SIGNAL_ADD_ENTITIES,
SIGNAL_ATTR_UPDATED,
SIGNAL_REMOVE_GROUP,
)
from .core.registries import ZHA_ENTITIES
from .entity import BaseZhaEntity, ZhaEntity
from .entity import ZhaEntity, ZhaGroupEntity
_LOGGER = logging.getLogger(__name__)
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
@ -43,7 +41,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
class BaseSwitch(BaseZhaEntity, SwitchDevice):
class BaseSwitch(SwitchDevice):
"""Common base class for zha switches."""
def __init__(self, *args, **kwargs):
@ -77,7 +75,7 @@ class BaseSwitch(BaseZhaEntity, SwitchDevice):
@STRICT_MATCH(channel_names=CHANNEL_ON_OFF)
class Switch(ZhaEntity, BaseSwitch):
class Switch(BaseSwitch, ZhaEntity):
"""ZHA switch."""
def __init__(self, unique_id, zha_device, channels, **kwargs):
@ -113,50 +111,17 @@ class Switch(ZhaEntity, BaseSwitch):
@GROUP_MATCH()
class SwitchGroup(BaseSwitch):
class SwitchGroup(BaseSwitch, ZhaGroupEntity):
"""Representation of a switch group."""
def __init__(
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
) -> None:
"""Initialize a switch group."""
super().__init__(unique_id, zha_device, **kwargs)
self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
self._group_id: int = group_id
super().__init__(entity_ids, unique_id, group_id, zha_device, **kwargs)
self._available: bool = False
self._entity_ids: List[str] = entity_ids
group = self.zha_device.gateway.get_group(self._group_id)
self._on_off_channel = group.endpoint[OnOff.cluster_id]
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
await self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
self.async_remove,
signal_override=True,
)
@callback
def async_state_changed_listener(
entity_id: str, old_state: State, new_state: State
):
"""Handle child updates."""
self.async_schedule_update_ha_state(True)
self._async_unsub_state_changed = async_track_state_change(
self.hass, self._entity_ids, async_state_changed_listener
)
await self.async_update()
async def async_will_remove_from_hass(self) -> None:
"""Handle removal from Home Assistant."""
await super().async_will_remove_from_hass()
if self._async_unsub_state_changed is not None:
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self) -> None:
"""Query all members and determine the light group state."""

View file

@ -30,6 +30,7 @@ ON = 1
OFF = 0
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE3 = "03:2d:6f:00:0a:90:69:e8"
LIGHT_ON_OFF = {
1: {
@ -140,6 +141,31 @@ async def device_light_2(hass, zigpy_device_mock, zha_device_joined):
return zha_device
@pytest.fixture
async def device_light_3(hass, zigpy_device_mock, zha_device_joined):
"""Test zha light platform."""
zigpy_device = zigpy_device_mock(
{
1: {
"in_clusters": [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
general.Identify.cluster_id,
],
"out_clusters": [],
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
}
},
ieee=IEEE_GROUPABLE_DEVICE3,
)
zha_device = await zha_device_joined(zigpy_device)
zha_device.set_available(True)
return zha_device
@patch("zigpy.zcl.clusters.general.OnOff.read_attributes", new=MagicMock())
async def test_light_refresh(hass, zigpy_device_mock, zha_device_joined_restored):
"""Test zha light platform refresh."""
@ -414,7 +440,7 @@ async def async_test_flash_from_hass(hass, cluster, entity_id, flash):
async def async_test_zha_group_light_entity(
hass, device_light_1, device_light_2, coordinator
hass, device_light_1, device_light_2, device_light_3, coordinator
):
"""Test the light entity for a ZHA group."""
zha_gateway = get_zha_gateway(hass)
@ -445,6 +471,7 @@ async def async_test_zha_group_light_entity(
dev1_cluster_on_off = device_light_1.endpoints[1].on_off
dev2_cluster_on_off = device_light_2.endpoints[1].on_off
dev3_cluster_on_off = device_light_3.endpoints[1].on_off
# test that the lights were created and that they are unavailable
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
@ -503,3 +530,12 @@ async def async_test_zha_group_light_entity(
# test that group light is now back on
assert hass.states.get(entity_id).state == STATE_ON
# test that group light is now off
await group_cluster_on_off.off()
assert hass.states.get(entity_id).state == STATE_OFF
# add a new member and test that his state is also tracked
await zha_group.async_add_members([device_light_3.ieee])
await dev3_cluster_on_off.on()
assert hass.states.get(entity_id).state == STATE_ON