Improve zwave_js custom triggers and services (#67461)

* Improve zwave_js custom triggers and services

* Switch from pop to get

* Support string boolean values

* refactor and add coverage

* comments and additional assertions
This commit is contained in:
Raman Gupta 2022-03-05 03:00:31 -05:00 committed by GitHub
parent cdb463ea55
commit 9632cbeffa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 222 additions and 135 deletions

View file

@ -60,8 +60,8 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .config_validation import BITMASK_SCHEMA
from .const import (
BITMASK_SCHEMA,
CONF_DATA_COLLECTION_OPTED_IN,
DATA_CLIENT,
DOMAIN,

View file

@ -0,0 +1,41 @@
"""Config validation for the Z-Wave JS integration."""
from typing import Any
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
# Validates that a bitmask is provided in hex form and converts it to decimal
# int equivalent since that's what the library uses
BITMASK_SCHEMA = vol.All(
cv.string,
vol.Lower,
vol.Match(
r"^(0x)?[0-9a-f]+$",
msg="Must provide an integer (e.g. 255) or a bitmask in hex form (e.g. 0xff)",
),
lambda value: int(value, 16),
)
def boolean(value: Any) -> bool:
"""Validate and coerce a boolean value."""
if isinstance(value, bool):
return value
if isinstance(value, str):
value = value.lower().strip()
if value in ("true", "yes", "on", "enable"):
return True
if value in ("false", "no", "off", "disable"):
return False
raise vol.Invalid(f"invalid boolean value {value}")
VALUE_SCHEMA = vol.Any(
boolean,
vol.Coerce(int),
vol.Coerce(float),
BITMASK_SCHEMA,
cv.string,
dict,
)

View file

@ -1,10 +1,6 @@
"""Constants for the Z-Wave JS integration."""
import logging
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
CONF_ADDON_DEVICE = "device"
CONF_ADDON_EMULATE_HARDWARE = "emulate_hardware"
CONF_ADDON_LOG_LEVEL = "log_level"
@ -117,26 +113,3 @@ ENTITY_DESC_KEY_TEMPERATURE = "temperature"
ENTITY_DESC_KEY_TARGET_TEMPERATURE = "target_temperature"
ENTITY_DESC_KEY_MEASUREMENT = "measurement"
ENTITY_DESC_KEY_TOTAL_INCREASING = "total_increasing"
# Schema Constants
# Validates that a bitmask is provided in hex form and converts it to decimal
# int equivalent since that's what the library uses
BITMASK_SCHEMA = vol.All(
cv.string,
vol.Lower,
vol.Match(
r"^(0x)?[0-9a-f]+$",
msg="Must provide an integer (e.g. 255) or a bitmask in hex form (e.g. 0xff)",
),
lambda value: int(value, 16),
)
VALUE_SCHEMA = vol.Any(
bool,
vol.Coerce(int),
vol.Coerce(float),
BITMASK_SCHEMA,
cv.string,
dict,
)

View file

@ -29,6 +29,7 @@ from homeassistant.helpers import entity_registry
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType
from .config_validation import VALUE_SCHEMA
from .const import (
ATTR_COMMAND_CLASS,
ATTR_CONFIG_PARAMETER,
@ -48,7 +49,6 @@ from .const import (
SERVICE_SET_CONFIG_PARAMETER,
SERVICE_SET_LOCK_USERCODE,
SERVICE_SET_VALUE,
VALUE_SCHEMA,
)
from .device_automation_helpers import (
CONF_SUBTYPE,

View file

@ -16,6 +16,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, config_validation as cv
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from .config_validation import VALUE_SCHEMA
from .const import (
ATTR_COMMAND_CLASS,
ATTR_ENDPOINT,
@ -23,7 +24,6 @@ from .const import (
ATTR_PROPERTY_KEY,
ATTR_VALUE,
DOMAIN,
VALUE_SCHEMA,
)
from .device_automation_helpers import (
CONF_SUBTYPE,

View file

@ -32,6 +32,7 @@ from homeassistant.helpers import (
from homeassistant.helpers.typing import ConfigType
from . import trigger
from .config_validation import VALUE_SCHEMA
from .const import (
ATTR_COMMAND_CLASS,
ATTR_DATA_TYPE,
@ -80,14 +81,6 @@ CONFIG_PARAMETER_VALUE_UPDATED = f"{VALUE_UPDATED_PLATFORM_TYPE}.config_paramete
VALUE_VALUE_UPDATED = f"{VALUE_UPDATED_PLATFORM_TYPE}.value"
NODE_STATUS = "state.node_status"
VALUE_SCHEMA = vol.Any(
bool,
vol.Coerce(int),
vol.Coerce(float),
cv.boolean,
cv.string,
)
NOTIFICATION_EVENT_CC_MAPPINGS = (
(ENTRY_CONTROL_NOTIFICATION, CommandClass.ENTRY_CONTROL),

View file

@ -14,9 +14,16 @@ from zwave_js_server.model.value import (
get_value_id,
)
from homeassistant.components.group import expand_entity_ids
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_TYPE, __version__ as HA_VERSION
from homeassistant.const import (
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
CONF_TYPE,
__version__ as HA_VERSION,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
@ -30,6 +37,7 @@ from .const import (
CONF_DATA_COLLECTION_OPTED_IN,
DATA_CLIENT,
DOMAIN,
LOGGER,
)
@ -221,6 +229,40 @@ def async_get_nodes_from_area_id(
return nodes
@callback
def async_get_nodes_from_targets(
hass: HomeAssistant,
val: dict[str, Any],
ent_reg: er.EntityRegistry | None = None,
dev_reg: dr.DeviceRegistry | None = None,
) -> set[ZwaveNode]:
"""
Get nodes for all targets.
Supports entity_id with group expansion, area_id, and device_id.
"""
nodes: set[ZwaveNode] = set()
# Convert all entity IDs to nodes
for entity_id in expand_entity_ids(hass, val.get(ATTR_ENTITY_ID, [])):
try:
nodes.add(async_get_node_from_entity_id(hass, entity_id, ent_reg, dev_reg))
except ValueError as err:
LOGGER.warning(err.args[0])
# Convert all area IDs to nodes
for area_id in val.get(ATTR_AREA_ID, []):
nodes.update(async_get_nodes_from_area_id(hass, area_id, ent_reg, dev_reg))
# Convert all device IDs to nodes
for device_id in val.get(ATTR_DEVICE_ID, []):
try:
nodes.add(async_get_node_from_device_id(hass, device_id, dev_reg))
except ValueError as err:
LOGGER.warning(err.args[0])
return nodes
def get_zwave_value_from_config(node: ZwaveNode, config: ConfigType) -> ZwaveValue:
"""Get a Z-Wave JS Value from a config."""
endpoint = None

View file

@ -25,11 +25,8 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from . import const
from .helpers import (
async_get_node_from_device_id,
async_get_node_from_entity_id,
async_get_nodes_from_area_id,
)
from .config_validation import BITMASK_SCHEMA, VALUE_SCHEMA
from .helpers import async_get_nodes_from_targets
_LOGGER = logging.getLogger(__name__)
@ -80,38 +77,16 @@ class ZWaveServices:
@callback
def get_nodes_from_service_data(val: dict[str, Any]) -> dict[str, Any]:
"""Get nodes set from service data."""
nodes: set[ZwaveNode] = set()
# Convert all entity IDs to nodes
for entity_id in expand_entity_ids(self._hass, val.pop(ATTR_ENTITY_ID, [])):
try:
nodes.add(
async_get_node_from_entity_id(
self._hass, entity_id, self._ent_reg, self._dev_reg
)
)
except ValueError as err:
const.LOGGER.warning(err.args[0])
val[const.ATTR_NODES] = async_get_nodes_from_targets(
self._hass, val, self._ent_reg, self._dev_reg
)
return val
# Convert all area IDs to nodes
for area_id in val.pop(ATTR_AREA_ID, []):
nodes.update(
async_get_nodes_from_area_id(
self._hass, area_id, self._ent_reg, self._dev_reg
)
)
# Convert all device IDs to nodes
for device_id in val.pop(ATTR_DEVICE_ID, []):
try:
nodes.add(
async_get_node_from_device_id(
self._hass, device_id, self._dev_reg
)
)
except ValueError as err:
const.LOGGER.warning(err.args[0])
val[const.ATTR_NODES] = nodes
@callback
def has_at_least_one_node(val: dict[str, Any]) -> dict[str, Any]:
"""Validate that at least one node is specified."""
if not val.get(const.ATTR_NODES):
raise vol.Invalid(f"No {const.DOMAIN} nodes found for given targets")
return val
@callback
@ -120,6 +95,9 @@ class ZWaveServices:
nodes: set[ZwaveNode] = val[const.ATTR_NODES]
broadcast: bool = val[const.ATTR_BROADCAST]
if not broadcast:
has_at_least_one_node(val)
# User must specify a node if they are attempting a broadcast and have more
# than one zwave-js network.
if (
@ -150,12 +128,20 @@ class ZWaveServices:
def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
"""Validate entities exist and are from the zwave_js platform."""
val[ATTR_ENTITY_ID] = expand_entity_ids(self._hass, val[ATTR_ENTITY_ID])
invalid_entities = []
for entity_id in val[ATTR_ENTITY_ID]:
entry = self._ent_reg.async_get(entity_id)
if entry is None or entry.platform != const.DOMAIN:
raise vol.Invalid(
f"Entity {entity_id} is not a valid {const.DOMAIN} entity."
const.LOGGER.info(
"Entity %s is not a valid %s entity.", entity_id, const.DOMAIN
)
invalid_entities.append(entity_id)
# Remove invalid entities
val[ATTR_ENTITY_ID] = list(set(val[ATTR_ENTITY_ID]) - set(invalid_entities))
if not val[ATTR_ENTITY_ID]:
raise vol.Invalid(f"No {const.DOMAIN} entities found in service call")
return val
@ -177,10 +163,10 @@ class ZWaveServices:
vol.Coerce(int), cv.string
),
vol.Optional(const.ATTR_CONFIG_PARAMETER_BITMASK): vol.Any(
vol.Coerce(int), const.BITMASK_SCHEMA
vol.Coerce(int), BITMASK_SCHEMA
),
vol.Required(const.ATTR_CONFIG_VALUE): vol.Any(
vol.Coerce(int), const.BITMASK_SCHEMA, cv.string
vol.Coerce(int), BITMASK_SCHEMA, cv.string
),
},
cv.has_at_least_one_key(
@ -188,6 +174,7 @@ class ZWaveServices:
),
parameter_name_does_not_need_bitmask,
get_nodes_from_service_data,
has_at_least_one_node,
),
),
)
@ -211,10 +198,8 @@ class ZWaveServices:
vol.Coerce(int),
{
vol.Any(
vol.Coerce(int), const.BITMASK_SCHEMA, cv.string
): vol.Any(
vol.Coerce(int), const.BITMASK_SCHEMA, cv.string
)
vol.Coerce(int), BITMASK_SCHEMA, cv.string
): vol.Any(vol.Coerce(int), BITMASK_SCHEMA, cv.string)
},
),
},
@ -222,6 +207,7 @@ class ZWaveServices:
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
has_at_least_one_node,
),
),
)
@ -265,16 +251,15 @@ class ZWaveServices:
vol.Coerce(int), str
),
vol.Optional(const.ATTR_ENDPOINT): vol.Coerce(int),
vol.Required(const.ATTR_VALUE): const.VALUE_SCHEMA,
vol.Required(const.ATTR_VALUE): VALUE_SCHEMA,
vol.Optional(const.ATTR_WAIT_FOR_RESULT): cv.boolean,
vol.Optional(const.ATTR_OPTIONS): {
cv.string: const.VALUE_SCHEMA
},
vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA},
},
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
has_at_least_one_node,
),
),
)
@ -302,10 +287,8 @@ class ZWaveServices:
vol.Coerce(int), str
),
vol.Optional(const.ATTR_ENDPOINT): vol.Coerce(int),
vol.Required(const.ATTR_VALUE): const.VALUE_SCHEMA,
vol.Optional(const.ATTR_OPTIONS): {
cv.string: const.VALUE_SCHEMA
},
vol.Required(const.ATTR_VALUE): VALUE_SCHEMA,
vol.Optional(const.ATTR_OPTIONS): {cv.string: VALUE_SCHEMA},
},
vol.Any(
cv.has_at_least_one_key(
@ -338,6 +321,7 @@ class ZWaveServices:
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
get_nodes_from_service_data,
has_at_least_one_node,
),
),
)

View file

@ -20,13 +20,13 @@ from homeassistant.components.zwave_js.const import (
ATTR_EVENT_DATA,
ATTR_EVENT_SOURCE,
ATTR_NODE_ID,
ATTR_NODES,
ATTR_PARTIAL_DICT_MATCH,
DATA_CLIENT,
DOMAIN,
)
from homeassistant.components.zwave_js.helpers import (
async_get_node_from_device_id,
async_get_node_from_entity_id,
async_get_nodes_from_targets,
get_device_id,
get_home_and_node_id_from_device_entry,
)
@ -111,6 +111,13 @@ async def async_validate_trigger_config(
"""Validate config."""
config = TRIGGER_SCHEMA(config)
if config[ATTR_EVENT_SOURCE] == "node":
config[ATTR_NODES] = async_get_nodes_from_targets(hass, config)
if not config[ATTR_NODES]:
raise vol.Invalid(
f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
)
if ATTR_CONFIG_ENTRY_ID not in config:
return config
@ -133,21 +140,7 @@ async def async_attach_trigger(
platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration."""
nodes: set[Node] = set()
if ATTR_DEVICE_ID in config:
nodes.update(
{
async_get_node_from_device_id(hass, device_id)
for device_id in config[ATTR_DEVICE_ID]
}
)
if ATTR_ENTITY_ID in config:
nodes.update(
{
async_get_node_from_entity_id(hass, entity_id)
for entity_id in config[ATTR_ENTITY_ID]
}
)
nodes: set[Node] = config.get(ATTR_NODES, {})
event_source = config[ATTR_EVENT_SOURCE]
event_name = config[ATTR_EVENT]

View file

@ -20,6 +20,7 @@ from homeassistant.components.zwave_js.const import (
ATTR_CURRENT_VALUE_RAW,
ATTR_ENDPOINT,
ATTR_NODE_ID,
ATTR_NODES,
ATTR_PREVIOUS_VALUE,
ATTR_PREVIOUS_VALUE_RAW,
ATTR_PROPERTY,
@ -29,8 +30,7 @@ from homeassistant.components.zwave_js.const import (
DOMAIN,
)
from homeassistant.components.zwave_js.helpers import (
async_get_node_from_device_id,
async_get_node_from_entity_id,
async_get_nodes_from_targets,
get_device_id,
)
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_PLATFORM, MATCH_ALL
@ -38,20 +38,14 @@ from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, device_registry as dr
from homeassistant.helpers.typing import ConfigType
from ..config_validation import VALUE_SCHEMA
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
ATTR_FROM = "from"
ATTR_TO = "to"
VALUE_SCHEMA = vol.Any(
bool,
vol.Coerce(int),
vol.Coerce(float),
cv.boolean,
cv.string,
)
TRIGGER_SCHEMA = vol.All(
cv.TRIGGER_BASE_SCHEMA.extend(
{
@ -76,6 +70,20 @@ TRIGGER_SCHEMA = vol.All(
)
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
config = TRIGGER_SCHEMA(config)
config[ATTR_NODES] = async_get_nodes_from_targets(hass, config)
if not config[ATTR_NODES]:
raise vol.Invalid(
f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
)
return config
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
@ -85,21 +93,7 @@ async def async_attach_trigger(
platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration."""
nodes: set[Node] = set()
if ATTR_DEVICE_ID in config:
nodes.update(
{
async_get_node_from_device_id(hass, device_id)
for device_id in config.get(ATTR_DEVICE_ID, [])
}
)
if ATTR_ENTITY_ID in config:
nodes.update(
{
async_get_node_from_entity_id(hass, entity_id)
for entity_id in config.get(ATTR_ENTITY_ID, [])
}
)
nodes: set[Node] = config[ATTR_NODES]
from_value = config[ATTR_FROM]
to_value = config[ATTR_TO]

View file

@ -0,0 +1,26 @@
"""Test the Z-Wave JS config validation helpers."""
import pytest
import voluptuous as vol
from homeassistant.components.zwave_js.config_validation import boolean
def test_boolean_validation():
"""Test boolean config validator."""
# test bool
assert boolean(True)
assert not boolean(False)
# test strings
assert boolean("TRUE")
assert not boolean("FALSE")
assert boolean("ON")
assert not boolean("NO")
# ensure 1's and 0's don't get converted to bool
with pytest.raises(vol.Invalid):
boolean("1")
with pytest.raises(vol.Invalid):
boolean("0")
with pytest.raises(vol.Invalid):
boolean(1)
with pytest.raises(vol.Invalid):
boolean(0)

View file

@ -498,6 +498,20 @@ async def test_set_config_parameter(hass, client, multisensor_6, integration):
client.async_send_command.reset_mock()
# Test setting config parameter with no valid nodes raises Exception
with pytest.raises(vol.MultipleInvalid):
await hass.services.async_call(
DOMAIN,
SERVICE_SET_CONFIG_PARAMETER,
{
ATTR_ENTITY_ID: "sensor.fake",
ATTR_CONFIG_PARAMETER: 102,
ATTR_CONFIG_PARAMETER_BITMASK: 1,
ATTR_CONFIG_VALUE: 1,
},
blocking=True,
)
async def test_bulk_set_config_parameters(hass, client, multisensor_6, integration):
"""Test the bulk_set_partial_config_parameters service."""
@ -1345,8 +1359,8 @@ async def test_multicast_set_value(
diff_network_node.client.driver.controller.home_id.return_value = "diff_home_id"
with pytest.raises(vol.MultipleInvalid), patch(
"homeassistant.components.zwave_js.services.async_get_node_from_device_id",
return_value=diff_network_node,
"homeassistant.components.zwave_js.helpers.async_get_node_from_device_id",
side_effect=(climate_danfoss_lc_13, diff_network_node),
):
await hass.services.async_call(
DOMAIN,

View file

@ -1,6 +1,8 @@
"""The tests for Z-Wave JS automation triggers."""
from unittest.mock import AsyncMock, patch
import pytest
import voluptuous as vol
from zwave_js_server.const import CommandClass
from zwave_js_server.event import Event
from zwave_js_server.model.node import Node
@ -708,3 +710,28 @@ async def test_async_validate_trigger_config(hass):
mock_platform.async_validate_trigger_config.return_value = {}
await async_validate_trigger_config(hass, {})
mock_platform.async_validate_trigger_config.assert_awaited()
async def test_invalid_trigger_configs(hass):
"""Test invalid trigger configs."""
with pytest.raises(vol.Invalid):
await async_validate_trigger_config(
hass,
{
"platform": f"{DOMAIN}.event",
"entity_id": "fake.entity",
"event_source": "node",
"event": "value updated",
},
)
with pytest.raises(vol.Invalid):
await async_validate_trigger_config(
hass,
{
"platform": f"{DOMAIN}.value_updated",
"entity_id": "fake.entity",
"command_class": CommandClass.DOOR_LOCK.value,
"property": "latchStatus",
},
)