Add identify buttons to ZHA devices (#61495)
* Identify buttons * clean up and add test * use Platform * update device list * Only 1 identify button per device * cleanup press until the need arises for the branch * make imports relative
This commit is contained in:
parent
fa6d6d914b
commit
41531b528e
7 changed files with 1858 additions and 1146 deletions
106
homeassistant/components/zha/button.py
Normal file
106
homeassistant/components/zha/button.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
"""Support for ZHA button."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ENTITY_CATEGORY_DIAGNOSTIC, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .core import discovery
|
||||
from .core.const import CHANNEL_IDENTIFY, DATA_ZHA, SIGNAL_ADD_ENTITIES
|
||||
from .core.registries import ZHA_ENTITIES
|
||||
from .core.typing import ChannelType, ZhaDeviceType
|
||||
from .entity import ZhaEntity
|
||||
|
||||
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BUTTON)
|
||||
DEFAULT_DURATION = 5 # seconds
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up the Zigbee Home Automation button from config entry."""
|
||||
entities_to_create = hass.data[DATA_ZHA][Platform.BUTTON]
|
||||
|
||||
unsub = async_dispatcher_connect(
|
||||
hass,
|
||||
SIGNAL_ADD_ENTITIES,
|
||||
functools.partial(
|
||||
discovery.async_add_entities,
|
||||
async_add_entities,
|
||||
entities_to_create,
|
||||
update_before_add=False,
|
||||
),
|
||||
)
|
||||
config_entry.async_on_unload(unsub)
|
||||
|
||||
|
||||
class ZHAButton(ZhaEntity, ButtonEntity):
|
||||
"""Defines a ZHA button."""
|
||||
|
||||
_command_name: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unique_id: str,
|
||||
zha_device: ZhaDeviceType,
|
||||
channels: list[ChannelType],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init this button."""
|
||||
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||
self._channel: ChannelType = channels[0]
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_args(self) -> list[Any]:
|
||||
"""Return the arguments to use in the command."""
|
||||
|
||||
async def async_press(self) -> None:
|
||||
"""Send out a update command."""
|
||||
command = getattr(self._channel, self._command_name)
|
||||
arguments = self.get_args()
|
||||
await command(*arguments)
|
||||
|
||||
|
||||
@MULTI_MATCH(channel_names=CHANNEL_IDENTIFY)
|
||||
class ZHAIdentifyButton(ZHAButton):
|
||||
"""Defines a ZHA identify button."""
|
||||
|
||||
@classmethod
|
||||
def create_entity(
|
||||
cls,
|
||||
unique_id: str,
|
||||
zha_device: ZhaDeviceType,
|
||||
channels: list[ChannelType],
|
||||
**kwargs,
|
||||
) -> ZhaEntity | None:
|
||||
"""Entity Factory.
|
||||
|
||||
Return entity if it is a supported configuration, otherwise return None
|
||||
"""
|
||||
platform_restrictions = ZHA_ENTITIES.single_device_matches[Platform.BUTTON]
|
||||
device_restrictions = platform_restrictions[zha_device.ieee]
|
||||
if CHANNEL_IDENTIFY in device_restrictions:
|
||||
return None
|
||||
device_restrictions.append(CHANNEL_IDENTIFY)
|
||||
return cls(unique_id, zha_device, channels, **kwargs)
|
||||
|
||||
_attr_device_class: ButtonDeviceClass = ButtonDeviceClass.UPDATE
|
||||
_attr_entity_category = ENTITY_CATEGORY_DIAGNOSTIC
|
||||
_command_name = "identify"
|
||||
|
||||
def get_args(self) -> list[Any]:
|
||||
"""Return the arguments to use in the command."""
|
||||
|
||||
return [DEFAULT_DURATION]
|
|
@ -102,6 +102,7 @@ CLUSTER_TYPE_OUT = "out"
|
|||
PLATFORMS = (
|
||||
Platform.ALARM_CONTROL_PANEL,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.BUTTON,
|
||||
Platform.CLIMATE,
|
||||
Platform.COVER,
|
||||
Platform.DEVICE_TRACKER,
|
||||
|
|
|
@ -17,6 +17,7 @@ from . import const as zha_const, registries as zha_regs, typing as zha_typing
|
|||
from .. import ( # noqa: F401 pylint: disable=unused-import,
|
||||
alarm_control_panel,
|
||||
binary_sensor,
|
||||
button,
|
||||
climate,
|
||||
cover,
|
||||
device_tracker,
|
||||
|
@ -66,6 +67,7 @@ class ProbeEndpoint:
|
|||
self.discover_by_device_type(channel_pool)
|
||||
self.discover_multi_entities(channel_pool)
|
||||
self.discover_by_cluster_id(channel_pool)
|
||||
zha_regs.ZHA_ENTITIES.clean_up()
|
||||
|
||||
@callback
|
||||
def discover_by_device_type(self, channel_pool: zha_typing.ChannelPoolType) -> None:
|
||||
|
|
|
@ -9,6 +9,7 @@ import attr
|
|||
from zigpy import zcl
|
||||
import zigpy.profiles.zha
|
||||
import zigpy.profiles.zll
|
||||
from zigpy.types.named import EUI64
|
||||
|
||||
from homeassistant.const import Platform
|
||||
|
||||
|
@ -228,6 +229,9 @@ class ZHAEntityRegistry:
|
|||
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
)
|
||||
self._group_registry: dict[str, CALLABLE_T] = {}
|
||||
self.single_device_matches: dict[
|
||||
Platform, dict[EUI64, list[str]]
|
||||
] = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
|
||||
def get_entity(
|
||||
self,
|
||||
|
@ -342,5 +346,11 @@ class ZHAEntityRegistry:
|
|||
|
||||
return decorator
|
||||
|
||||
def clean_up(self) -> None:
|
||||
"""Clean up post discovery."""
|
||||
self.single_device_matches: dict[
|
||||
Platform, dict[EUI64, list[str]]
|
||||
] = collections.defaultdict(lambda: collections.defaultdict(list))
|
||||
|
||||
|
||||
ZHA_ENTITIES = ZHAEntityRegistry()
|
||||
|
|
89
tests/components/zha/test_button.py
Normal file
89
tests/components/zha/test_button.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
"""Test ZHA button."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
from zigpy.const import SIG_EP_PROFILE
|
||||
import zigpy.profiles.zha as zha
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import zigpy.zcl.clusters.security as security
|
||||
import zigpy.zcl.foundation as zcl_f
|
||||
|
||||
from homeassistant.components.button import DOMAIN, ButtonDeviceClass
|
||||
from homeassistant.components.button.const import SERVICE_PRESS
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_ENTITY_ID,
|
||||
ENTITY_CATEGORY_DIAGNOSTIC,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from .common import find_entity_id
|
||||
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def contact_sensor(hass, zigpy_device_mock, zha_device_joined_restored):
|
||||
"""Contact sensor fixture."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
SIG_EP_INPUT: [
|
||||
general.Basic.cluster_id,
|
||||
general.Identify.cluster_id,
|
||||
security.IasZone.cluster_id,
|
||||
],
|
||||
SIG_EP_OUTPUT: [],
|
||||
SIG_EP_TYPE: zha.DeviceType.IAS_ZONE,
|
||||
SIG_EP_PROFILE: zha.PROFILE_ID,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
zha_device = await zha_device_joined_restored(zigpy_device)
|
||||
return zha_device, zigpy_device.endpoints[1].identify
|
||||
|
||||
|
||||
@freeze_time("2021-11-04 17:37:00", tz_offset=-1)
|
||||
async def test_button(hass, contact_sensor):
|
||||
"""Test zha button platform."""
|
||||
|
||||
entity_registry = er.async_get(hass)
|
||||
zha_device, cluster = contact_sensor
|
||||
assert cluster is not None
|
||||
entity_id = await find_entity_id(DOMAIN, zha_device, hass)
|
||||
assert entity_id is not None
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
|
||||
|
||||
entry = entity_registry.async_get(entity_id)
|
||||
assert entry
|
||||
assert entry.entity_category == ENTITY_CATEGORY_DIAGNOSTIC
|
||||
|
||||
with patch(
|
||||
"zigpy.zcl.Cluster.request",
|
||||
return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_PRESS,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert len(cluster.request.mock_calls) == 1
|
||||
assert cluster.request.call_args[0][0] is False
|
||||
assert cluster.request.call_args[0][1] == 0
|
||||
assert cluster.request.call_args[0][3] == 5 # duration in seconds
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == "2021-11-04T16:37:00+00:00"
|
||||
assert state.attributes[ATTR_DEVICE_CLASS] == ButtonDeviceClass.UPDATE
|
|
@ -125,17 +125,25 @@ async def test_devices(
|
|||
ch.id for pool in zha_dev.channels.pools for ch in pool.client_channels.values()
|
||||
}
|
||||
assert event_channels == set(device[DEV_SIG_EVT_CHANNELS])
|
||||
|
||||
# we need to probe the class create entity factory so we need to reset this to get accurate results
|
||||
zha_regs.ZHA_ENTITIES.clean_up()
|
||||
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
||||
ha_ent_info = {}
|
||||
created_entity_count = 0
|
||||
for call in _dispatch.call_args_list:
|
||||
_, component, entity_cls, unique_id, channels = call[0]
|
||||
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id
|
||||
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
|
||||
component,
|
||||
unique_id,
|
||||
channels,
|
||||
)
|
||||
# the factory can return None. We filter these out to get an accurate created entity count
|
||||
response = entity_cls.create_entity(unique_id, zha_dev, channels)
|
||||
if response:
|
||||
created_entity_count += 1
|
||||
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
|
||||
0
|
||||
) # ieee + endpoint_id
|
||||
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
|
||||
component,
|
||||
unique_id,
|
||||
channels,
|
||||
)
|
||||
|
||||
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
|
||||
component, unique_id = comp_id
|
||||
|
@ -156,7 +164,7 @@ async def test_devices(
|
|||
assert unique_id.startswith(ha_unique_id)
|
||||
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
||||
|
||||
assert _dispatch.call_count == len(device[DEV_SIG_ENT_MAP])
|
||||
assert created_entity_count == len(device[DEV_SIG_ENT_MAP])
|
||||
|
||||
entity_ids = hass_disable_services.states.async_entity_ids()
|
||||
await hass_disable_services.async_block_till_done()
|
||||
|
@ -298,7 +306,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
|
|||
assert device_info[DEV_SIG_EVT_CHANNELS] == sorted(
|
||||
ch.id for pool in channels.pools for ch in pool.client_channels.values()
|
||||
)
|
||||
assert new_ent.call_count == len(list(device_info[DEV_SIG_ENT_MAP].values()))
|
||||
|
||||
# build a dict of entity_class -> (component, unique_id, channels) tuple
|
||||
ha_ent_info = {}
|
||||
|
@ -326,8 +333,6 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
|
|||
assert unique_id.startswith(ha_unique_id)
|
||||
assert {ch.name for ch in ha_channels} == set(ent_info[DEV_SIG_CHANNELS])
|
||||
|
||||
assert new_ent.call_count == len(device_info[DEV_SIG_ENT_MAP])
|
||||
|
||||
|
||||
def _ch_mock(cluster):
|
||||
"""Return mock of a channel with a cluster."""
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue