Add identify buttons to ZHA devices (#61495)

* Identify buttons

* clean up and add test

* use Platform

* update device list

* Only 1 identify button per device

* cleanup press until the need arises for the branch

* make imports relative
This commit is contained in:
David F. Mulcahey 2021-12-23 17:52:42 -05:00 committed by GitHub
parent fa6d6d914b
commit 41531b528e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 1858 additions and 1146 deletions

View file

@ -0,0 +1,106 @@
"""Support for ZHA button."""
from __future__ import annotations
import abc
import functools
import logging
from typing import Any
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ENTITY_CATEGORY_DIAGNOSTIC, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery
from .core.const import CHANNEL_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES
from .core.registries import ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BUTTON)
DEFAULT_DURATION = 5 # seconds
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the Zigbee Home Automation button from config entry."""
entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON]
unsub = async_dispatcher_connect(
hass,
SIGNAL_ADD_ENTITIES,
functools.partial(
discovery.async_add_entities,
async_add_entities,
entities_to_create,
update_before_add=False,
),
)
config_entry.async_on_unload(unsub)
class ZHAButton(ZhaEntity, ButtonEntity):
"""Defines a ZHA button."""
_command_name: str = None
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> None:
"""Init this button."""
super().__init__(unique_id, zha_device, channels, **kwargs)
self._channel: ChannelType = channels[0]
@abc.abstractmethod
def get_args(self) -> list[Any]:
"""Return the arguments to use in the command."""
async def async_press(self) -> None:
"""Send out a update command."""
command = getattr(self._channel, self._command_name)
arguments = self.get_args()
await command(*arguments)
@MULTI_MATCH(channel_names=CHANNEL_IDENTIFY)
class ZHAIdentifyButton(ZHAButton):
"""Defines a ZHA identify button."""
@classmethod
def create_entity(
cls,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> ZhaEntity | None:
"""Entity Factory.
Return entity if it is a supported configuration, otherwise return None
"""
platform_restrictions = ZHA_ENTITIES.single_device_matches[Platform.BUTTON]
device_restrictions = platform_restrictions[zha_device.ieee]
if CHANNEL_IDENTIFY in device_restrictions:
return None
device_restrictions.append(CHANNEL_IDENTIFY)
return cls(unique_id, zha_device, channels, **kwargs)
_attr_device_class: ButtonDeviceClass = ButtonDeviceClass.UPDATE
_attr_entity_category = ENTITY_CATEGORY_DIAGNOSTIC
_command_name = "identify"
def get_args(self) -> list[Any]:
"""Return the arguments to use in the command."""
return [DEFAULT_DURATION]

View file

@ -102,6 +102,7 @@ CLUSTER_TYPE_OUT = "out"
PLATFORMS = ( PLATFORMS = (
Platform.ALARM_CONTROL_PANEL, Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,
Platform.BUTTON,
Platform.CLIMATE, Platform.CLIMATE,
Platform.COVER, Platform.COVER,
Platform.DEVICE_TRACKER, Platform.DEVICE_TRACKER,

View file

@ -17,6 +17,7 @@ from . import const as zha_const, registries as zha_regs, typing as zha_typing
from .. import ( # noqa: F401 pylint: disable=unused-import, from .. import ( # noqa: F401 pylint: disable=unused-import,
alarm_control_panel, alarm_control_panel,
binary_sensor, binary_sensor,
button,
climate, climate,
cover, cover,
device_tracker, device_tracker,
@ -66,6 +67,7 @@ class ProbeEndpoint:
self.discover_by_device_type(channel_pool) self.discover_by_device_type(channel_pool)
self.discover_multi_entities(channel_pool) self.discover_multi_entities(channel_pool)
self.discover_by_cluster_id(channel_pool) self.discover_by_cluster_id(channel_pool)
zha_regs.ZHA_ENTITIES.clean_up()
@callback @callback
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None: def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:

View file

@ -9,6 +9,7 @@ import attr
from zigpy import zcl from zigpy import zcl
import zigpy.profiles.zha import zigpy.profiles.zha
import zigpy.profiles.zll import zigpy.profiles.zll
from zigpy.types.named import EUI64
from homeassistant.const import Platform from homeassistant.const import Platform
@ -228,6 +229,9 @@ class ZHAEntityRegistry:
lambda: collections.defaultdict(lambda: collections.defaultdict(list)) lambda: collections.defaultdict(lambda: collections.defaultdict(list))
) )
self._group_registry: dict[str, CALLABLE_T] = {} self._group_registry: dict[str, CALLABLE_T] = {}
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
] = collections.defaultdict(lambda: collections.defaultdict(list))
def get_entity( def get_entity(
self, self,
@ -342,5 +346,11 @@ class ZHAEntityRegistry:
return decorator return decorator
def clean_up(self) -> None:
"""Clean up post discovery."""
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
] = collections.defaultdict(lambda: collections.defaultdict(list))
ZHA_ENTITIES = ZHAEntityRegistry() ZHA_ENTITIES = ZHAEntityRegistry()

View file

@ -0,0 +1,89 @@
"""Test ZHA button."""
from unittest.mock import patch
from freezegun import freeze_time
import pytest
from zigpy.const import SIG_EP_PROFILE
import zigpy.profiles.zha as zha
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.security as security
import zigpy.zcl.foundation as zcl_f
from homeassistant.components.button import DOMAIN, ButtonDeviceClass
from homeassistant.components.button.const import SERVICE_PRESS
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_ENTITY_ID,
ENTITY_CATEGORY_DIAGNOSTIC,
STATE_UNKNOWN,
)
from homeassistant.helpers import entity_registry as er
from .common import find_entity_id
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import mock_coro
@pytest.fixture
async def contact_sensor(hass, zigpy_device_mock, zha_device_joined_restored):
"""Contact sensor fixture."""
zigpy_device = zigpy_device_mock(
{
1: {
SIG_EP_INPUT: [
general.Basic.cluster_id,
general.Identify.cluster_id,
security.IasZone.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zha.DeviceType.IAS_ZONE,
SIG_EP_PROFILE: zha.PROFILE_ID,
}
},
)
zha_device = await zha_device_joined_restored(zigpy_device)
return zha_device, zigpy_device.endpoints[1].identify
@freeze_time("2021-11-04 17:37:00", tz_offset=-1)
async def test_button(hass, contact_sensor):
"""Test zha button platform."""
entity_registry = er.async_get(hass)
zha_device, cluster = contact_sensor
assert cluster is not None
entity_id = await find_entity_id(DOMAIN, zha_device, hass)
assert entity_id is not None
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
entry = entity_registry.async_get(entity_id)
assert entry
assert entry.entity_category == ENTITY_CATEGORY_DIAGNOSTIC
with patch(
"zigpy.zcl.Cluster.request",
return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]),
):
await hass.services.async_call(
DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: entity_id},
blocking=True,
)
await hass.async_block_till_done()
assert len(cluster.request.mock_calls) == 1
assert cluster.request.call_args[0][0] is False
assert cluster.request.call_args[0][1] == 0
assert cluster.request.call_args[0][3] == 5 # duration in seconds
state = hass.states.get(entity_id)
assert state
assert state.state == "2021-11-04T16:37:00+00:00"
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE

View file

@ -125,17 +125,25 @@ async def test_devices(
ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values() ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values()
} }
assert event_channels == set(device[DEV_SIG_EVT_CHANNELS]) assert event_channels == set(device[DEV_SIG_EVT_CHANNELS])
# we need to probe the class create entity factory so we need to reset this to get accurate results
zha_regs.ZHA_ENTITIES.clean_up()
# build a dict of entity_class -> (component, unique_id, channels) tuple # build a dict of entity_class -> (component, unique_id, channels) tuple
ha_ent_info = {} ha_ent_info = {}
created_entity_count = 0
for call in _dispatch.call_args_list: for call in _dispatch.call_args_list:
_, component, entity_cls, unique_id, channels = call[0] _, component, entity_cls, unique_id, channels = call[0]
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id # the factory can return None. We filter these out to get an accurate created entity count
ha_ent_info[(unique_id_head, entity_cls.__name__)] = ( response = entity_cls.create_entity(unique_id, zha_dev, channels)
component, if response:
unique_id, created_entity_count += 1
channels, unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
) 0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
component,
unique_id,
channels,
)
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items(): for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
component, unique_id = comp_id component, unique_id = comp_id
@ -156,7 +164,7 @@ async def test_devices(
assert unique_id.startswith(ha_unique_id) assert unique_id.startswith(ha_unique_id)
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS]) assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
assert _dispatch.call_count == len(device[DEV_SIG_ENT_MAP]) assert created_entity_count == len(device[DEV_SIG_ENT_MAP])
entity_ids = hass_disable_services.states.async_entity_ids() entity_ids = hass_disable_services.states.async_entity_ids()
await hass_disable_services.async_block_till_done() await hass_disable_services.async_block_till_done()
@ -298,7 +306,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
assert device_info[DEV_SIG_EVT_CHANNELS] == sorted( assert device_info[DEV_SIG_EVT_CHANNELS] == sorted(
ch.id for pool in channels.pools for ch in pool.client_channels.values() ch.id for pool in channels.pools for ch in pool.client_channels.values()
) )
assert new_ent.call_count == len(list(device_info[DEV_SIG_ENT_MAP].values()))
# build a dict of entity_class -> (component, unique_id, channels) tuple # build a dict of entity_class -> (component, unique_id, channels) tuple
ha_ent_info = {} ha_ent_info = {}
@ -326,8 +333,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
assert unique_id.startswith(ha_unique_id) assert unique_id.startswith(ha_unique_id)
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS]) assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
assert new_ent.call_count == len(device_info[DEV_SIG_ENT_MAP])
def _ch_mock(cluster): def _ch_mock(cluster):
"""Return mock of a channel with a cluster.""" """Return mock of a channel with a cluster."""

File diff suppressed because it is too large Load diff