diff --git a/homeassistant/components/zha/binary_sensor.py b/homeassistant/components/zha/binary_sensor.py index 7e2b34d3613..4bc7c156f37 100644 --- a/homeassistant/components/zha/binary_sensor.py +++ b/homeassistant/components/zha/binary_sensor.py @@ -36,6 +36,7 @@ CLASS_MAPPING = { } STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.BINARY_SENSOR) +MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.BINARY_SENSOR) async def async_setup_entry( @@ -103,7 +104,7 @@ class BinarySensor(ZhaEntity, BinarySensorEntity): self._state = attr_value -@STRICT_MATCH(channel_names=CHANNEL_ACCELEROMETER) +@MULTI_MATCH(channel_names=CHANNEL_ACCELEROMETER) class Accelerometer(BinarySensor): """ZHA BinarySensor.""" @@ -111,7 +112,7 @@ class Accelerometer(BinarySensor): _attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.MOVING -@STRICT_MATCH(channel_names=CHANNEL_OCCUPANCY) +@MULTI_MATCH(channel_names=CHANNEL_OCCUPANCY) class Occupancy(BinarySensor): """ZHA BinarySensor.""" @@ -127,7 +128,7 @@ class Opening(BinarySensor): _attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.OPENING -@STRICT_MATCH(channel_names=CHANNEL_BINARY_INPUT) +@MULTI_MATCH(channel_names=CHANNEL_BINARY_INPUT) class BinaryInput(BinarySensor): """ZHA BinarySensor.""" @@ -153,7 +154,7 @@ class Motion(BinarySensor): _attr_device_class: BinarySensorDeviceClass = BinarySensorDeviceClass.MOTION -@STRICT_MATCH(channel_names=CHANNEL_ZONE) +@MULTI_MATCH(channel_names=CHANNEL_ZONE) class IASZone(BinarySensor): """ZHA IAS BinarySensor.""" diff --git a/homeassistant/components/zha/climate.py b/homeassistant/components/zha/climate.py index 9c01f6630db..3473a6d8f9e 100644 --- a/homeassistant/components/zha/climate.py +++ b/homeassistant/components/zha/climate.py @@ -172,7 +172,11 @@ async def async_setup_entry( config_entry.async_on_unload(unsub) -@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT, aux_channels=CHANNEL_FAN) +@MULTI_MATCH( + channel_names=CHANNEL_THERMOSTAT, + aux_channels=CHANNEL_FAN, + stop_on_match_group=CHANNEL_THERMOSTAT, +) class Thermostat(ZhaEntity, ClimateEntity): """Representation of a ZHA Thermostat device.""" @@ -526,7 +530,7 @@ class Thermostat(ZhaEntity, ClimateEntity): @MULTI_MATCH( channel_names={CHANNEL_THERMOSTAT, "sinope_manufacturer_specific"}, manufacturers="Sinope Technologies", - stop_on_match=True, + stop_on_match_group=CHANNEL_THERMOSTAT, ) class SinopeTechnologiesThermostat(Thermostat): """Sinope Technologies Thermostat.""" @@ -579,7 +583,7 @@ class SinopeTechnologiesThermostat(Thermostat): channel_names=CHANNEL_THERMOSTAT, aux_channels=CHANNEL_FAN, manufacturers="Zen Within", - stop_on_match=True, + stop_on_match_group=CHANNEL_THERMOSTAT, ) class ZenWithinThermostat(Thermostat): """Zen Within Thermostat implementation.""" @@ -609,7 +613,7 @@ class ZenWithinThermostat(Thermostat): aux_channels=CHANNEL_FAN, manufacturers="Centralite", models={"3157100", "3157100-E"}, - stop_on_match=True, + stop_on_match_group=CHANNEL_THERMOSTAT, ) class CentralitePearl(ZenWithinThermostat): """Centralite Pearl Thermostat implementation.""" diff --git a/homeassistant/components/zha/core/registries.py b/homeassistant/components/zha/core/registries.py index 36a77b841b6..146b7d43f0f 100644 --- a/homeassistant/components/zha/core/registries.py +++ b/homeassistant/components/zha/core/registries.py @@ -4,7 +4,6 @@ from __future__ import annotations import collections from collections.abc import Callable import dataclasses -from typing import Dict, List import attr from zigpy import zcl @@ -53,29 +52,10 @@ REMOTE_DEVICE_TYPES = collections.defaultdict(list, REMOTE_DEVICE_TYPES) SINGLE_INPUT_CLUSTER_DEVICE_CLASS = { # this works for now but if we hit conflicts we can break it out to # a different dict that is keyed by manufacturer - SMARTTHINGS_ACCELERATION_CLUSTER: Platform.BINARY_SENSOR, - SMARTTHINGS_HUMIDITY_CLUSTER: Platform.SENSOR, - VOC_LEVEL_CLUSTER: Platform.SENSOR, - zcl.clusters.closures.DoorLock.cluster_id: Platform.LOCK, - zcl.clusters.closures.WindowCovering.cluster_id: Platform.COVER, - zcl.clusters.general.BinaryInput.cluster_id: Platform.BINARY_SENSOR, - zcl.clusters.general.AnalogInput.cluster_id: Platform.SENSOR, zcl.clusters.general.AnalogOutput.cluster_id: Platform.NUMBER, zcl.clusters.general.MultistateInput.cluster_id: Platform.SENSOR, zcl.clusters.general.OnOff.cluster_id: Platform.SWITCH, - zcl.clusters.general.PowerConfiguration.cluster_id: Platform.SENSOR, zcl.clusters.hvac.Fan.cluster_id: Platform.FAN, - zcl.clusters.measurement.CarbonDioxideConcentration.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.CarbonMonoxideConcentration.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.FormaldehydeConcentration.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.IlluminanceMeasurement.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.OccupancySensing.cluster_id: Platform.BINARY_SENSOR, - zcl.clusters.measurement.PressureMeasurement.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.RelativeHumidity.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.SoilMoisture.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.LeafWetness.cluster_id: Platform.SENSOR, - zcl.clusters.measurement.TemperatureMeasurement.cluster_id: Platform.SENSOR, - zcl.clusters.security.IasZone.cluster_id: Platform.BINARY_SENSOR, } SINGLE_OUTPUT_CLUSTER_DEVICE_CLASS = { @@ -136,12 +116,10 @@ def set_or_callable(value): class MatchRule: """Match a ZHA Entity to a channel name or generic id.""" - channel_names: Callable | set[str] | str = attr.ib( - factory=frozenset, converter=set_or_callable - ) - generic_ids: Callable | set[str] | str = attr.ib( + channel_names: set[str] | str = attr.ib( factory=frozenset, converter=set_or_callable ) + generic_ids: set[str] | str = attr.ib(factory=frozenset, converter=set_or_callable) manufacturers: Callable | set[str] | str = attr.ib( factory=frozenset, converter=set_or_callable ) @@ -151,8 +129,6 @@ class MatchRule: aux_channels: Callable | set[str] | str = attr.ib( factory=frozenset, converter=set_or_callable ) - # for multi entities, stop further processing on a match for a component - stop_on_match: bool = attr.ib(default=False) @property def weight(self) -> int: @@ -238,21 +214,20 @@ class EntityClassAndChannels: claimed_channel: list[ChannelType] -RegistryDictType = Dict[str, Dict[MatchRule, CALLABLE_T]] -MultiRegistryDictType = Dict[str, Dict[MatchRule, List[CALLABLE_T]]] -GroupRegistryDictType = Dict[str, CALLABLE_T] - - class ZHAEntityRegistry: """Channel to ZHA Entity mapping.""" def __init__(self): """Initialize Registry instance.""" - self._strict_registry: RegistryDictType = collections.defaultdict(dict) - self._multi_entity_registry: MultiRegistryDictType = collections.defaultdict( - lambda: collections.defaultdict(list) + self._strict_registry: dict[ + str, dict[MatchRule, CALLABLE_T] + ] = collections.defaultdict(dict) + self._multi_entity_registry: dict[ + str, dict[int | str | None, dict[MatchRule, list[CALLABLE_T]]] + ] = collections.defaultdict( + lambda: collections.defaultdict(lambda: collections.defaultdict(list)) ) - self._group_registry: GroupRegistryDictType = {} + self._group_registry: dict[str, CALLABLE_T] = {} def get_entity( self, @@ -276,23 +251,22 @@ class ZHAEntityRegistry: manufacturer: str, model: str, channels: list[ChannelType], - components: set | None = None, ) -> tuple[dict[str, list[EntityClassAndChannels]], list[ChannelType]]: """Match ZHA Channels to potentially multiple ZHA Entity classes.""" result: dict[str, list[EntityClassAndChannels]] = collections.defaultdict(list) all_claimed: set[ChannelType] = set() - for component in components or self._multi_entity_registry: - matches = self._multi_entity_registry[component] - sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True) - for match in sorted_matches: - if match.strict_matched(manufacturer, model, channels): - claimed = match.claim_channels(channels) - for ent_class in self._multi_entity_registry[component][match]: - ent_n_channels = EntityClassAndChannels(ent_class, claimed) - result[component].append(ent_n_channels) - all_claimed |= set(claimed) - if match.stop_on_match: - break + for component, stop_match_groups in self._multi_entity_registry.items(): + for stop_match_grp, matches in stop_match_groups.items(): + sorted_matches = sorted(matches, key=lambda x: x.weight, reverse=True) + for match in sorted_matches: + if match.strict_matched(manufacturer, model, channels): + claimed = match.claim_channels(channels) + for ent_class in stop_match_groups[stop_match_grp][match]: + ent_n_channels = EntityClassAndChannels(ent_class, claimed) + result[component].append(ent_n_channels) + all_claimed |= set(claimed) + if stop_match_grp: + break return result, list(all_claimed) @@ -303,8 +277,8 @@ class ZHAEntityRegistry: def strict_match( self, component: str, - channel_names: Callable | set[str] | str = None, - generic_ids: Callable | set[str] | str = None, + channel_names: set[str] | str = None, + generic_ids: set[str] | str = None, manufacturers: Callable | set[str] | str = None, models: Callable | set[str] | str = None, aux_channels: Callable | set[str] | str = None, @@ -328,12 +302,12 @@ class ZHAEntityRegistry: def multipass_match( self, component: str, - channel_names: Callable | set[str] | str = None, - generic_ids: Callable | set[str] | str = None, + channel_names: set[str] | str = None, + generic_ids: set[str] | str = None, manufacturers: Callable | set[str] | str = None, models: Callable | set[str] | str = None, aux_channels: Callable | set[str] | str = None, - stop_on_match: bool = False, + stop_on_match_group: int | str | None = None, ) -> Callable[[CALLABLE_T], CALLABLE_T]: """Decorate a loose match rule.""" @@ -343,7 +317,6 @@ class ZHAEntityRegistry: manufacturers, models, aux_channels, - stop_on_match, ) def decorator(zha_entity: CALLABLE_T) -> CALLABLE_T: @@ -351,7 +324,10 @@ class ZHAEntityRegistry: All non empty fields of a match rule must match. """ - self._multi_entity_registry[component][rule].append(zha_entity) + # group the rules by channels + self._multi_entity_registry[component][stop_on_match_group][rule].append( + zha_entity + ) return zha_entity return decorator diff --git a/homeassistant/components/zha/cover.py b/homeassistant/components/zha/cover.py index e5ff1f6e95d..2e8c7ab45ea 100644 --- a/homeassistant/components/zha/cover.py +++ b/homeassistant/components/zha/cover.py @@ -42,7 +42,7 @@ from .entity import ZhaEntity _LOGGER = logging.getLogger(__name__) -STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.COVER) +MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.COVER) async def async_setup_entry( @@ -63,7 +63,7 @@ async def async_setup_entry( config_entry.async_on_unload(unsub) -@STRICT_MATCH(channel_names=CHANNEL_COVER) +@MULTI_MATCH(channel_names=CHANNEL_COVER) class ZhaCover(ZhaEntity, CoverEntity): """Representation of a ZHA cover.""" @@ -182,7 +182,7 @@ class ZhaCover(ZhaEntity, CoverEntity): self._state = None -@STRICT_MATCH(channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF, CHANNEL_SHADE}) +@MULTI_MATCH(channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF, CHANNEL_SHADE}) class Shade(ZhaEntity, CoverEntity): """ZHA Shade.""" @@ -289,7 +289,7 @@ class Shade(ZhaEntity, CoverEntity): return -@STRICT_MATCH( +@MULTI_MATCH( channel_names={CHANNEL_LEVEL, CHANNEL_ON_OFF}, manufacturers="Keen Home Inc" ) class KeenVent(Shade): diff --git a/homeassistant/components/zha/lock.py b/homeassistant/components/zha/lock.py index 61104a8b7e9..4eb9752f355 100644 --- a/homeassistant/components/zha/lock.py +++ b/homeassistant/components/zha/lock.py @@ -23,7 +23,7 @@ from .entity import ZhaEntity # The first state is Zigbee 'Not fully locked' STATE_LIST = [STATE_UNLOCKED, STATE_LOCKED, STATE_UNLOCKED] -STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, Platform.LOCK) +MULTI_MATCH = functools.partial(ZHA_ENTITIES.multipass_match, Platform.LOCK) VALUE_TO_STATE = dict(enumerate(STATE_LIST)) @@ -86,7 +86,7 @@ async def async_setup_entry( ) -@STRICT_MATCH(channel_names=CHANNEL_DOORLOCK) +@MULTI_MATCH(channel_names=CHANNEL_DOORLOCK) class ZhaDoorLock(ZhaEntity, LockEntity): """Representation of a ZHA lock.""" diff --git a/homeassistant/components/zha/sensor.py b/homeassistant/components/zha/sensor.py index eed85417b67..406bcbfa7aa 100644 --- a/homeassistant/components/zha/sensor.py +++ b/homeassistant/components/zha/sensor.py @@ -186,19 +186,24 @@ class Sensor(ZhaEntity, SensorEntity): return round(float(value * self._multiplier) / self._divisor) -@STRICT_MATCH( +@MULTI_MATCH( channel_names=CHANNEL_ANALOG_INPUT, manufacturers="LUMI", models={"lumi.plug", "lumi.plug.maus01", "lumi.plug.mmeu01"}, + stop_on_match_group=CHANNEL_ANALOG_INPUT, +) +@MULTI_MATCH( + channel_names=CHANNEL_ANALOG_INPUT, + manufacturers="Digi", + stop_on_match_group=CHANNEL_ANALOG_INPUT, ) -@STRICT_MATCH(channel_names=CHANNEL_ANALOG_INPUT, manufacturers="Digi") class AnalogInput(Sensor): """Sensor that displays analog input values.""" SENSOR_ATTR = "present_value" -@STRICT_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION) +@MULTI_MATCH(channel_names=CHANNEL_POWER_CONFIGURATION) class Battery(Sensor): """Battery sensor of power configuration cluster.""" @@ -339,8 +344,10 @@ class ElectricalMeasurementRMSVoltage(ElectricalMeasurement, id_suffix="rms_volt return False -@STRICT_MATCH(generic_ids=CHANNEL_ST_HUMIDITY_CLUSTER) -@STRICT_MATCH(channel_names=CHANNEL_HUMIDITY) +@MULTI_MATCH( + generic_ids=CHANNEL_ST_HUMIDITY_CLUSTER, stop_on_match_group=CHANNEL_HUMIDITY +) +@MULTI_MATCH(channel_names=CHANNEL_HUMIDITY, stop_on_match_group=CHANNEL_HUMIDITY) class Humidity(Sensor): """Humidity sensor.""" @@ -351,7 +358,7 @@ class Humidity(Sensor): _unit = PERCENTAGE -@STRICT_MATCH(channel_names=CHANNEL_SOIL_MOISTURE) +@MULTI_MATCH(channel_names=CHANNEL_SOIL_MOISTURE) class SoilMoisture(Sensor): """Soil Moisture sensor.""" @@ -362,7 +369,7 @@ class SoilMoisture(Sensor): _unit = PERCENTAGE -@STRICT_MATCH(channel_names=CHANNEL_LEAF_WETNESS) +@MULTI_MATCH(channel_names=CHANNEL_LEAF_WETNESS) class LeafWetness(Sensor): """Leaf Wetness sensor.""" @@ -373,7 +380,7 @@ class LeafWetness(Sensor): _unit = PERCENTAGE -@STRICT_MATCH(channel_names=CHANNEL_ILLUMINANCE) +@MULTI_MATCH(channel_names=CHANNEL_ILLUMINANCE) class Illuminance(Sensor): """Illuminance Sensor.""" @@ -465,7 +472,7 @@ class SmartEnergySummation(SmartEnergyMetering, id_suffix="summation_delivered") return round(cooked, 3) -@STRICT_MATCH(channel_names=CHANNEL_PRESSURE) +@MULTI_MATCH(channel_names=CHANNEL_PRESSURE) class Pressure(Sensor): """Pressure sensor.""" @@ -476,7 +483,7 @@ class Pressure(Sensor): _unit = PRESSURE_HPA -@STRICT_MATCH(channel_names=CHANNEL_TEMPERATURE) +@MULTI_MATCH(channel_names=CHANNEL_TEMPERATURE) class Temperature(Sensor): """Temperature Sensor.""" @@ -487,7 +494,7 @@ class Temperature(Sensor): _unit = TEMP_CELSIUS -@STRICT_MATCH(channel_names="carbon_dioxide_concentration") +@MULTI_MATCH(channel_names="carbon_dioxide_concentration") class CarbonDioxideConcentration(Sensor): """Carbon Dioxide Concentration sensor.""" @@ -499,7 +506,7 @@ class CarbonDioxideConcentration(Sensor): _unit = CONCENTRATION_PARTS_PER_MILLION -@STRICT_MATCH(channel_names="carbon_monoxide_concentration") +@MULTI_MATCH(channel_names="carbon_monoxide_concentration") class CarbonMonoxideConcentration(Sensor): """Carbon Monoxide Concentration sensor.""" @@ -511,8 +518,8 @@ class CarbonMonoxideConcentration(Sensor): _unit = CONCENTRATION_PARTS_PER_MILLION -@STRICT_MATCH(generic_ids="channel_0x042e") -@STRICT_MATCH(channel_names="voc_level") +@MULTI_MATCH(generic_ids="channel_0x042e", stop_on_match_group="voc_level") +@MULTI_MATCH(channel_names="voc_level", stop_on_match_group="voc_level") class VOCLevel(Sensor): """VOC Level sensor.""" @@ -524,7 +531,11 @@ class VOCLevel(Sensor): _unit = CONCENTRATION_MICROGRAMS_PER_CUBIC_METER -@STRICT_MATCH(channel_names="voc_level", models="lumi.airmonitor.acn01") +@MULTI_MATCH( + channel_names="voc_level", + models="lumi.airmonitor.acn01", + stop_on_match_group="voc_level", +) class PPBVOCLevel(Sensor): """VOC Level sensor.""" @@ -536,7 +547,7 @@ class PPBVOCLevel(Sensor): _unit = CONCENTRATION_PARTS_PER_BILLION -@STRICT_MATCH(channel_names="formaldehyde_concentration") +@MULTI_MATCH(channel_names="formaldehyde_concentration") class FormaldehydeConcentration(Sensor): """Formaldehyde Concentration sensor.""" @@ -547,7 +558,7 @@ class FormaldehydeConcentration(Sensor): _unit = CONCENTRATION_PARTS_PER_MILLION -@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT) +@MULTI_MATCH(channel_names=CHANNEL_THERMOSTAT, stop_on_match_group=CHANNEL_THERMOSTAT) class ThermostatHVACAction(Sensor, id_suffix="hvac_action"): """Thermostat HVAC action sensor.""" @@ -626,12 +637,12 @@ class ThermostatHVACAction(Sensor, id_suffix="hvac_action"): aux_channels=CHANNEL_FAN, manufacturers="Centralite", models={"3157100", "3157100-E"}, - stop_on_match=True, + stop_on_match_group=CHANNEL_THERMOSTAT, ) @MULTI_MATCH( channel_names=CHANNEL_THERMOSTAT, manufacturers="Zen Within", - stop_on_match=True, + stop_on_match_group=CHANNEL_THERMOSTAT, ) class ZenHVACAction(ThermostatHVACAction): """Zen Within Thermostat HVAC Action.""" diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index 6b34590f4ad..9480f4b1e65 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -367,7 +367,6 @@ def _test_single_input_cluster_device_class(probe_mock): cover_ch, multistate_ch, ias_ch, - analog_ch, ] disc.ProbeEndpoint().discover_by_cluster_id(ch_pool) @@ -385,11 +384,6 @@ def _test_single_input_cluster_device_class(probe_mock): assert call[0][1] == ch -def test_single_input_cluster_device_class(): - """Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class.""" - _test_single_input_cluster_device_class() - - def test_single_input_cluster_device_class_by_cluster_class(): """Test SINGLE_INPUT_CLUSTER_DEVICE_CLASS matching by cluster id or class.""" mock_reg = {