Refactor ZHA entity matching process (#60063)

* Group multi-matches by channels

* Group multi-matched by explicit groups

* Registryless AnalogInput and PowerConfiguration

* Refactor single cluster sensor registry

* Refactor single cluster cover and lock registry

* Refactor single cluster binary_sensor registry

* Pylint
This commit is contained in:
Alexei Chetroi 2021-12-12 07:52:49 -05:00 committed by GitHub
parent 359affb856
commit 997809c6c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 94 deletions

View file

@ -36,6 +36,7 @@ CLASS_MAPPING = {
}
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.BINARY_SENSOR)
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BINARY_SENSOR)
async def async_setup_entry(
@ -103,7 +104,7 @@ class BinarySensor(ZhaEntity, BinarySensorEntity):
self._state = attr_value
@STRICT_MATCH(channel_names=CHANNEL_ACCELEROMETER)
@MULTI_MATCH(channel_names=CHANNEL_ACCELEROMETER)
class Accelerometer(BinarySensor):
"""ZHA BinarySensor."""
@ -111,7 +112,7 @@ class Accelerometer(BinarySensor):
_attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.MOVING
@STRICT_MATCH(channel_names=CHANNEL_OCCUPANCY)
@MULTI_MATCH(channel_names=CHANNEL_OCCUPANCY)
class Occupancy(BinarySensor):
"""ZHA BinarySensor."""
@ -127,7 +128,7 @@ class Opening(BinarySensor):
_attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.OPENING
@STRICT_MATCH(channel_names=CHANNEL_BINARY_INPUT)
@MULTI_MATCH(channel_names=CHANNEL_BINARY_INPUT)
class BinaryInput(BinarySensor):
"""ZHA BinarySensor."""
@ -153,7 +154,7 @@ class Motion(BinarySensor):
_attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.MOTION
@STRICT_MATCH(channel_names=CHANNEL_ZONE)
@MULTI_MATCH(channel_names=CHANNEL_ZONE)
class IASZone(BinarySensor):
"""ZHA IAS BinarySensor."""

View file

@ -172,7 +172,11 @@ async def async_setup_entry(
config_entry.async_on_unload(unsub)
@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT, aux_channels=CHANNEL_FAN)
@MULTI_MATCH(
channel_names=CHANNEL_THERMOSTAT,
aux_channels=CHANNEL_FAN,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
class Thermostat(ZhaEntity, ClimateEntity):
"""Representation of a ZHA Thermostat device."""
@ -526,7 +530,7 @@ class Thermostat(ZhaEntity, ClimateEntity):
@MULTI_MATCH(
channel_names={CHANNEL_THERMOSTAT, "sinope_manufacturer_specific"},
manufacturers="Sinope Technologies",
stop_on_match=True,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
class SinopeTechnologiesThermostat(Thermostat):
"""Sinope Technologies Thermostat."""
@ -579,7 +583,7 @@ class SinopeTechnologiesThermostat(Thermostat):
channel_names=CHANNEL_THERMOSTAT,
aux_channels=CHANNEL_FAN,
manufacturers="Zen Within",
stop_on_match=True,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
class ZenWithinThermostat(Thermostat):
"""Zen Within Thermostat implementation."""
@ -609,7 +613,7 @@ class ZenWithinThermostat(Thermostat):
aux_channels=CHANNEL_FAN,
manufacturers="Centralite",
models={"3157100", "3157100-E"},
stop_on_match=True,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
class CentralitePearl(ZenWithinThermostat):
"""Centralite Pearl Thermostat implementation."""

View file

@ -4,7 +4,6 @@ from __future__ import annotations
import collections
from collections.abc import Callable
import dataclasses
from typing import Dict, List
import attr
from zigpy import zcl
@ -53,29 +52,10 @@ REMOTE_DEVICE_TYPES = collections.defaultdict(list, REMOTE_DEVICE_TYPES)
SINGLE_INPUT_CLUSTER_DEVICE_CLASS = {
# this works for now but if we hit conflicts we can break it out to
# a different dict that is keyed by manufacturer
SMARTTHINGS_ACCELERATION_CLUSTER: Platform.BINARY_SENSOR,
SMARTTHINGS_HUMIDITY_CLUSTER: Platform.SENSOR,
VOC_LEVEL_CLUSTER: Platform.SENSOR,
zcl.clusters.closures.DoorLock.cluster_id: Platform.LOCK,
zcl.clusters.closures.WindowCovering.cluster_id: Platform.COVER,
zcl.clusters.general.BinaryInput.cluster_id: Platform.BINARY_SENSOR,
zcl.clusters.general.AnalogInput.cluster_id: Platform.SENSOR,
zcl.clusters.general.AnalogOutput.cluster_id: Platform.NUMBER,
zcl.clusters.general.MultistateInput.cluster_id: Platform.SENSOR,
zcl.clusters.general.OnOff.cluster_id: Platform.SWITCH,
zcl.clusters.general.PowerConfiguration.cluster_id: Platform.SENSOR,
zcl.clusters.hvac.Fan.cluster_id: Platform.FAN,
zcl.clusters.measurement.CarbonDioxideConcentration.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.CarbonMonoxideConcentration.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.FormaldehydeConcentration.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.IlluminanceMeasurement.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.OccupancySensing.cluster_id: Platform.BINARY_SENSOR,
zcl.clusters.measurement.PressureMeasurement.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.RelativeHumidity.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.SoilMoisture.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.LeafWetness.cluster_id: Platform.SENSOR,
zcl.clusters.measurement.TemperatureMeasurement.cluster_id: Platform.SENSOR,
zcl.clusters.security.IasZone.cluster_id: Platform.BINARY_SENSOR,
}
SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS = {
@ -136,12 +116,10 @@ def set_or_callable(value):
class MatchRule:
"""Match a ZHA Entity to a channel name or generic id."""
channel_names: Callable | set[str] | str = attr.ib(
factory=frozenset, converter=set_or_callable
)
generic_ids: Callable | set[str] | str = attr.ib(
channel_names: set[str] | str = attr.ib(
factory=frozenset, converter=set_or_callable
)
generic_ids: set[str] | str = attr.ib(factory=frozenset, converter=set_or_callable)
manufacturers: Callable | set[str] | str = attr.ib(
factory=frozenset, converter=set_or_callable
)
@ -151,8 +129,6 @@ class MatchRule:
aux_channels: Callable | set[str] | str = attr.ib(
factory=frozenset, converter=set_or_callable
)
# for multi entities, stop further processing on a match for a component
stop_on_match: bool = attr.ib(default=False)
@property
def weight(self) -> int:
@ -238,21 +214,20 @@ class EntityClassAndChannels:
claimed_channel: list[ChannelType]
RegistryDictType = Dict[str, Dict[MatchRule, CALLABLE_T]]
MultiRegistryDictType = Dict[str, Dict[MatchRule, List[CALLABLE_T]]]
GroupRegistryDictType = Dict[str, CALLABLE_T]
class ZHAEntityRegistry:
"""Channel to ZHA Entity mapping."""
def __init__(self):
"""Initialize Registry instance."""
self._strict_registry: RegistryDictType = collections.defaultdict(dict)
self._multi_entity_registry: MultiRegistryDictType = collections.defaultdict(
lambda: collections.defaultdict(list)
self._strict_registry: dict[
str, dict[MatchRule, CALLABLE_T]
] = collections.defaultdict(dict)
self._multi_entity_registry: dict[
str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]]
] = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.defaultdict(list))
)
self._group_registry: GroupRegistryDictType = {}
self._group_registry: dict[str, CALLABLE_T] = {}
def get_entity(
self,
@ -276,22 +251,21 @@ class ZHAEntityRegistry:
manufacturer: str,
model: str,
channels: list[ChannelType],
components: set | None = None,
) -> tuple[dict[str, list[EntityClassAndChannels]], list[ChannelType]]:
"""Match ZHA Channels to potentially multiple ZHA Entity classes."""
result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list)
all_claimed: set[ChannelType] = set()
for component in components or self._multi_entity_registry:
matches = self._multi_entity_registry[component]
for component, stop_match_groups in self._multi_entity_registry.items():
for stop_match_grp, matches in stop_match_groups.items():
sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True)
for match in sorted_matches:
if match.strict_matched(manufacturer, model, channels):
claimed = match.claim_channels(channels)
for ent_class in self._multi_entity_registry[component][match]:
for ent_class in stop_match_groups[stop_match_grp][match]:
ent_n_channels = EntityClassAndChannels(ent_class, claimed)
result[component].append(ent_n_channels)
all_claimed |= set(claimed)
if match.stop_on_match:
if stop_match_grp:
break
return result, list(all_claimed)
@ -303,8 +277,8 @@ class ZHAEntityRegistry:
def strict_match(
self,
component: str,
channel_names: Callable | set[str] | str = None,
generic_ids: Callable | set[str] | str = None,
channel_names: set[str] | str = None,
generic_ids: set[str] | str = None,
manufacturers: Callable | set[str] | str = None,
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
@ -328,12 +302,12 @@ class ZHAEntityRegistry:
def multipass_match(
self,
component: str,
channel_names: Callable | set[str] | str = None,
generic_ids: Callable | set[str] | str = None,
channel_names: set[str] | str = None,
generic_ids: set[str] | str = None,
manufacturers: Callable | set[str] | str = None,
models: Callable | set[str] | str = None,
aux_channels: Callable | set[str] | str = None,
stop_on_match: bool = False,
stop_on_match_group: int | str | None = None,
) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Decorate a loose match rule."""
@ -343,7 +317,6 @@ class ZHAEntityRegistry:
manufacturers,
models,
aux_channels,
stop_on_match,
)
def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T:
@ -351,7 +324,10 @@ class ZHAEntityRegistry:
All non empty fields of a match rule must match.
"""
self._multi_entity_registry[component][rule].append(zha_entity)
# group the rules by channels
self._multi_entity_registry[component][stop_on_match_group][rule].append(
zha_entity
)
return zha_entity
return decorator

View file

@ -42,7 +42,7 @@ from .entity import ZhaEntity
_LOGGER = logging.getLogger(__name__)
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.COVER)
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.COVER)
async def async_setup_entry(
@ -63,7 +63,7 @@ async def async_setup_entry(
config_entry.async_on_unload(unsub)
@STRICT_MATCH(channel_names=CHANNEL_COVER)
@MULTI_MATCH(channel_names=CHANNEL_COVER)
class ZhaCover(ZhaEntity, CoverEntity):
"""Representation of a ZHA cover."""
@ -182,7 +182,7 @@ class ZhaCover(ZhaEntity, CoverEntity):
self._state = None
@STRICT_MATCH(channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF, CHANNEL_SHADE})
@MULTI_MATCH(channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF, CHANNEL_SHADE})
class Shade(ZhaEntity, CoverEntity):
"""ZHA Shade."""
@ -289,7 +289,7 @@ class Shade(ZhaEntity, CoverEntity):
return
@STRICT_MATCH(
@MULTI_MATCH(
channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF}, manufacturers="Keen Home Inc"
)
class KeenVent(Shade):

View file

@ -23,7 +23,7 @@ from .entity import ZhaEntity
# The first state is Zigbee 'Not fully locked'
STATE_LIST = [STATE_UNLOCKED, STATE_LOCKED, STATE_UNLOCKED]
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.LOCK)
MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.LOCK)
VALUE_TO_STATE = dict(enumerate(STATE_LIST))
@ -86,7 +86,7 @@ async def async_setup_entry(
)
@STRICT_MATCH(channel_names=CHANNEL_DOORLOCK)
@MULTI_MATCH(channel_names=CHANNEL_DOORLOCK)
class ZhaDoorLock(ZhaEntity, LockEntity):
"""Representation of a ZHA lock."""

View file

@ -186,19 +186,24 @@ class Sensor(ZhaEntity, SensorEntity):
return round(float(value * self._multiplier) / self._divisor)
@STRICT_MATCH(
@MULTI_MATCH(
channel_names=CHANNEL_ANALOG_INPUT,
manufacturers="LUMI",
models={"lumi.plug", "lumi.plug.maus01", "lumi.plug.mmeu01"},
stop_on_match_group=CHANNEL_ANALOG_INPUT,
)
@MULTI_MATCH(
channel_names=CHANNEL_ANALOG_INPUT,
manufacturers="Digi",
stop_on_match_group=CHANNEL_ANALOG_INPUT,
)
@STRICT_MATCH(channel_names=CHANNEL_ANALOG_INPUT, manufacturers="Digi")
class AnalogInput(Sensor):
"""Sensor that displays analog input values."""
SENSOR_ATTR = "present_value"
@STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION)
@MULTI_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION)
class Battery(Sensor):
"""Battery sensor of power configuration cluster."""
@ -339,8 +344,10 @@ class ElectricalMeasurementRMSVoltage(ElectricalMeasurement, id_suffix="rms_volt
return False
@STRICT_MATCH(generic_ids=CHANNEL_ST_HUMIDITY_CLUSTER)
@STRICT_MATCH(channel_names=CHANNEL_HUMIDITY)
@MULTI_MATCH(
generic_ids=CHANNEL_ST_HUMIDITY_CLUSTER, stop_on_match_group=CHANNEL_HUMIDITY
)
@MULTI_MATCH(channel_names=CHANNEL_HUMIDITY, stop_on_match_group=CHANNEL_HUMIDITY)
class Humidity(Sensor):
"""Humidity sensor."""
@ -351,7 +358,7 @@ class Humidity(Sensor):
_unit = PERCENTAGE
@STRICT_MATCH(channel_names=CHANNEL_SOIL_MOISTURE)
@MULTI_MATCH(channel_names=CHANNEL_SOIL_MOISTURE)
class SoilMoisture(Sensor):
"""Soil Moisture sensor."""
@ -362,7 +369,7 @@ class SoilMoisture(Sensor):
_unit = PERCENTAGE
@STRICT_MATCH(channel_names=CHANNEL_LEAF_WETNESS)
@MULTI_MATCH(channel_names=CHANNEL_LEAF_WETNESS)
class LeafWetness(Sensor):
"""Leaf Wetness sensor."""
@ -373,7 +380,7 @@ class LeafWetness(Sensor):
_unit = PERCENTAGE
@STRICT_MATCH(channel_names=CHANNEL_ILLUMINANCE)
@MULTI_MATCH(channel_names=CHANNEL_ILLUMINANCE)
class Illuminance(Sensor):
"""Illuminance Sensor."""
@ -465,7 +472,7 @@ class SmartEnergySummation(SmartEnergyMetering, id_suffix="summation_delivered")
return round(cooked, 3)
@STRICT_MATCH(channel_names=CHANNEL_PRESSURE)
@MULTI_MATCH(channel_names=CHANNEL_PRESSURE)
class Pressure(Sensor):
"""Pressure sensor."""
@ -476,7 +483,7 @@ class Pressure(Sensor):
_unit = PRESSURE_HPA
@STRICT_MATCH(channel_names=CHANNEL_TEMPERATURE)
@MULTI_MATCH(channel_names=CHANNEL_TEMPERATURE)
class Temperature(Sensor):
"""Temperature Sensor."""
@ -487,7 +494,7 @@ class Temperature(Sensor):
_unit = TEMP_CELSIUS
@STRICT_MATCH(channel_names="carbon_dioxide_concentration")
@MULTI_MATCH(channel_names="carbon_dioxide_concentration")
class CarbonDioxideConcentration(Sensor):
"""Carbon Dioxide Concentration sensor."""
@ -499,7 +506,7 @@ class CarbonDioxideConcentration(Sensor):
_unit = CONCENTRATION_PARTS_PER_MILLION
@STRICT_MATCH(channel_names="carbon_monoxide_concentration")
@MULTI_MATCH(channel_names="carbon_monoxide_concentration")
class CarbonMonoxideConcentration(Sensor):
"""Carbon Monoxide Concentration sensor."""
@ -511,8 +518,8 @@ class CarbonMonoxideConcentration(Sensor):
_unit = CONCENTRATION_PARTS_PER_MILLION
@STRICT_MATCH(generic_ids="channel_0x042e")
@STRICT_MATCH(channel_names="voc_level")
@MULTI_MATCH(generic_ids="channel_0x042e", stop_on_match_group="voc_level")
@MULTI_MATCH(channel_names="voc_level", stop_on_match_group="voc_level")
class VOCLevel(Sensor):
"""VOC Level sensor."""
@ -524,7 +531,11 @@ class VOCLevel(Sensor):
_unit = CONCENTRATION_MICROGRAMS_PER_CUBIC_METER
@STRICT_MATCH(channel_names="voc_level", models="lumi.airmonitor.acn01")
@MULTI_MATCH(
channel_names="voc_level",
models="lumi.airmonitor.acn01",
stop_on_match_group="voc_level",
)
class PPBVOCLevel(Sensor):
"""VOC Level sensor."""
@ -536,7 +547,7 @@ class PPBVOCLevel(Sensor):
_unit = CONCENTRATION_PARTS_PER_BILLION
@STRICT_MATCH(channel_names="formaldehyde_concentration")
@MULTI_MATCH(channel_names="formaldehyde_concentration")
class FormaldehydeConcentration(Sensor):
"""Formaldehyde Concentration sensor."""
@ -547,7 +558,7 @@ class FormaldehydeConcentration(Sensor):
_unit = CONCENTRATION_PARTS_PER_MILLION
@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT)
@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT, stop_on_match_group=CHANNEL_THERMOSTAT)
class ThermostatHVACAction(Sensor, id_suffix="hvac_action"):
"""Thermostat HVAC action sensor."""
@ -626,12 +637,12 @@ class ThermostatHVACAction(Sensor, id_suffix="hvac_action"):
aux_channels=CHANNEL_FAN,
manufacturers="Centralite",
models={"3157100", "3157100-E"},
stop_on_match=True,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
@MULTI_MATCH(
channel_names=CHANNEL_THERMOSTAT,
manufacturers="Zen Within",
stop_on_match=True,
stop_on_match_group=CHANNEL_THERMOSTAT,
)
class ZenHVACAction(ThermostatHVACAction):
"""Zen Within Thermostat HVAC Action."""

View file

@ -367,7 +367,6 @@ def _test_single_input_cluster_device_class(probe_mock):
cover_ch,
multistate_ch,
ias_ch,
analog_ch,
]
disc.ProbeEndpoint().discover_by_cluster_id(ch_pool)
@ -385,11 +384,6 @@ def _test_single_input_cluster_device_class(probe_mock):
assert call[0][1] == ch
def test_single_input_cluster_device_class():
"""Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class."""
_test_single_input_cluster_device_class()
def test_single_input_cluster_device_class_by_cluster_class():
"""Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class."""
mock_reg = {