Implement support for start_up_on_off in ZHA (#70110)

* Implement support for start_up_on_off

fix discovery issues

remove cover change

* add tests
This commit is contained in:
David F. Mulcahey 2022-04-24 12:50:06 -04:00 committed by GitHub
parent 8a73381b56
commit 9b8d217b0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 273 additions and 19 deletions

View file

@ -291,6 +291,9 @@ class OnOffChannel(ZigbeeChannel):
ON_OFF = 0
REPORT_CONFIG = ({"attr": "on_off", "config": REPORT_CONFIG_IMMEDIATE},)
ZCL_INIT_ATTRS = {
"start_up_on_off": True,
}
def __init__(
self, cluster: zha_typing.ZigpyClusterType, ch_pool: zha_typing.ChannelPoolType

View file

@ -78,6 +78,7 @@ class ProbeEndpoint:
self.discover_by_device_type(channel_pool)
self.discover_multi_entities(channel_pool)
self.discover_by_cluster_id(channel_pool)
self.discover_multi_entities(channel_pool, config_diagnostic_entities=True)
zha_regs.ZHA_ENTITIES.clean_up()
@callback
@ -177,16 +178,27 @@ class ProbeEndpoint:
@staticmethod
@callback
def discover_multi_entities(channel_pool: ChannelPool) -> None:
def discover_multi_entities(
channel_pool: ChannelPool,
config_diagnostic_entities: bool = False,
) -> None:
"""Process an endpoint on and discover multiple entities."""
ep_profile_id = channel_pool.endpoint.profile_id
ep_device_type = channel_pool.endpoint.device_type
cmpt_by_dev_type = zha_regs.DEVICE_CLASS[ep_profile_id].get(ep_device_type)
remaining_channels = channel_pool.unclaimed_channels()
if config_diagnostic_entities:
matches, claimed = zha_regs.ZHA_ENTITIES.get_config_diagnostic_entity(
channel_pool.manufacturer,
channel_pool.model,
list(channel_pool.all_channels.values()),
)
else:
matches, claimed = zha_regs.ZHA_ENTITIES.get_multi_entity(
channel_pool.manufacturer, channel_pool.model, remaining_channels
channel_pool.manufacturer,
channel_pool.model,
channel_pool.unclaimed_channels(),
)
channel_pool.claim_channels(claimed)

View file

@ -232,6 +232,11 @@ class ZHAEntityRegistry:
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._config_diagnostic_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]]
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._group_registry: dict[str, CALLABLE_T] = {}
self.single_device_matches: dict[
Platform, dict[EUI64, list[str]]
@ -278,6 +283,33 @@ class ZHAEntityRegistry:
return result, list(all_claimed)
def get_config_diagnostic_entity(
self,
manufacturer: str,
model: str,
channels: list[ChannelType],
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ChannelType]]:
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
all_claimed: set[ChannelType] = set()
for (
component,
stop_match_groups,
) in self._config_diagnostic_entity_registry.items():
for stop_match_grp, matches in stop_match_groups.items():
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
for match in sorted_matches:
if match.strict_matched(manufacturer, model, channels):
claimed = match.claim_channels(channels)
for ent_class in stop_match_groups[stop_match_grp][match]:
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
result[component].append(ent_n_channels)
all_claimed |= set(claimed)
if stop_match_grp:
break
return result, list(all_claimed)
def get_group_entity(self, component: str) -> CALLABLE_T:
"""Match a ZHA group to a ZHA Entity class."""
return self._group_registry.get(component)
@ -340,6 +372,39 @@ class ZHAEntityRegistry:
return decorator
def config_diagnostic_match(
self,
component: str,
channel_names: set[str] | str = None,
generic_ids: set[str] | str = None,
manufacturers: Callable | set[str] | str = None,
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
stop_on_match_group: int | str | None = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a loose match rule."""
rule = MatchRule(
channel_names,
generic_ids,
manufacturers,
models,
aux_channels,
)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
"""Register a loose match rule.
All non empty fields of a match rule must match.
"""
# group the rules by channels
self._config_diagnostic_entity_registry[component][stop_on_match_group][
rule
].append(zha_entity)
return zha_entity
return decorator
def group_match(self, component: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a group match rule."""

View file

@ -4,6 +4,7 @@ from __future__ import annotations
from enum import Enum
import functools
from zigpy.zcl.clusters.general import OnOff
from zigpy.zcl.clusters.security import IasWd
from homeassistant.components.select import SelectEntity
@ -15,12 +16,20 @@ from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .core import discovery
from .core.const import CHANNEL_IAS_WD, DATA_ZHA, SIGNAL_ADD_ENTITIES, Strobe
from .core.const import (
CHANNEL_IAS_WD,
CHANNEL_ON_OFF,
DATA_ZHA,
SIGNAL_ADD_ENTITIES,
Strobe,
)
from .core.registries import ZHA_ENTITIES
from .core.typing import ChannelType, ZhaDeviceType
from .entity import ZhaEntity
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.SELECT)
CONFIG_DIAGNOSTIC_MATCH = functools.partial(
ZHA_ENTITIES.config_diagnostic_match, Platform.SELECT
)
async def async_setup_entry(
@ -100,7 +109,7 @@ class ZHANonZCLSelectEntity(ZHAEnumSelectEntity):
return True
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD)
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultToneSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.WarningMode.__name__
):
@ -109,7 +118,7 @@ class ZHADefaultToneSelectEntity(
_enum: Enum = IasWd.Warning.WarningMode
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD)
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultSirenLevelSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.Warning.SirenLevel.__name__
):
@ -118,7 +127,7 @@ class ZHADefaultSirenLevelSelectEntity(
_enum: Enum = IasWd.Warning.SirenLevel
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD)
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultStrobeLevelSelectEntity(
ZHANonZCLSelectEntity, id_suffix=IasWd.StrobeLevel.__name__
):
@ -127,8 +136,72 @@ class ZHADefaultStrobeLevelSelectEntity(
_enum: Enum = IasWd.StrobeLevel
@MULTI_MATCH(channel_names=CHANNEL_IAS_WD)
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_IAS_WD)
class ZHADefaultStrobeSelectEntity(ZHANonZCLSelectEntity, id_suffix=Strobe.__name__):
"""Representation of a ZHA default siren strobe select entity."""
_enum: Enum = Strobe
class ZCLEnumSelectEntity(ZhaEntity, SelectEntity):
"""Representation of a ZHA ZCL enum select entity."""
_select_attr: str
_attr_entity_category = EntityCategory.CONFIG
_enum: Enum
@classmethod
def create_entity(
cls,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> ZhaEntity | None:
"""Entity Factory.
Return entity if it is a supported configuration, otherwise return None
"""
channel = channels[0]
if cls._select_attr in channel.cluster.unsupported_attributes:
return None
return cls(unique_id, zha_device, channels, **kwargs)
def __init__(
self,
unique_id: str,
zha_device: ZhaDeviceType,
channels: list[ChannelType],
**kwargs,
) -> None:
"""Init this select entity."""
self._attr_options = [entry.name.replace("_", " ") for entry in self._enum]
self._channel: ChannelType = channels[0]
super().__init__(unique_id, zha_device, channels, **kwargs)
@property
def current_option(self) -> str | None:
"""Return the selected entity option to represent the entity state."""
option = self._channel.cluster.get(self._select_attr)
if option is None:
return None
option = self._enum(option)
return option.name.replace("_", " ")
async def async_select_option(self, option: str | int) -> None:
"""Change the selected option."""
await self._channel.cluster.write_attributes(
{self._select_attr: self._enum[option.replace(" ", "_")]}
)
self.async_write_ha_state()
@CONFIG_DIAGNOSTIC_MATCH(channel_names=CHANNEL_ON_OFF)
class ZHAStartupOnOffSelectEntity(
ZCLEnumSelectEntity, id_suffix=OnOff.StartUpOnOff.__name__
):
"""Representation of a ZHA startup onoff select entity."""
_select_attr = "start_up_on_off"
_enum: Enum = OnOff.StartUpOnOff

View file

@ -44,6 +44,16 @@ from .zha_devices_list import (
NO_TAIL_ID = re.compile("_\\d$")
UNIQUE_ID_HD = re.compile(r"^(([\da-fA-F]{2}:){7}[\da-fA-F]{2}-\d{1,3})", re.X)
IGNORE_SUFFIXES = [zigpy.zcl.clusters.general.OnOff.StartUpOnOff.__name__]
def contains_ignored_suffix(unique_id: str) -> bool:
"""Return true if the unique_id ends with an ignored suffix."""
for suffix in IGNORE_SUFFIXES:
if suffix.lower() in unique_id.lower():
return True
return False
@pytest.fixture
def channels_mock(zha_device_mock):
@ -142,7 +152,7 @@ async def test_devices(
_, component, entity_cls, unique_id, channels = call[0]
# the factory can return None. We filter these out to get an accurate created entity count
response = entity_cls.create_entity(unique_id, zha_dev, channels)
if response:
if response and not contains_ignored_suffix(response.name):
created_entity_count += 1
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
@ -178,7 +188,9 @@ async def test_devices(
await hass_disable_services.async_block_till_done()
zha_entity_ids = {
ent for ent in entity_ids if ent.split(".")[0] in zha_const.PLATFORMS
ent
for ent in entity_ids
if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS
}
assert zha_entity_ids == {
e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values()
@ -319,7 +331,10 @@ async def test_discover_endpoint(device_info, channels_mock, hass):
ha_ent_info = {}
for call in new_ent.call_args_list:
component, entity_cls, unique_id, channels = call[0]
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) # ieee + endpoint_id
if not contains_ignored_suffix(unique_id):
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
component,
unique_id,

View file

@ -33,6 +33,28 @@ async def siren(hass, zigpy_device_mock, zha_device_joined_restored):
return zha_device, zigpy_device.endpoints[1].ias_wd
@pytest.fixture
async def light(hass, zigpy_device_mock):
"""Siren fixture."""
zigpy_device = zigpy_device_mock(
{
1: {
SIG_EP_PROFILE: zha.PROFILE_ID,
SIG_EP_TYPE: zha.DeviceType.ON_OFF_LIGHT,
SIG_EP_INPUT: [
general.Basic.cluster_id,
general.Identify.cluster_id,
general.OnOff.cluster_id,
],
SIG_EP_OUTPUT: [general.Ota.cluster_id],
}
},
)
return zigpy_device
@pytest.fixture
def core_rs(hass_storage):
"""Core.restore_state fixture."""
@ -149,3 +171,67 @@ async def test_select_restore_state(
state = hass.states.get(entity_id)
assert state
assert state.state == security.IasWd.Warning.WarningMode.Burglar.name
async def test_on_off_select(hass, light, zha_device_joined_restored):
"""Test zha on off select."""
entity_registry = er.async_get(hass)
on_off_cluster = light.endpoints[1].on_off
on_off_cluster.PLUGGED_ATTR_READS = {
"start_up_on_off": general.OnOff.StartUpOnOff.On
}
zha_device = await zha_device_joined_restored(light)
select_name = general.OnOff.StartUpOnOff.__name__
entity_id = await find_entity_id(
Platform.SELECT,
zha_device,
hass,
qualifier=select_name.lower(),
)
assert entity_id is not None
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
assert state.attributes["options"] == ["Off", "On", "Toggle", "PreviousValue"]
entity_entry = entity_registry.async_get(entity_id)
assert entity_entry
assert entity_entry.entity_category == ENTITY_CATEGORY_CONFIG
# Test select option with string value
await hass.services.async_call(
"select",
"select_option",
{
"entity_id": entity_id,
"option": general.OnOff.StartUpOnOff.Off.name,
},
blocking=True,
)
assert on_off_cluster.write_attributes.call_count == 1
assert on_off_cluster.write_attributes.call_args[0][0] == {
"start_up_on_off": general.OnOff.StartUpOnOff.Off
}
state = hass.states.get(entity_id)
assert state
assert state.state == general.OnOff.StartUpOnOff.Off.name
async def test_on_off_select_unsupported(hass, light, zha_device_joined_restored):
"""Test zha on off select unsupported."""
on_off_cluster = light.endpoints[1].on_off
on_off_cluster.add_unsupported_attribute("start_up_on_off")
zha_device = await zha_device_joined_restored(light)
select_name = general.OnOff.StartUpOnOff.__name__
entity_id = await find_entity_id(
Platform.SELECT,
zha_device,
hass,
qualifier=select_name.lower(),
)
assert entity_id is None