diff --git a/homeassistant/components/zwave_js/api.py b/homeassistant/components/zwave_js/api.py index 4e68cc2e2dd..0e947de982b 100644 --- a/homeassistant/components/zwave_js/api.py +++ b/homeassistant/components/zwave_js/api.py @@ -60,8 +60,8 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.dispatcher import async_dispatcher_connect +from .config_validation import BITMASK_SCHEMA from .const import ( - BITMASK_SCHEMA, CONF_DATA_COLLECTION_OPTED_IN, DATA_CLIENT, DOMAIN, diff --git a/homeassistant/components/zwave_js/config_validation.py b/homeassistant/components/zwave_js/config_validation.py new file mode 100644 index 00000000000..9fc502bdafb --- /dev/null +++ b/homeassistant/components/zwave_js/config_validation.py @@ -0,0 +1,41 @@ +"""Config validation for the Z-Wave JS integration.""" +from typing import Any + +import voluptuous as vol + +import homeassistant.helpers.config_validation as cv + +# Validates that a bitmask is provided in hex form and converts it to decimal +# int equivalent since that's what the library uses +BITMASK_SCHEMA = vol.All( + cv.string, + vol.Lower, + vol.Match( + r"^(0x)?[0-9a-f]+$", + msg="Must provide an integer (e.g. 255) or a bitmask in hex form (e.g. 0xff)", + ), + lambda value: int(value, 16), +) + + +def boolean(value: Any) -> bool: + """Validate and coerce a boolean value.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + value = value.lower().strip() + if value in ("true", "yes", "on", "enable"): + return True + if value in ("false", "no", "off", "disable"): + return False + raise vol.Invalid(f"invalid boolean value {value}") + + +VALUE_SCHEMA = vol.Any( + boolean, + vol.Coerce(int), + vol.Coerce(float), + BITMASK_SCHEMA, + cv.string, + dict, +) diff --git a/homeassistant/components/zwave_js/const.py b/homeassistant/components/zwave_js/const.py index 8f6fada2106..d6d63487b8a 100644 --- a/homeassistant/components/zwave_js/const.py +++ b/homeassistant/components/zwave_js/const.py @@ -1,10 +1,6 @@ """Constants for the Z-Wave JS integration.""" import logging -import voluptuous as vol - -import homeassistant.helpers.config_validation as cv - CONF_ADDON_DEVICE = "device" CONF_ADDON_EMULATE_HARDWARE = "emulate_hardware" CONF_ADDON_LOG_LEVEL = "log_level" @@ -117,26 +113,3 @@ ENTITY_DESC_KEY_TEMPERATURE = "temperature" ENTITY_DESC_KEY_TARGET_TEMPERATURE = "target_temperature" ENTITY_DESC_KEY_MEASUREMENT = "measurement" ENTITY_DESC_KEY_TOTAL_INCREASING = "total_increasing" - -# Schema Constants - -# Validates that a bitmask is provided in hex form and converts it to decimal -# int equivalent since that's what the library uses -BITMASK_SCHEMA = vol.All( - cv.string, - vol.Lower, - vol.Match( - r"^(0x)?[0-9a-f]+$", - msg="Must provide an integer (e.g. 255) or a bitmask in hex form (e.g. 0xff)", - ), - lambda value: int(value, 16), -) - -VALUE_SCHEMA = vol.Any( - bool, - vol.Coerce(int), - vol.Coerce(float), - BITMASK_SCHEMA, - cv.string, - dict, -) diff --git a/homeassistant/components/zwave_js/device_action.py b/homeassistant/components/zwave_js/device_action.py index b81d675e6fd..7e5e8c6c78d 100644 --- a/homeassistant/components/zwave_js/device_action.py +++ b/homeassistant/components/zwave_js/device_action.py @@ -29,6 +29,7 @@ from homeassistant.helpers import entity_registry import homeassistant.helpers.config_validation as cv from homeassistant.helpers.typing import ConfigType +from .config_validation import VALUE_SCHEMA from .const import ( ATTR_COMMAND_CLASS, ATTR_CONFIG_PARAMETER, @@ -48,7 +49,6 @@ from .const import ( SERVICE_SET_CONFIG_PARAMETER, SERVICE_SET_LOCK_USERCODE, SERVICE_SET_VALUE, - VALUE_SCHEMA, ) from .device_automation_helpers import ( CONF_SUBTYPE, diff --git a/homeassistant/components/zwave_js/device_condition.py b/homeassistant/components/zwave_js/device_condition.py index 8bb151199d7..c70371d6f8a 100644 --- a/homeassistant/components/zwave_js/device_condition.py +++ b/homeassistant/components/zwave_js/device_condition.py @@ -16,6 +16,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import condition, config_validation as cv from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from .config_validation import VALUE_SCHEMA from .const import ( ATTR_COMMAND_CLASS, ATTR_ENDPOINT, @@ -23,7 +24,6 @@ from .const import ( ATTR_PROPERTY_KEY, ATTR_VALUE, DOMAIN, - VALUE_SCHEMA, ) from .device_automation_helpers import ( CONF_SUBTYPE, diff --git a/homeassistant/components/zwave_js/device_trigger.py b/homeassistant/components/zwave_js/device_trigger.py index 0615668cccd..89379f9a953 100644 --- a/homeassistant/components/zwave_js/device_trigger.py +++ b/homeassistant/components/zwave_js/device_trigger.py @@ -32,6 +32,7 @@ from homeassistant.helpers import ( from homeassistant.helpers.typing import ConfigType from . import trigger +from .config_validation import VALUE_SCHEMA from .const import ( ATTR_COMMAND_CLASS, ATTR_DATA_TYPE, @@ -80,14 +81,6 @@ CONFIG_PARAMETER_VALUE_UPDATED = f"{VALUE_UPDATED_PLATFORM_TYPE}.config_paramete VALUE_VALUE_UPDATED = f"{VALUE_UPDATED_PLATFORM_TYPE}.value" NODE_STATUS = "state.node_status" -VALUE_SCHEMA = vol.Any( - bool, - vol.Coerce(int), - vol.Coerce(float), - cv.boolean, - cv.string, -) - NOTIFICATION_EVENT_CC_MAPPINGS = ( (ENTRY_CONTROL_NOTIFICATION, CommandClass.ENTRY_CONTROL), diff --git a/homeassistant/components/zwave_js/helpers.py b/homeassistant/components/zwave_js/helpers.py index 05df480a487..a84ddee300f 100644 --- a/homeassistant/components/zwave_js/helpers.py +++ b/homeassistant/components/zwave_js/helpers.py @@ -14,9 +14,16 @@ from zwave_js_server.model.value import ( get_value_id, ) +from homeassistant.components.group import expand_entity_ids from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.config_entries import ConfigEntry, ConfigEntryState -from homeassistant.const import CONF_TYPE, __version__ as HA_VERSION +from homeassistant.const import ( + ATTR_AREA_ID, + ATTR_DEVICE_ID, + ATTR_ENTITY_ID, + CONF_TYPE, + __version__ as HA_VERSION, +) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, entity_registry as er @@ -30,6 +37,7 @@ from .const import ( CONF_DATA_COLLECTION_OPTED_IN, DATA_CLIENT, DOMAIN, + LOGGER, ) @@ -221,6 +229,40 @@ def async_get_nodes_from_area_id( return nodes +@callback +def async_get_nodes_from_targets( + hass: HomeAssistant, + val: dict[str, Any], + ent_reg: er.EntityRegistry | None = None, + dev_reg: dr.DeviceRegistry | None = None, +) -> set[ZwaveNode]: + """ + Get nodes for all targets. + + Supports entity_id with group expansion, area_id, and device_id. + """ + nodes: set[ZwaveNode] = set() + # Convert all entity IDs to nodes + for entity_id in expand_entity_ids(hass, val.get(ATTR_ENTITY_ID, [])): + try: + nodes.add(async_get_node_from_entity_id(hass, entity_id, ent_reg, dev_reg)) + except ValueError as err: + LOGGER.warning(err.args[0]) + + # Convert all area IDs to nodes + for area_id in val.get(ATTR_AREA_ID, []): + nodes.update(async_get_nodes_from_area_id(hass, area_id, ent_reg, dev_reg)) + + # Convert all device IDs to nodes + for device_id in val.get(ATTR_DEVICE_ID, []): + try: + nodes.add(async_get_node_from_device_id(hass, device_id, dev_reg)) + except ValueError as err: + LOGGER.warning(err.args[0]) + + return nodes + + def get_zwave_value_from_config(node: ZwaveNode, config: ConfigType) -> ZwaveValue: """Get a Z-Wave JS Value from a config.""" endpoint = None diff --git a/homeassistant/components/zwave_js/services.py b/homeassistant/components/zwave_js/services.py index 767516cc17c..2d31bed108f 100644 --- a/homeassistant/components/zwave_js/services.py +++ b/homeassistant/components/zwave_js/services.py @@ -25,11 +25,8 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from . import const -from .helpers import ( - async_get_node_from_device_id, - async_get_node_from_entity_id, - async_get_nodes_from_area_id, -) +from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA +from .helpers import async_get_nodes_from_targets _LOGGER = logging.getLogger(__name__) @@ -80,38 +77,16 @@ class ZWaveServices: @callback def get_nodes_from_service_data(val: dict[str, Any]) -> dict[str, Any]: """Get nodes set from service data.""" - nodes: set[ZwaveNode] = set() - # Convert all entity IDs to nodes - for entity_id in expand_entity_ids(self._hass, val.pop(ATTR_ENTITY_ID, [])): - try: - nodes.add( - async_get_node_from_entity_id( - self._hass, entity_id, self._ent_reg, self._dev_reg - ) - ) - except ValueError as err: - const.LOGGER.warning(err.args[0]) + val[const.ATTR_NODES] = async_get_nodes_from_targets( + self._hass, val, self._ent_reg, self._dev_reg + ) + return val - # Convert all area IDs to nodes - for area_id in val.pop(ATTR_AREA_ID, []): - nodes.update( - async_get_nodes_from_area_id( - self._hass, area_id, self._ent_reg, self._dev_reg - ) - ) - - # Convert all device IDs to nodes - for device_id in val.pop(ATTR_DEVICE_ID, []): - try: - nodes.add( - async_get_node_from_device_id( - self._hass, device_id, self._dev_reg - ) - ) - except ValueError as err: - const.LOGGER.warning(err.args[0]) - - val[const.ATTR_NODES] = nodes + @callback + def has_at_least_one_node(val: dict[str, Any]) -> dict[str, Any]: + """Validate that at least one node is specified.""" + if not val.get(const.ATTR_NODES): + raise vol.Invalid(f"No {const.DOMAIN} nodes found for given targets") return val @callback @@ -120,6 +95,9 @@ class ZWaveServices: nodes: set[ZwaveNode] = val[const.ATTR_NODES] broadcast: bool = val[const.ATTR_BROADCAST] + if not broadcast: + has_at_least_one_node(val) + # User must specify a node if they are attempting a broadcast and have more # than one zwave-js network. if ( @@ -150,12 +128,20 @@ class ZWaveServices: def validate_entities(val: dict[str, Any]) -> dict[str, Any]: """Validate entities exist and are from the zwave_js platform.""" val[ATTR_ENTITY_ID] = expand_entity_ids(self._hass, val[ATTR_ENTITY_ID]) + invalid_entities = [] for entity_id in val[ATTR_ENTITY_ID]: entry = self._ent_reg.async_get(entity_id) if entry is None or entry.platform != const.DOMAIN: - raise vol.Invalid( - f"Entity {entity_id} is not a valid {const.DOMAIN} entity." + const.LOGGER.info( + "Entity %s is not a valid %s entity.", entity_id, const.DOMAIN ) + invalid_entities.append(entity_id) + + # Remove invalid entities + val[ATTR_ENTITY_ID] = list(set(val[ATTR_ENTITY_ID]) - set(invalid_entities)) + + if not val[ATTR_ENTITY_ID]: + raise vol.Invalid(f"No {const.DOMAIN} entities found in service call") return val @@ -177,10 +163,10 @@ class ZWaveServices: vol.Coerce(int), cv.string ), vol.Optional(const.ATTR_CONFIG_PARAMETER_BITMASK): vol.Any( - vol.Coerce(int), const.BITMASK_SCHEMA + vol.Coerce(int), BITMASK_SCHEMA ), vol.Required(const.ATTR_CONFIG_VALUE): vol.Any( - vol.Coerce(int), const.BITMASK_SCHEMA, cv.string + vol.Coerce(int), BITMASK_SCHEMA, cv.string ), }, cv.has_at_least_one_key( @@ -188,6 +174,7 @@ class ZWaveServices: ), parameter_name_does_not_need_bitmask, get_nodes_from_service_data, + has_at_least_one_node, ), ), ) @@ -211,10 +198,8 @@ class ZWaveServices: vol.Coerce(int), { vol.Any( - vol.Coerce(int), const.BITMASK_SCHEMA, cv.string - ): vol.Any( - vol.Coerce(int), const.BITMASK_SCHEMA, cv.string - ) + vol.Coerce(int), BITMASK_SCHEMA, cv.string + ): vol.Any(vol.Coerce(int), BITMASK_SCHEMA, cv.string) }, ), }, @@ -222,6 +207,7 @@ class ZWaveServices: ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID ), get_nodes_from_service_data, + has_at_least_one_node, ), ), ) @@ -265,16 +251,15 @@ class ZWaveServices: vol.Coerce(int), str ), vol.Optional(const.ATTR_ENDPOINT): vol.Coerce(int), - vol.Required(const.ATTR_VALUE): const.VALUE_SCHEMA, + vol.Required(const.ATTR_VALUE): VALUE_SCHEMA, vol.Optional(const.ATTR_WAIT_FOR_RESULT): cv.boolean, - vol.Optional(const.ATTR_OPTIONS): { - cv.string: const.VALUE_SCHEMA - }, + vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA}, }, cv.has_at_least_one_key( ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID ), get_nodes_from_service_data, + has_at_least_one_node, ), ), ) @@ -302,10 +287,8 @@ class ZWaveServices: vol.Coerce(int), str ), vol.Optional(const.ATTR_ENDPOINT): vol.Coerce(int), - vol.Required(const.ATTR_VALUE): const.VALUE_SCHEMA, - vol.Optional(const.ATTR_OPTIONS): { - cv.string: const.VALUE_SCHEMA - }, + vol.Required(const.ATTR_VALUE): VALUE_SCHEMA, + vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA}, }, vol.Any( cv.has_at_least_one_key( @@ -338,6 +321,7 @@ class ZWaveServices: ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID ), get_nodes_from_service_data, + has_at_least_one_node, ), ), ) diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 110cd21294f..fd46c89832b 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -20,13 +20,13 @@ from homeassistant.components.zwave_js.const import ( ATTR_EVENT_DATA, ATTR_EVENT_SOURCE, ATTR_NODE_ID, + ATTR_NODES, ATTR_PARTIAL_DICT_MATCH, DATA_CLIENT, DOMAIN, ) from homeassistant.components.zwave_js.helpers import ( - async_get_node_from_device_id, - async_get_node_from_entity_id, + async_get_nodes_from_targets, get_device_id, get_home_and_node_id_from_device_entry, ) @@ -111,6 +111,13 @@ async def async_validate_trigger_config( """Validate config.""" config = TRIGGER_SCHEMA(config) + if config[ATTR_EVENT_SOURCE] == "node": + config[ATTR_NODES] = async_get_nodes_from_targets(hass, config) + if not config[ATTR_NODES]: + raise vol.Invalid( + f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." + ) + if ATTR_CONFIG_ENTRY_ID not in config: return config @@ -133,21 +140,7 @@ async def async_attach_trigger( platform_type: str = PLATFORM_TYPE, ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" - nodes: set[Node] = set() - if ATTR_DEVICE_ID in config: - nodes.update( - { - async_get_node_from_device_id(hass, device_id) - for device_id in config[ATTR_DEVICE_ID] - } - ) - if ATTR_ENTITY_ID in config: - nodes.update( - { - async_get_node_from_entity_id(hass, entity_id) - for entity_id in config[ATTR_ENTITY_ID] - } - ) + nodes: set[Node] = config.get(ATTR_NODES, {}) event_source = config[ATTR_EVENT_SOURCE] event_name = config[ATTR_EVENT] diff --git a/homeassistant/components/zwave_js/triggers/value_updated.py b/homeassistant/components/zwave_js/triggers/value_updated.py index 71223c4ef1e..8a0b287c26b 100644 --- a/homeassistant/components/zwave_js/triggers/value_updated.py +++ b/homeassistant/components/zwave_js/triggers/value_updated.py @@ -20,6 +20,7 @@ from homeassistant.components.zwave_js.const import ( ATTR_CURRENT_VALUE_RAW, ATTR_ENDPOINT, ATTR_NODE_ID, + ATTR_NODES, ATTR_PREVIOUS_VALUE, ATTR_PREVIOUS_VALUE_RAW, ATTR_PROPERTY, @@ -29,8 +30,7 @@ from homeassistant.components.zwave_js.const import ( DOMAIN, ) from homeassistant.components.zwave_js.helpers import ( - async_get_node_from_device_id, - async_get_node_from_entity_id, + async_get_nodes_from_targets, get_device_id, ) from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, MATCH_ALL @@ -38,20 +38,14 @@ from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers.typing import ConfigType +from ..config_validation import VALUE_SCHEMA + # Platform type should be . PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}" ATTR_FROM = "from" ATTR_TO = "to" -VALUE_SCHEMA = vol.Any( - bool, - vol.Coerce(int), - vol.Coerce(float), - cv.boolean, - cv.string, -) - TRIGGER_SCHEMA = vol.All( cv.TRIGGER_BASE_SCHEMA.extend( { @@ -76,6 +70,20 @@ TRIGGER_SCHEMA = vol.All( ) +async def async_validate_trigger_config( + hass: HomeAssistant, config: ConfigType +) -> ConfigType: + """Validate config.""" + config = TRIGGER_SCHEMA(config) + + config[ATTR_NODES] = async_get_nodes_from_targets(hass, config) + if not config[ATTR_NODES]: + raise vol.Invalid( + f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." + ) + return config + + async def async_attach_trigger( hass: HomeAssistant, config: ConfigType, @@ -85,21 +93,7 @@ async def async_attach_trigger( platform_type: str = PLATFORM_TYPE, ) -> CALLBACK_TYPE: """Listen for state changes based on configuration.""" - nodes: set[Node] = set() - if ATTR_DEVICE_ID in config: - nodes.update( - { - async_get_node_from_device_id(hass, device_id) - for device_id in config.get(ATTR_DEVICE_ID, []) - } - ) - if ATTR_ENTITY_ID in config: - nodes.update( - { - async_get_node_from_entity_id(hass, entity_id) - for entity_id in config.get(ATTR_ENTITY_ID, []) - } - ) + nodes: set[Node] = config[ATTR_NODES] from_value = config[ATTR_FROM] to_value = config[ATTR_TO] diff --git a/tests/components/zwave_js/test_config_validation.py b/tests/components/zwave_js/test_config_validation.py new file mode 100644 index 00000000000..5a390ff0290 --- /dev/null +++ b/tests/components/zwave_js/test_config_validation.py @@ -0,0 +1,26 @@ +"""Test the Z-Wave JS config validation helpers.""" +import pytest +import voluptuous as vol + +from homeassistant.components.zwave_js.config_validation import boolean + + +def test_boolean_validation(): + """Test boolean config validator.""" + # test bool + assert boolean(True) + assert not boolean(False) + # test strings + assert boolean("TRUE") + assert not boolean("FALSE") + assert boolean("ON") + assert not boolean("NO") + # ensure 1's and 0's don't get converted to bool + with pytest.raises(vol.Invalid): + boolean("1") + with pytest.raises(vol.Invalid): + boolean("0") + with pytest.raises(vol.Invalid): + boolean(1) + with pytest.raises(vol.Invalid): + boolean(0) diff --git a/tests/components/zwave_js/test_services.py b/tests/components/zwave_js/test_services.py index 571190bd35c..09189ad9230 100644 --- a/tests/components/zwave_js/test_services.py +++ b/tests/components/zwave_js/test_services.py @@ -498,6 +498,20 @@ async def test_set_config_parameter(hass, client, multisensor_6, integration): client.async_send_command.reset_mock() + # Test setting config parameter with no valid nodes raises Exception + with pytest.raises(vol.MultipleInvalid): + await hass.services.async_call( + DOMAIN, + SERVICE_SET_CONFIG_PARAMETER, + { + ATTR_ENTITY_ID: "sensor.fake", + ATTR_CONFIG_PARAMETER: 102, + ATTR_CONFIG_PARAMETER_BITMASK: 1, + ATTR_CONFIG_VALUE: 1, + }, + blocking=True, + ) + async def test_bulk_set_config_parameters(hass, client, multisensor_6, integration): """Test the bulk_set_partial_config_parameters service.""" @@ -1345,8 +1359,8 @@ async def test_multicast_set_value( diff_network_node.client.driver.controller.home_id.return_value = "diff_home_id" with pytest.raises(vol.MultipleInvalid), patch( - "homeassistant.components.zwave_js.services.async_get_node_from_device_id", - return_value=diff_network_node, + "homeassistant.components.zwave_js.helpers.async_get_node_from_device_id", + side_effect=(climate_danfoss_lc_13, diff_network_node), ): await hass.services.async_call( DOMAIN, diff --git a/tests/components/zwave_js/test_trigger.py b/tests/components/zwave_js/test_trigger.py index 5dbeff87a54..45de09e8b17 100644 --- a/tests/components/zwave_js/test_trigger.py +++ b/tests/components/zwave_js/test_trigger.py @@ -1,6 +1,8 @@ """The tests for Z-Wave JS automation triggers.""" from unittest.mock import AsyncMock, patch +import pytest +import voluptuous as vol from zwave_js_server.const import CommandClass from zwave_js_server.event import Event from zwave_js_server.model.node import Node @@ -708,3 +710,28 @@ async def test_async_validate_trigger_config(hass): mock_platform.async_validate_trigger_config.return_value = {} await async_validate_trigger_config(hass, {}) mock_platform.async_validate_trigger_config.assert_awaited() + + +async def test_invalid_trigger_configs(hass): + """Test invalid trigger configs.""" + with pytest.raises(vol.Invalid): + await async_validate_trigger_config( + hass, + { + "platform": f"{DOMAIN}.event", + "entity_id": "fake.entity", + "event_source": "node", + "event": "value updated", + }, + ) + + with pytest.raises(vol.Invalid): + await async_validate_trigger_config( + hass, + { + "platform": f"{DOMAIN}.value_updated", + "entity_id": "fake.entity", + "command_class": CommandClass.DOOR_LOCK.value, + "property": "latchStatus", + }, + )