Add identify buttons to ZHA devices (#61495)
* Identify buttons * clean up and add test * use Platform * update device list * Only 1 identify button per device * cleanup press until the need arises for the branch * make imports relative
This commit is contained in:
parent
fa6d6d914b
commit
41531b528e
7 changed files with 1858 additions and 1146 deletions
106
homeassistant/components/zha/button.py
Normal file
106
homeassistant/components/zha/button.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
"""Support for ZHA button."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import ENTITY_CATEGORY_DIAGNOSTIC, Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .core import discovery
|
||||||
|
from .core.const import CHANNEL_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES
|
||||||
|
from .core.registries import ZHA_ENTITIES
|
||||||
|
from .core.typing import ChannelType, ZhaDeviceType
|
||||||
|
from .entity import ZhaEntity
|
||||||
|
|
||||||
|
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BUTTON)
|
||||||
|
DEFAULT_DURATION = 5 # seconds
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up the Zigbee Home Automation button from config entry."""
|
||||||
|
entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON]
|
||||||
|
|
||||||
|
unsub = async_dispatcher_connect(
|
||||||
|
hass,
|
||||||
|
SIGNAL_ADD_ENTITIES,
|
||||||
|
functools.partial(
|
||||||
|
discovery.async_add_entities,
|
||||||
|
async_add_entities,
|
||||||
|
entities_to_create,
|
||||||
|
update_before_add=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
config_entry.async_on_unload(unsub)
|
||||||
|
|
||||||
|
|
||||||
|
class ZHAButton(ZhaEntity, ButtonEntity):
|
||||||
|
"""Defines a ZHA button."""
|
||||||
|
|
||||||
|
_command_name: str = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unique_id: str,
|
||||||
|
zha_device: ZhaDeviceType,
|
||||||
|
channels: list[ChannelType],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Init this button."""
|
||||||
|
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||||
|
self._channel: ChannelType = channels[0]
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_args(self) -> list[Any]:
|
||||||
|
"""Return the arguments to use in the command."""
|
||||||
|
|
||||||
|
async def async_press(self) -> None:
|
||||||
|
"""Send out a update command."""
|
||||||
|
command = getattr(self._channel, self._command_name)
|
||||||
|
arguments = self.get_args()
|
||||||
|
await command(*arguments)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTI_MATCH(channel_names=CHANNEL_IDENTIFY)
|
||||||
|
class ZHAIdentifyButton(ZHAButton):
|
||||||
|
"""Defines a ZHA identify button."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_entity(
|
||||||
|
cls,
|
||||||
|
unique_id: str,
|
||||||
|
zha_device: ZhaDeviceType,
|
||||||
|
channels: list[ChannelType],
|
||||||
|
**kwargs,
|
||||||
|
) -> ZhaEntity | None:
|
||||||
|
"""Entity Factory.
|
||||||
|
|
||||||
|
Return entity if it is a supported configuration, otherwise return None
|
||||||
|
"""
|
||||||
|
platform_restrictions = ZHA_ENTITIES.single_device_matches[Platform.BUTTON]
|
||||||
|
device_restrictions = platform_restrictions[zha_device.ieee]
|
||||||
|
if CHANNEL_IDENTIFY in device_restrictions:
|
||||||
|
return None
|
||||||
|
device_restrictions.append(CHANNEL_IDENTIFY)
|
||||||
|
return cls(unique_id, zha_device, channels, **kwargs)
|
||||||
|
|
||||||
|
_attr_device_class: ButtonDeviceClass = ButtonDeviceClass.UPDATE
|
||||||
|
_attr_entity_category = ENTITY_CATEGORY_DIAGNOSTIC
|
||||||
|
_command_name = "identify"
|
||||||
|
|
||||||
|
def get_args(self) -> list[Any]:
|
||||||
|
"""Return the arguments to use in the command."""
|
||||||
|
|
||||||
|
return [DEFAULT_DURATION]
|
|
@ -102,6 +102,7 @@ CLUSTER_TYPE_OUT = "out"
|
||||||
PLATFORMS = (
|
PLATFORMS = (
|
||||||
Platform.ALARM_CONTROL_PANEL,
|
Platform.ALARM_CONTROL_PANEL,
|
||||||
Platform.BINARY_SENSOR,
|
Platform.BINARY_SENSOR,
|
||||||
|
Platform.BUTTON,
|
||||||
Platform.CLIMATE,
|
Platform.CLIMATE,
|
||||||
Platform.COVER,
|
Platform.COVER,
|
||||||
Platform.DEVICE_TRACKER,
|
Platform.DEVICE_TRACKER,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from . import const as zha_const, registries as zha_regs, typing as zha_typing
|
||||||
from .. import ( # noqa: F401 pylint: disable=unused-import,
|
from .. import ( # noqa: F401 pylint: disable=unused-import,
|
||||||
alarm_control_panel,
|
alarm_control_panel,
|
||||||
binary_sensor,
|
binary_sensor,
|
||||||
|
button,
|
||||||
climate,
|
climate,
|
||||||
cover,
|
cover,
|
||||||
device_tracker,
|
device_tracker,
|
||||||
|
@ -66,6 +67,7 @@ class ProbeEndpoint:
|
||||||
self.discover_by_device_type(channel_pool)
|
self.discover_by_device_type(channel_pool)
|
||||||
self.discover_multi_entities(channel_pool)
|
self.discover_multi_entities(channel_pool)
|
||||||
self.discover_by_cluster_id(channel_pool)
|
self.discover_by_cluster_id(channel_pool)
|
||||||
|
zha_regs.ZHA_ENTITIES.clean_up()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:
|
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:
|
||||||
|
|
|
@ -9,6 +9,7 @@ import attr
|
||||||
from zigpy import zcl
|
from zigpy import zcl
|
||||||
import zigpy.profiles.zha
|
import zigpy.profiles.zha
|
||||||
import zigpy.profiles.zll
|
import zigpy.profiles.zll
|
||||||
|
from zigpy.types.named import EUI64
|
||||||
|
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
|
|
||||||
|
@ -228,6 +229,9 @@ class ZHAEntityRegistry:
|
||||||
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
|
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
|
||||||
)
|
)
|
||||||
self._group_registry: dict[str, CALLABLE_T] = {}
|
self._group_registry: dict[str, CALLABLE_T] = {}
|
||||||
|
self.single_device_matches: dict[
|
||||||
|
Platform, dict[EUI64, list[str]]
|
||||||
|
] = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||||
|
|
||||||
def get_entity(
|
def get_entity(
|
||||||
self,
|
self,
|
||||||
|
@ -342,5 +346,11 @@ class ZHAEntityRegistry:
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def clean_up(self) -> None:
|
||||||
|
"""Clean up post discovery."""
|
||||||
|
self.single_device_matches: dict[
|
||||||
|
Platform, dict[EUI64, list[str]]
|
||||||
|
] = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||||
|
|
||||||
|
|
||||||
ZHA_ENTITIES = ZHAEntityRegistry()
|
ZHA_ENTITIES = ZHAEntityRegistry()
|
||||||
|
|
89
tests/components/zha/test_button.py
Normal file
89
tests/components/zha/test_button.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
"""Test ZHA button."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
|
import pytest
|
||||||
|
from zigpy.const import SIG_EP_PROFILE
|
||||||
|
import zigpy.profiles.zha as zha
|
||||||
|
import zigpy.zcl.clusters.general as general
|
||||||
|
import zigpy.zcl.clusters.security as security
|
||||||
|
import zigpy.zcl.foundation as zcl_f
|
||||||
|
|
||||||
|
from homeassistant.components.button import DOMAIN, ButtonDeviceClass
|
||||||
|
from homeassistant.components.button.const import SERVICE_PRESS
|
||||||
|
from homeassistant.const import (
|
||||||
|
ATTR_DEVICE_CLASS,
|
||||||
|
ATTR_ENTITY_ID,
|
||||||
|
ENTITY_CATEGORY_DIAGNOSTIC,
|
||||||
|
STATE_UNKNOWN,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
|
||||||
|
from .common import find_entity_id
|
||||||
|
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
|
||||||
|
|
||||||
|
from tests.common import mock_coro
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def contact_sensor(hass, zigpy_device_mock, zha_device_joined_restored):
|
||||||
|
"""Contact sensor fixture."""
|
||||||
|
|
||||||
|
zigpy_device = zigpy_device_mock(
|
||||||
|
{
|
||||||
|
1: {
|
||||||
|
SIG_EP_INPUT: [
|
||||||
|
general.Basic.cluster_id,
|
||||||
|
general.Identify.cluster_id,
|
||||||
|
security.IasZone.cluster_id,
|
||||||
|
],
|
||||||
|
SIG_EP_OUTPUT: [],
|
||||||
|
SIG_EP_TYPE: zha.DeviceType.IAS_ZONE,
|
||||||
|
SIG_EP_PROFILE: zha.PROFILE_ID,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
zha_device = await zha_device_joined_restored(zigpy_device)
|
||||||
|
return zha_device, zigpy_device.endpoints[1].identify
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2021-11-04 17:37:00", tz_offset=-1)
|
||||||
|
async def test_button(hass, contact_sensor):
|
||||||
|
"""Test zha button platform."""
|
||||||
|
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
zha_device, cluster = contact_sensor
|
||||||
|
assert cluster is not None
|
||||||
|
entity_id = await find_entity_id(DOMAIN, zha_device, hass)
|
||||||
|
assert entity_id is not None
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state == STATE_UNKNOWN
|
||||||
|
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
|
||||||
|
|
||||||
|
entry = entity_registry.async_get(entity_id)
|
||||||
|
assert entry
|
||||||
|
assert entry.entity_category == ENTITY_CATEGORY_DIAGNOSTIC
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"zigpy.zcl.Cluster.request",
|
||||||
|
return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]),
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_PRESS,
|
||||||
|
{ATTR_ENTITY_ID: entity_id},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(cluster.request.mock_calls) == 1
|
||||||
|
assert cluster.request.call_args[0][0] is False
|
||||||
|
assert cluster.request.call_args[0][1] == 0
|
||||||
|
assert cluster.request.call_args[0][3] == 5 # duration in seconds
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state == "2021-11-04T16:37:00+00:00"
|
||||||
|
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
|
|
@ -125,17 +125,25 @@ async def test_devices(
|
||||||
ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values()
|
ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values()
|
||||||
}
|
}
|
||||||
assert event_channels == set(device[DEV_SIG_EVT_CHANNELS])
|
assert event_channels == set(device[DEV_SIG_EVT_CHANNELS])
|
||||||
|
# we need to probe the class create entity factory so we need to reset this to get accurate results
|
||||||
|
zha_regs.ZHA_ENTITIES.clean_up()
|
||||||
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
||||||
ha_ent_info = {}
|
ha_ent_info = {}
|
||||||
|
created_entity_count = 0
|
||||||
for call in _dispatch.call_args_list:
|
for call in _dispatch.call_args_list:
|
||||||
_, component, entity_cls, unique_id, channels = call[0]
|
_, component, entity_cls, unique_id, channels = call[0]
|
||||||
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id
|
# the factory can return None. We filter these out to get an accurate created entity count
|
||||||
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
|
response = entity_cls.create_entity(unique_id, zha_dev, channels)
|
||||||
component,
|
if response:
|
||||||
unique_id,
|
created_entity_count += 1
|
||||||
channels,
|
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
|
||||||
)
|
0
|
||||||
|
) # ieee + endpoint_id
|
||||||
|
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
|
||||||
|
component,
|
||||||
|
unique_id,
|
||||||
|
channels,
|
||||||
|
)
|
||||||
|
|
||||||
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
|
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
|
||||||
component, unique_id = comp_id
|
component, unique_id = comp_id
|
||||||
|
@ -156,7 +164,7 @@ async def test_devices(
|
||||||
assert unique_id.startswith(ha_unique_id)
|
assert unique_id.startswith(ha_unique_id)
|
||||||
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
||||||
|
|
||||||
assert _dispatch.call_count == len(device[DEV_SIG_ENT_MAP])
|
assert created_entity_count == len(device[DEV_SIG_ENT_MAP])
|
||||||
|
|
||||||
entity_ids = hass_disable_services.states.async_entity_ids()
|
entity_ids = hass_disable_services.states.async_entity_ids()
|
||||||
await hass_disable_services.async_block_till_done()
|
await hass_disable_services.async_block_till_done()
|
||||||
|
@ -298,7 +306,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
|
||||||
assert device_info[DEV_SIG_EVT_CHANNELS] == sorted(
|
assert device_info[DEV_SIG_EVT_CHANNELS] == sorted(
|
||||||
ch.id for pool in channels.pools for ch in pool.client_channels.values()
|
ch.id for pool in channels.pools for ch in pool.client_channels.values()
|
||||||
)
|
)
|
||||||
assert new_ent.call_count == len(list(device_info[DEV_SIG_ENT_MAP].values()))
|
|
||||||
|
|
||||||
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
||||||
ha_ent_info = {}
|
ha_ent_info = {}
|
||||||
|
@ -326,8 +333,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
|
||||||
assert unique_id.startswith(ha_unique_id)
|
assert unique_id.startswith(ha_unique_id)
|
||||||
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
||||||
|
|
||||||
assert new_ent.call_count == len(device_info[DEV_SIG_ENT_MAP])
|
|
||||||
|
|
||||||
|
|
||||||
def _ch_mock(cluster):
|
def _ch_mock(cluster):
|
||||||
"""Return mock of a channel with a cluster."""
|
"""Return mock of a channel with a cluster."""
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue