Add zwave_js.multicast_set_value service (#51115)

* Add zwave_js.multicast_set_value service

* comment

* Add test for multiple config entries validation

* additional validation test

* brevity

* wrap schema in vol.Schema

* Update homeassistant/components/zwave_js/services.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* do node transform and multicast validation in schema validation

* move poll value entity validation into schema validation, pass helper functions dev and ent reg instead of retrieving it every time

* make validators nested functions since they don't neeed to be externally accessible

* Update homeassistant/components/zwave_js/services.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove errant ALLOW_EXTRA

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Raman Gupta 2021-05-27 21:57:35 -04:00 committed by GitHub
parent 93ada0a675
commit ca8d09e5e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 432 additions and 69 deletions

View file

@ -2,12 +2,15 @@
from __future__ import annotations
import logging
from typing import Any
import voluptuous as vol
from zwave_js_server.client import Client as ZwaveClient
from zwave_js_server.const import CommandStatus
from zwave_js_server.exceptions import SetValueFailed
from zwave_js_server.model.node import Node as ZwaveNode
from zwave_js_server.model.value import get_value_id
from zwave_js_server.util.multicast import async_multicast_set_value
from zwave_js_server.util.node import (
async_bulk_set_partial_config_parameters,
async_set_config_parameter,
@ -16,6 +19,7 @@ from zwave_js_server.util.node import (
from homeassistant.const import ATTR_DEVICE_ID, ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import DeviceRegistry
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity_registry import EntityRegistry
@ -26,8 +30,8 @@ _LOGGER = logging.getLogger(__name__)
def parameter_name_does_not_need_bitmask(
val: dict[str, int | str]
) -> dict[str, int | str]:
val: dict[str, int | str | list[str]]
) -> dict[str, int | str | list[str]]:
"""Validate that if a parameter name is provided, bitmask is not as well."""
if isinstance(val[const.ATTR_CONFIG_PARAMETER], str) and (
val.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
@ -39,6 +43,16 @@ def parameter_name_does_not_need_bitmask(
return val
def broadcast_command(val: dict[str, Any]) -> dict[str, Any]:
"""Validate that the service call is for a broadcast command."""
if val.get(const.ATTR_BROADCAST):
return val
raise vol.Invalid(
"Either `broadcast` must be set to True or multiple devices/entities must be "
"specified"
)
# Validates that a bitmask is provided in hex form and converts it to decimal
# int equivalent since that's what the library uses
BITMASK_SCHEMA = vol.All(
@ -55,14 +69,95 @@ BITMASK_SCHEMA = vol.All(
class ZWaveServices:
"""Class that holds our services (Zwave Commands) that should be published to hass."""
def __init__(self, hass: HomeAssistant, ent_reg: EntityRegistry) -> None:
def __init__(
self, hass: HomeAssistant, ent_reg: EntityRegistry, dev_reg: DeviceRegistry
) -> None:
"""Initialize with hass object."""
self._hass = hass
self._ent_reg = ent_reg
self._dev_reg = dev_reg
@callback
def async_register(self) -> None:
"""Register all our services."""
@callback
def get_nodes_from_service_data(val: dict[str, Any]) -> dict[str, Any]:
"""Get nodes set from service data."""
nodes: set[ZwaveNode] = set()
try:
if ATTR_ENTITY_ID in val:
nodes |= {
async_get_node_from_entity_id(
self._hass, entity_id, self._ent_reg, self._dev_reg
)
for entity_id in val[ATTR_ENTITY_ID]
}
val.pop(ATTR_ENTITY_ID)
if ATTR_DEVICE_ID in val:
nodes |= {
async_get_node_from_device_id(
self._hass, device_id, self._dev_reg
)
for device_id in val[ATTR_DEVICE_ID]
}
val.pop(ATTR_DEVICE_ID)
except ValueError as err:
raise vol.Invalid(err.args[0]) from err
val[const.ATTR_NODES] = nodes
return val
@callback
def validate_multicast_nodes(val: dict[str, Any]) -> dict[str, Any]:
"""Validate the input nodes for multicast."""
nodes: set[ZwaveNode] = val[const.ATTR_NODES]
broadcast: bool = val[const.ATTR_BROADCAST]
# User must specify a node if they are attempting a broadcast and have more
# than one zwave-js network. We know it's a broadcast if the nodes list is
# empty because of schema validation.
if (
not nodes
and len(self._hass.config_entries.async_entries(const.DOMAIN)) > 1
):
raise vol.Invalid(
"You must include at least one entity or device in the service call"
)
# When multicasting, user must specify at least two nodes
if not broadcast and len(nodes) < 2:
raise vol.Invalid(
"To set a value on a single node, use the zwave_js.set_value service"
)
first_node = next((node for node in nodes), None)
# If any nodes don't have matching home IDs, we can't run the command because
# we can't multicast across multiple networks
if first_node and any(
node.client.driver.controller.home_id
!= first_node.client.driver.controller.home_id
for node in nodes
):
raise vol.Invalid(
"Multicast commands only work on devices in the same network"
)
return val
@callback
def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
"""Validate entities exist and are from the zwave_js platform."""
for entity_id in val[ATTR_ENTITY_ID]:
entry = self._ent_reg.async_get(entity_id)
if entry is None or entry.platform != const.DOMAIN:
raise vol.Invalid(
f"Entity {entity_id} is not a valid {const.DOMAIN} entity."
)
return val
self._hass.services.async_register(
const.DOMAIN,
const.SERVICE_SET_CONFIG_PARAMETER,
@ -86,6 +181,7 @@ class ZWaveServices:
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
parameter_name_does_not_need_bitmask,
get_nodes_from_service_data,
),
),
)
@ -112,6 +208,7 @@ class ZWaveServices:
),
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
get_nodes_from_service_data,
),
),
)
@ -121,10 +218,15 @@ class ZWaveServices:
const.SERVICE_REFRESH_VALUE,
self.async_poll_value,
schema=vol.Schema(
{
vol.Required(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(const.ATTR_REFRESH_ALL_VALUES, default=False): bool,
}
vol.All(
{
vol.Required(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(
const.ATTR_REFRESH_ALL_VALUES, default=False
): bool,
},
validate_entities,
)
),
)
@ -153,23 +255,48 @@ class ZWaveServices:
vol.Optional(const.ATTR_WAIT_FOR_RESULT): vol.Coerce(bool),
},
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
get_nodes_from_service_data,
),
),
)
self._hass.services.async_register(
const.DOMAIN,
const.SERVICE_MULTICAST_SET_VALUE,
self.async_multicast_set_value,
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Optional(const.ATTR_BROADCAST, default=False): cv.boolean,
vol.Required(const.ATTR_COMMAND_CLASS): vol.Coerce(int),
vol.Required(const.ATTR_PROPERTY): vol.Any(
vol.Coerce(int), str
),
vol.Optional(const.ATTR_PROPERTY_KEY): vol.Any(
vol.Coerce(int), str
),
vol.Optional(const.ATTR_ENDPOINT): vol.Coerce(int),
vol.Required(const.ATTR_VALUE): vol.Any(
bool, vol.Coerce(int), vol.Coerce(float), cv.string
),
},
vol.Any(
cv.has_at_least_one_key(ATTR_DEVICE_ID, ATTR_ENTITY_ID),
broadcast_command,
),
get_nodes_from_service_data,
validate_multicast_nodes,
),
),
)
async def async_set_config_parameter(self, service: ServiceCall) -> None:
"""Set a config value on a node."""
nodes: set[ZwaveNode] = set()
if ATTR_ENTITY_ID in service.data:
nodes |= {
async_get_node_from_entity_id(self._hass, entity_id)
for entity_id in service.data[ATTR_ENTITY_ID]
}
if ATTR_DEVICE_ID in service.data:
nodes |= {
async_get_node_from_device_id(self._hass, device_id)
for device_id in service.data[ATTR_DEVICE_ID]
}
nodes = service.data[const.ATTR_NODES]
property_or_property_name = service.data[const.ATTR_CONFIG_PARAMETER]
property_key = service.data.get(const.ATTR_CONFIG_PARAMETER_BITMASK)
new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -196,17 +323,7 @@ class ZWaveServices:
self, service: ServiceCall
) -> None:
"""Bulk set multiple partial config values on a node."""
nodes: set[ZwaveNode] = set()
if ATTR_ENTITY_ID in service.data:
nodes |= {
async_get_node_from_entity_id(self._hass, entity_id)
for entity_id in service.data[ATTR_ENTITY_ID]
}
if ATTR_DEVICE_ID in service.data:
nodes |= {
async_get_node_from_device_id(self._hass, device_id)
for device_id in service.data[ATTR_DEVICE_ID]
}
nodes = service.data[const.ATTR_NODES]
property_ = service.data[const.ATTR_CONFIG_PARAMETER]
new_value = service.data[const.ATTR_CONFIG_VALUE]
@ -231,10 +348,7 @@ class ZWaveServices:
"""Poll value on a node."""
for entity_id in service.data[ATTR_ENTITY_ID]:
entry = self._ent_reg.async_get(entity_id)
if entry is None or entry.platform != const.DOMAIN:
raise ValueError(
f"Entity {entity_id} is not a valid {const.DOMAIN} entity."
)
assert entry # Schema validation would have failed if we can't do this
async_dispatcher_send(
self._hass,
f"{const.DOMAIN}_{entry.unique_id}_poll_value",
@ -243,17 +357,7 @@ class ZWaveServices:
async def async_set_value(self, service: ServiceCall) -> None:
"""Set a value on a node."""
nodes: set[ZwaveNode] = set()
if ATTR_ENTITY_ID in service.data:
nodes |= {
async_get_node_from_entity_id(self._hass, entity_id)
for entity_id in service.data[ATTR_ENTITY_ID]
}
if ATTR_DEVICE_ID in service.data:
nodes |= {
async_get_node_from_device_id(self._hass, device_id)
for device_id in service.data[ATTR_DEVICE_ID]
}
nodes = service.data[const.ATTR_NODES]
command_class = service.data[const.ATTR_COMMAND_CLASS]
property_ = service.data[const.ATTR_PROPERTY]
property_key = service.data.get(const.ATTR_PROPERTY_KEY)
@ -280,3 +384,37 @@ class ZWaveServices:
"https://zwave-js.github.io/node-zwave-js/#/api/node?id=setvalue "
"for possible reasons"
)
async def async_multicast_set_value(self, service: ServiceCall) -> None:
"""Set a value via multicast to multiple nodes."""
nodes = service.data[const.ATTR_NODES]
broadcast: bool = service.data[const.ATTR_BROADCAST]
value = {
"commandClass": service.data[const.ATTR_COMMAND_CLASS],
"property": service.data[const.ATTR_PROPERTY],
"propertyKey": service.data.get(const.ATTR_PROPERTY_KEY),
"endpoint": service.data.get(const.ATTR_ENDPOINT),
}
new_value = service.data[const.ATTR_VALUE]
# If there are no nodes, we can assume there is only one config entry due to
# schema validation and can use that to get the client, otherwise we can just
# get the client from the node.
client: ZwaveClient = None
first_node = next((node for node in nodes), None)
if first_node:
client = first_node.client
else:
entry_id = self._hass.config_entries.async_entries(const.DOMAIN)[0].entry_id
client = self._hass.data[const.DOMAIN][entry_id][const.DATA_CLIENT]
success = await async_multicast_set_value(
client,
new_value,
{k: v for k, v in value.items() if v is not None},
None if broadcast else list(nodes),
)
if success is False:
raise SetValueFailed("Unable to set value via multicast")