Reuse zwave_js device when replacing removed node with same node (#56599)

* Reuse zwave_js device when a removed node is replaced with the same node

* Ensure change is backwards compatible with servers that don't include replaced

* Remove lambda

* Add assertions to remove type ignores

* fix tests by always copying state and setting manufacturer/label attributes
This commit is contained in:
Raman Gupta 2021-09-25 04:43:37 -04:00 committed by GitHub
parent 5d3d6fa1cd
commit b1f4ccfd6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 27 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from typing import Callable
from async_timeout import timeout from async_timeout import timeout
from zwave_js_server.client import Client as ZwaveClient from zwave_js_server.client import Client as ZwaveClient
@ -97,11 +98,22 @@ def register_node_in_dev_reg(
dev_reg: device_registry.DeviceRegistry, dev_reg: device_registry.DeviceRegistry,
client: ZwaveClient, client: ZwaveClient,
node: ZwaveNode, node: ZwaveNode,
remove_device_func: Callable[[device_registry.DeviceEntry], None],
) -> device_registry.DeviceEntry: ) -> device_registry.DeviceEntry:
"""Register node in dev reg.""" """Register node in dev reg."""
device_id = get_device_id(client, node)
# If a device already exists but it doesn't match the new node, it means the node
# was replaced with a different device and the device needs to be removeed so the
# new device can be created. Otherwise if the device exists and the node is the same,
# the node was replaced with the same device model and we can reuse the device.
if (device := dev_reg.async_get_device({device_id})) and (
device.model != node.device_config.label
or device.manufacturer != node.device_config.manufacturer
):
remove_device_func(device)
params = { params = {
"config_entry_id": entry.entry_id, "config_entry_id": entry.entry_id,
"identifiers": {get_device_id(client, node)}, "identifiers": {device_id},
"sw_version": node.firmware_version, "sw_version": node.firmware_version,
"name": node.name or node.device_config.description or f"Node {node.node_id}", "name": node.name or node.device_config.description or f"Node {node.node_id}",
"model": node.device_config.label, "model": node.device_config.label,
@ -135,6 +147,14 @@ async def async_setup_entry( # noqa: C901
registered_unique_ids: dict[str, dict[str, set[str]]] = defaultdict(dict) registered_unique_ids: dict[str, dict[str, set[str]]] = defaultdict(dict)
discovered_value_ids: dict[str, set[str]] = defaultdict(set) discovered_value_ids: dict[str, set[str]] = defaultdict(set)
@callback
def remove_device(device: device_registry.DeviceEntry) -> None:
"""Remove device from registry."""
# note: removal of entity registry entry is handled by core
dev_reg.async_remove_device(device.id)
registered_unique_ids.pop(device.id, None)
discovered_value_ids.pop(device.id, None)
async def async_handle_discovery_info( async def async_handle_discovery_info(
device: device_registry.DeviceEntry, device: device_registry.DeviceEntry,
disc_info: ZwaveDiscoveryInfo, disc_info: ZwaveDiscoveryInfo,
@ -188,7 +208,9 @@ async def async_setup_entry( # noqa: C901
"""Handle node ready event.""" """Handle node ready event."""
LOGGER.debug("Processing node %s", node) LOGGER.debug("Processing node %s", node)
# register (or update) node in device registry # register (or update) node in device registry
device = register_node_in_dev_reg(hass, entry, dev_reg, client, node) device = register_node_in_dev_reg(
hass, entry, dev_reg, client, node, remove_device
)
# We only want to create the defaultdict once, even on reinterviews # We only want to create the defaultdict once, even on reinterviews
if device.id not in registered_unique_ids: if device.id not in registered_unique_ids:
registered_unique_ids[device.id] = defaultdict(set) registered_unique_ids[device.id] = defaultdict(set)
@ -265,7 +287,7 @@ async def async_setup_entry( # noqa: C901
) )
# we do submit the node to device registry so user has # we do submit the node to device registry so user has
# some visual feedback that something is (in the process of) being added # some visual feedback that something is (in the process of) being added
register_node_in_dev_reg(hass, entry, dev_reg, client, node) register_node_in_dev_reg(hass, entry, dev_reg, client, node, remove_device)
async def async_on_value_added( async def async_on_value_added(
value_updates_disc_info: dict[str, ZwaveDiscoveryInfo], value: Value value_updates_disc_info: dict[str, ZwaveDiscoveryInfo], value: Value
@ -293,20 +315,24 @@ async def async_setup_entry( # noqa: C901
) )
@callback @callback
def async_on_node_removed(node: ZwaveNode) -> None: def async_on_node_removed(event: dict) -> None:
"""Handle node removed event.""" """Handle node removed event."""
node: ZwaveNode = event["node"]
replaced: bool = event.get("replaced", False)
# grab device in device registry attached to this node # grab device in device registry attached to this node
dev_id = get_device_id(client, node) dev_id = get_device_id(client, node)
device = dev_reg.async_get_device({dev_id}) device = dev_reg.async_get_device({dev_id})
# note: removal of entity registry entry is handled by core # We assert because we know the device exists
dev_reg.async_remove_device(device.id) # type: ignore assert device
registered_unique_ids.pop(device.id, None) # type: ignore if not replaced:
discovered_value_ids.pop(device.id, None) # type: ignore remove_device(device)
@callback @callback
def async_on_value_notification(notification: ValueNotification) -> None: def async_on_value_notification(notification: ValueNotification) -> None:
"""Relay stateless value notification events from Z-Wave nodes to hass.""" """Relay stateless value notification events from Z-Wave nodes to hass."""
device = dev_reg.async_get_device({get_device_id(client, notification.node)}) device = dev_reg.async_get_device({get_device_id(client, notification.node)})
# We assert because we know the device exists
assert device
raw_value = value = notification.value raw_value = value = notification.value
if notification.metadata.states: if notification.metadata.states:
value = notification.metadata.states.get(str(value), value) value = notification.metadata.states.get(str(value), value)
@ -317,7 +343,7 @@ async def async_setup_entry( # noqa: C901
ATTR_NODE_ID: notification.node.node_id, ATTR_NODE_ID: notification.node.node_id,
ATTR_HOME_ID: client.driver.controller.home_id, ATTR_HOME_ID: client.driver.controller.home_id,
ATTR_ENDPOINT: notification.endpoint, ATTR_ENDPOINT: notification.endpoint,
ATTR_DEVICE_ID: device.id, # type: ignore ATTR_DEVICE_ID: device.id,
ATTR_COMMAND_CLASS: notification.command_class, ATTR_COMMAND_CLASS: notification.command_class,
ATTR_COMMAND_CLASS_NAME: notification.command_class_name, ATTR_COMMAND_CLASS_NAME: notification.command_class_name,
ATTR_LABEL: notification.metadata.label, ATTR_LABEL: notification.metadata.label,
@ -336,11 +362,13 @@ async def async_setup_entry( # noqa: C901
) -> None: ) -> None:
"""Relay stateless notification events from Z-Wave nodes to hass.""" """Relay stateless notification events from Z-Wave nodes to hass."""
device = dev_reg.async_get_device({get_device_id(client, notification.node)}) device = dev_reg.async_get_device({get_device_id(client, notification.node)})
# We assert because we know the device exists
assert device
event_data = { event_data = {
ATTR_DOMAIN: DOMAIN, ATTR_DOMAIN: DOMAIN,
ATTR_NODE_ID: notification.node.node_id, ATTR_NODE_ID: notification.node.node_id,
ATTR_HOME_ID: client.driver.controller.home_id, ATTR_HOME_ID: client.driver.controller.home_id,
ATTR_DEVICE_ID: device.id, # type: ignore ATTR_DEVICE_ID: device.id,
ATTR_COMMAND_CLASS: notification.command_class, ATTR_COMMAND_CLASS: notification.command_class,
} }
@ -379,6 +407,8 @@ async def async_setup_entry( # noqa: C901
disc_info = value_updates_disc_info[value.value_id] disc_info = value_updates_disc_info[value.value_id]
device = dev_reg.async_get_device({get_device_id(client, value.node)}) device = dev_reg.async_get_device({get_device_id(client, value.node)})
# We assert because we know the device exists
assert device
unique_id = get_unique_id( unique_id = get_unique_id(
client.driver.controller.home_id, disc_info.primary_value.value_id client.driver.controller.home_id, disc_info.primary_value.value_id
@ -394,7 +424,7 @@ async def async_setup_entry( # noqa: C901
{ {
ATTR_NODE_ID: value.node.node_id, ATTR_NODE_ID: value.node.node_id,
ATTR_HOME_ID: client.driver.controller.home_id, ATTR_HOME_ID: client.driver.controller.home_id,
ATTR_DEVICE_ID: device.id, # type: ignore ATTR_DEVICE_ID: device.id,
ATTR_ENTITY_ID: entity_id, ATTR_ENTITY_ID: entity_id,
ATTR_COMMAND_CLASS: value.command_class, ATTR_COMMAND_CLASS: value.command_class,
ATTR_COMMAND_CLASS_NAME: value.command_class_name, ATTR_COMMAND_CLASS_NAME: value.command_class_name,
@ -500,9 +530,7 @@ async def async_setup_entry( # noqa: C901
# listen for nodes being removed from the mesh # listen for nodes being removed from the mesh
# NOTE: This will not remove nodes that were removed when HA was not running # NOTE: This will not remove nodes that were removed when HA was not running
entry.async_on_unload( entry.async_on_unload(
client.driver.controller.on( client.driver.controller.on("node removed", async_on_node_removed)
"node removed", lambda event: async_on_node_removed(event["node"])
)
) )
platform_task = hass.async_create_task(start_platforms()) platform_task = hass.async_create_task(start_platforms())

View file

@ -161,7 +161,7 @@ async def test_new_entity_on_value_added(hass, multisensor_6, client, integratio
async def test_on_node_added_ready(hass, multisensor_6_state, client, integration): async def test_on_node_added_ready(hass, multisensor_6_state, client, integration):
"""Test we handle a ready node added event.""" """Test we handle a ready node added event."""
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
node = Node(client, multisensor_6_state) node = Node(client, deepcopy(multisensor_6_state))
event = {"node": node} event = {"node": node}
air_temperature_device_id = f"{client.driver.controller.home_id}-{node.node_id}" air_temperature_device_id = f"{client.driver.controller.home_id}-{node.node_id}"
@ -656,3 +656,85 @@ async def test_suggested_area(hass, client, eaton_rf9640_dimmer):
entity = ent_reg.async_get(EATON_RF9640_ENTITY) entity = ent_reg.async_get(EATON_RF9640_ENTITY)
assert dev_reg.async_get(entity.device_id).area_id is not None assert dev_reg.async_get(entity.device_id).area_id is not None
async def test_node_removed(hass, multisensor_6_state, client, integration):
"""Test that device gets removed when node gets removed."""
dev_reg = dr.async_get(hass)
node = Node(client, deepcopy(multisensor_6_state))
device_id = f"{client.driver.controller.home_id}-{node.node_id}"
event = {"node": node}
client.driver.controller.emit("node added", event)
await hass.async_block_till_done()
old_device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)})
assert old_device.id
event = {"node": node, "replaced": False}
client.driver.controller.emit("node removed", event)
await hass.async_block_till_done()
# Assert device has been removed
assert not dev_reg.async_get(old_device.id)
async def test_replace_same_node(hass, multisensor_6_state, client, integration):
"""Test when a node is replaced with itself that the device remains."""
dev_reg = dr.async_get(hass)
node = Node(client, deepcopy(multisensor_6_state))
device_id = f"{client.driver.controller.home_id}-{node.node_id}"
event = {"node": node}
client.driver.controller.emit("node added", event)
await hass.async_block_till_done()
old_device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)})
assert old_device.id
event = {"node": node, "replaced": True}
client.driver.controller.emit("node removed", event)
await hass.async_block_till_done()
# Assert device has remained
assert dev_reg.async_get(old_device.id)
event = {"node": node}
client.driver.controller.emit("node added", event)
await hass.async_block_till_done()
# Assert device has remained
assert dev_reg.async_get(old_device.id)
async def test_replace_different_node(
hass, multisensor_6_state, hank_binary_switch_state, client, integration
):
"""Test when a node is replaced with a different node."""
hank_binary_switch_state = deepcopy(hank_binary_switch_state)
multisensor_6_state = deepcopy(multisensor_6_state)
hank_binary_switch_state["nodeId"] = multisensor_6_state["nodeId"]
dev_reg = dr.async_get(hass)
old_node = Node(client, multisensor_6_state)
device_id = f"{client.driver.controller.home_id}-{old_node.node_id}"
new_node = Node(client, hank_binary_switch_state)
event = {"node": old_node}
client.driver.controller.emit("node added", event)
await hass.async_block_till_done()
device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)})
assert device
event = {"node": old_node, "replaced": True}
client.driver.controller.emit("node removed", event)
await hass.async_block_till_done()
# Device should still be there after the node was removed
assert device
event = {"node": new_node}
client.driver.controller.emit("node added", event)
await hass.async_block_till_done()
device = dev_reg.async_get(device.id)
# assert device is new
assert device
assert device.manufacturer == "HANK Electronics Ltd."

View file

@ -1,4 +1,6 @@
"""Test the Z-Wave JS migration module.""" """Test the Z-Wave JS migration module."""
import copy
import pytest import pytest
from zwave_js_server.model.node import Node from zwave_js_server.model.node import Node
@ -48,7 +50,7 @@ async def test_unique_id_migration_dupes(
assert entity_entry.unique_id == old_unique_id_2 assert entity_entry.unique_id == old_unique_id_2
# Add a ready node, unique ID should be migrated # Add a ready node, unique ID should be migrated
node = Node(client, multisensor_6_state) node = Node(client, copy.deepcopy(multisensor_6_state))
event = {"node": node} event = {"node": node}
client.driver.controller.emit("node added", event) client.driver.controller.emit("node added", event)
@ -91,7 +93,7 @@ async def test_unique_id_migration(hass, multisensor_6_state, client, integratio
assert entity_entry.unique_id == old_unique_id assert entity_entry.unique_id == old_unique_id
# Add a ready node, unique ID should be migrated # Add a ready node, unique ID should be migrated
node = Node(client, multisensor_6_state) node = Node(client, copy.deepcopy(multisensor_6_state))
event = {"node": node} event = {"node": node}
client.driver.controller.emit("node added", event) client.driver.controller.emit("node added", event)
@ -135,7 +137,7 @@ async def test_unique_id_migration_property_key(
assert entity_entry.unique_id == old_unique_id assert entity_entry.unique_id == old_unique_id
# Add a ready node, unique ID should be migrated # Add a ready node, unique ID should be migrated
node = Node(client, hank_binary_switch_state) node = Node(client, copy.deepcopy(hank_binary_switch_state))
event = {"node": node} event = {"node": node}
client.driver.controller.emit("node added", event) client.driver.controller.emit("node added", event)
@ -170,7 +172,7 @@ async def test_unique_id_migration_notification_binary_sensor(
assert entity_entry.unique_id == old_unique_id assert entity_entry.unique_id == old_unique_id
# Add a ready node, unique ID should be migrated # Add a ready node, unique ID should be migrated
node = Node(client, multisensor_6_state) node = Node(client, copy.deepcopy(multisensor_6_state))
event = {"node": node} event = {"node": node}
client.driver.controller.emit("node added", event) client.driver.controller.emit("node added", event)
@ -187,12 +189,15 @@ async def test_old_entity_migration(
hass, hank_binary_switch_state, client, integration hass, hank_binary_switch_state, client, integration
): ):
"""Test old entity on a different endpoint is migrated to a new one.""" """Test old entity on a different endpoint is migrated to a new one."""
node = Node(client, hank_binary_switch_state) node = Node(client, copy.deepcopy(hank_binary_switch_state))
ent_reg = er.async_get(hass) ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create( device = dev_reg.async_get_or_create(
config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} config_entry_id=integration.entry_id,
identifiers={get_device_id(client, node)},
manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"],
model=hank_binary_switch_state["deviceConfig"]["label"],
) )
SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed" SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed"
@ -230,12 +235,15 @@ async def test_different_endpoint_migration_status_sensor(
hass, hank_binary_switch_state, client, integration hass, hank_binary_switch_state, client, integration
): ):
"""Test that the different endpoint migration logic skips over the status sensor.""" """Test that the different endpoint migration logic skips over the status sensor."""
node = Node(client, hank_binary_switch_state) node = Node(client, copy.deepcopy(hank_binary_switch_state))
ent_reg = er.async_get(hass) ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create( device = dev_reg.async_get_or_create(
config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} config_entry_id=integration.entry_id,
identifiers={get_device_id(client, node)},
manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"],
model=hank_binary_switch_state["deviceConfig"]["label"],
) )
SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_status_sensor" SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_status_sensor"
@ -271,12 +279,15 @@ async def test_skip_old_entity_migration_for_multiple(
hass, hank_binary_switch_state, client, integration hass, hank_binary_switch_state, client, integration
): ):
"""Test that multiple entities of the same value but on a different endpoint get skipped.""" """Test that multiple entities of the same value but on a different endpoint get skipped."""
node = Node(client, hank_binary_switch_state) node = Node(client, copy.deepcopy(hank_binary_switch_state))
ent_reg = er.async_get(hass) ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create( device = dev_reg.async_get_or_create(
config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} config_entry_id=integration.entry_id,
identifiers={get_device_id(client, node)},
manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"],
model=hank_binary_switch_state["deviceConfig"]["label"],
) )
SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed" SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed"
@ -328,12 +339,15 @@ async def test_old_entity_migration_notification_binary_sensor(
hass, multisensor_6_state, client, integration hass, multisensor_6_state, client, integration
): ):
"""Test old entity on a different endpoint is migrated to a new one for a notification binary sensor.""" """Test old entity on a different endpoint is migrated to a new one for a notification binary sensor."""
node = Node(client, multisensor_6_state) node = Node(client, copy.deepcopy(multisensor_6_state))
ent_reg = er.async_get(hass) ent_reg = er.async_get(hass)
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create( device = dev_reg.async_get_or_create(
config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} config_entry_id=integration.entry_id,
identifiers={get_device_id(client, node)},
manufacturer=multisensor_6_state["deviceConfig"]["manufacturer"],
model=multisensor_6_state["deviceConfig"]["label"],
) )
entity_name = NOTIFICATION_MOTION_BINARY_SENSOR.split(".")[1] entity_name = NOTIFICATION_MOTION_BINARY_SENSOR.split(".")[1]