diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 71c5b2bf592..7a8f284787f 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections import defaultdict +from typing import Callable from async_timeout import timeout from zwave_js_server.client import Client as ZwaveClient @@ -97,11 +98,22 @@ def register_node_in_dev_reg( dev_reg: device_registry.DeviceRegistry, client: ZwaveClient, node: ZwaveNode, + remove_device_func: Callable[[device_registry.DeviceEntry], None], ) -> device_registry.DeviceEntry: """Register node in dev reg.""" + device_id = get_device_id(client, node) + # If a device already exists but it doesn't match the new node, it means the node + # was replaced with a different device and the device needs to be removeed so the + # new device can be created. Otherwise if the device exists and the node is the same, + # the node was replaced with the same device model and we can reuse the device. + if (device := dev_reg.async_get_device({device_id})) and ( + device.model != node.device_config.label + or device.manufacturer != node.device_config.manufacturer + ): + remove_device_func(device) params = { "config_entry_id": entry.entry_id, - "identifiers": {get_device_id(client, node)}, + "identifiers": {device_id}, "sw_version": node.firmware_version, "name": node.name or node.device_config.description or f"Node {node.node_id}", "model": node.device_config.label, @@ -135,6 +147,14 @@ async def async_setup_entry( # noqa: C901 registered_unique_ids: dict[str, dict[str, set[str]]] = defaultdict(dict) discovered_value_ids: dict[str, set[str]] = defaultdict(set) + @callback + def remove_device(device: device_registry.DeviceEntry) -> None: + """Remove device from registry.""" + # note: removal of entity registry entry is handled by core + dev_reg.async_remove_device(device.id) + registered_unique_ids.pop(device.id, None) + discovered_value_ids.pop(device.id, None) + async def async_handle_discovery_info( device: device_registry.DeviceEntry, disc_info: ZwaveDiscoveryInfo, @@ -188,7 +208,9 @@ async def async_setup_entry( # noqa: C901 """Handle node ready event.""" LOGGER.debug("Processing node %s", node) # register (or update) node in device registry - device = register_node_in_dev_reg(hass, entry, dev_reg, client, node) + device = register_node_in_dev_reg( + hass, entry, dev_reg, client, node, remove_device + ) # We only want to create the defaultdict once, even on reinterviews if device.id not in registered_unique_ids: registered_unique_ids[device.id] = defaultdict(set) @@ -265,7 +287,7 @@ async def async_setup_entry( # noqa: C901 ) # we do submit the node to device registry so user has # some visual feedback that something is (in the process of) being added - register_node_in_dev_reg(hass, entry, dev_reg, client, node) + register_node_in_dev_reg(hass, entry, dev_reg, client, node, remove_device) async def async_on_value_added( value_updates_disc_info: dict[str, ZwaveDiscoveryInfo], value: Value @@ -293,20 +315,24 @@ async def async_setup_entry( # noqa: C901 ) @callback - def async_on_node_removed(node: ZwaveNode) -> None: + def async_on_node_removed(event: dict) -> None: """Handle node removed event.""" + node: ZwaveNode = event["node"] + replaced: bool = event.get("replaced", False) # grab device in device registry attached to this node dev_id = get_device_id(client, node) device = dev_reg.async_get_device({dev_id}) - # note: removal of entity registry entry is handled by core - dev_reg.async_remove_device(device.id) # type: ignore - registered_unique_ids.pop(device.id, None) # type: ignore - discovered_value_ids.pop(device.id, None) # type: ignore + # We assert because we know the device exists + assert device + if not replaced: + remove_device(device) @callback def async_on_value_notification(notification: ValueNotification) -> None: """Relay stateless value notification events from Z-Wave nodes to hass.""" device = dev_reg.async_get_device({get_device_id(client, notification.node)}) + # We assert because we know the device exists + assert device raw_value = value = notification.value if notification.metadata.states: value = notification.metadata.states.get(str(value), value) @@ -317,7 +343,7 @@ async def async_setup_entry( # noqa: C901 ATTR_NODE_ID: notification.node.node_id, ATTR_HOME_ID: client.driver.controller.home_id, ATTR_ENDPOINT: notification.endpoint, - ATTR_DEVICE_ID: device.id, # type: ignore + ATTR_DEVICE_ID: device.id, ATTR_COMMAND_CLASS: notification.command_class, ATTR_COMMAND_CLASS_NAME: notification.command_class_name, ATTR_LABEL: notification.metadata.label, @@ -336,11 +362,13 @@ async def async_setup_entry( # noqa: C901 ) -> None: """Relay stateless notification events from Z-Wave nodes to hass.""" device = dev_reg.async_get_device({get_device_id(client, notification.node)}) + # We assert because we know the device exists + assert device event_data = { ATTR_DOMAIN: DOMAIN, ATTR_NODE_ID: notification.node.node_id, ATTR_HOME_ID: client.driver.controller.home_id, - ATTR_DEVICE_ID: device.id, # type: ignore + ATTR_DEVICE_ID: device.id, ATTR_COMMAND_CLASS: notification.command_class, } @@ -379,6 +407,8 @@ async def async_setup_entry( # noqa: C901 disc_info = value_updates_disc_info[value.value_id] device = dev_reg.async_get_device({get_device_id(client, value.node)}) + # We assert because we know the device exists + assert device unique_id = get_unique_id( client.driver.controller.home_id, disc_info.primary_value.value_id @@ -394,7 +424,7 @@ async def async_setup_entry( # noqa: C901 { ATTR_NODE_ID: value.node.node_id, ATTR_HOME_ID: client.driver.controller.home_id, - ATTR_DEVICE_ID: device.id, # type: ignore + ATTR_DEVICE_ID: device.id, ATTR_ENTITY_ID: entity_id, ATTR_COMMAND_CLASS: value.command_class, ATTR_COMMAND_CLASS_NAME: value.command_class_name, @@ -500,9 +530,7 @@ async def async_setup_entry( # noqa: C901 # listen for nodes being removed from the mesh # NOTE: This will not remove nodes that were removed when HA was not running entry.async_on_unload( - client.driver.controller.on( - "node removed", lambda event: async_on_node_removed(event["node"]) - ) + client.driver.controller.on("node removed", async_on_node_removed) ) platform_task = hass.async_create_task(start_platforms()) diff --git a/tests/components/zwave_js/test_init.py b/tests/components/zwave_js/test_init.py index 5fed86c4d81..2e6a64f456e 100644 --- a/tests/components/zwave_js/test_init.py +++ b/tests/components/zwave_js/test_init.py @@ -161,7 +161,7 @@ async def test_new_entity_on_value_added(hass, multisensor_6, client, integratio async def test_on_node_added_ready(hass, multisensor_6_state, client, integration): """Test we handle a ready node added event.""" dev_reg = dr.async_get(hass) - node = Node(client, multisensor_6_state) + node = Node(client, deepcopy(multisensor_6_state)) event = {"node": node} air_temperature_device_id = f"{client.driver.controller.home_id}-{node.node_id}" @@ -656,3 +656,85 @@ async def test_suggested_area(hass, client, eaton_rf9640_dimmer): entity = ent_reg.async_get(EATON_RF9640_ENTITY) assert dev_reg.async_get(entity.device_id).area_id is not None + + +async def test_node_removed(hass, multisensor_6_state, client, integration): + """Test that device gets removed when node gets removed.""" + dev_reg = dr.async_get(hass) + node = Node(client, deepcopy(multisensor_6_state)) + device_id = f"{client.driver.controller.home_id}-{node.node_id}" + event = {"node": node} + + client.driver.controller.emit("node added", event) + await hass.async_block_till_done() + old_device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)}) + assert old_device.id + + event = {"node": node, "replaced": False} + + client.driver.controller.emit("node removed", event) + await hass.async_block_till_done() + # Assert device has been removed + assert not dev_reg.async_get(old_device.id) + + +async def test_replace_same_node(hass, multisensor_6_state, client, integration): + """Test when a node is replaced with itself that the device remains.""" + dev_reg = dr.async_get(hass) + node = Node(client, deepcopy(multisensor_6_state)) + device_id = f"{client.driver.controller.home_id}-{node.node_id}" + event = {"node": node} + + client.driver.controller.emit("node added", event) + await hass.async_block_till_done() + old_device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)}) + assert old_device.id + + event = {"node": node, "replaced": True} + + client.driver.controller.emit("node removed", event) + await hass.async_block_till_done() + # Assert device has remained + assert dev_reg.async_get(old_device.id) + + event = {"node": node} + + client.driver.controller.emit("node added", event) + await hass.async_block_till_done() + # Assert device has remained + assert dev_reg.async_get(old_device.id) + + +async def test_replace_different_node( + hass, multisensor_6_state, hank_binary_switch_state, client, integration +): + """Test when a node is replaced with a different node.""" + hank_binary_switch_state = deepcopy(hank_binary_switch_state) + multisensor_6_state = deepcopy(multisensor_6_state) + hank_binary_switch_state["nodeId"] = multisensor_6_state["nodeId"] + dev_reg = dr.async_get(hass) + old_node = Node(client, multisensor_6_state) + device_id = f"{client.driver.controller.home_id}-{old_node.node_id}" + new_node = Node(client, hank_binary_switch_state) + event = {"node": old_node} + + client.driver.controller.emit("node added", event) + await hass.async_block_till_done() + device = dev_reg.async_get_device(identifiers={(DOMAIN, device_id)}) + assert device + + event = {"node": old_node, "replaced": True} + + client.driver.controller.emit("node removed", event) + await hass.async_block_till_done() + # Device should still be there after the node was removed + assert device + + event = {"node": new_node} + + client.driver.controller.emit("node added", event) + await hass.async_block_till_done() + device = dev_reg.async_get(device.id) + # assert device is new + assert device + assert device.manufacturer == "HANK Electronics Ltd." diff --git a/tests/components/zwave_js/test_migrate.py b/tests/components/zwave_js/test_migrate.py index a1f60c31fce..37c53700d95 100644 --- a/tests/components/zwave_js/test_migrate.py +++ b/tests/components/zwave_js/test_migrate.py @@ -1,4 +1,6 @@ """Test the Z-Wave JS migration module.""" +import copy + import pytest from zwave_js_server.model.node import Node @@ -48,7 +50,7 @@ async def test_unique_id_migration_dupes( assert entity_entry.unique_id == old_unique_id_2 # Add a ready node, unique ID should be migrated - node = Node(client, multisensor_6_state) + node = Node(client, copy.deepcopy(multisensor_6_state)) event = {"node": node} client.driver.controller.emit("node added", event) @@ -91,7 +93,7 @@ async def test_unique_id_migration(hass, multisensor_6_state, client, integratio assert entity_entry.unique_id == old_unique_id # Add a ready node, unique ID should be migrated - node = Node(client, multisensor_6_state) + node = Node(client, copy.deepcopy(multisensor_6_state)) event = {"node": node} client.driver.controller.emit("node added", event) @@ -135,7 +137,7 @@ async def test_unique_id_migration_property_key( assert entity_entry.unique_id == old_unique_id # Add a ready node, unique ID should be migrated - node = Node(client, hank_binary_switch_state) + node = Node(client, copy.deepcopy(hank_binary_switch_state)) event = {"node": node} client.driver.controller.emit("node added", event) @@ -170,7 +172,7 @@ async def test_unique_id_migration_notification_binary_sensor( assert entity_entry.unique_id == old_unique_id # Add a ready node, unique ID should be migrated - node = Node(client, multisensor_6_state) + node = Node(client, copy.deepcopy(multisensor_6_state)) event = {"node": node} client.driver.controller.emit("node added", event) @@ -187,12 +189,15 @@ async def test_old_entity_migration( hass, hank_binary_switch_state, client, integration ): """Test old entity on a different endpoint is migrated to a new one.""" - node = Node(client, hank_binary_switch_state) + node = Node(client, copy.deepcopy(hank_binary_switch_state)) ent_reg = er.async_get(hass) dev_reg = dr.async_get(hass) device = dev_reg.async_get_or_create( - config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} + config_entry_id=integration.entry_id, + identifiers={get_device_id(client, node)}, + manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"], + model=hank_binary_switch_state["deviceConfig"]["label"], ) SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed" @@ -230,12 +235,15 @@ async def test_different_endpoint_migration_status_sensor( hass, hank_binary_switch_state, client, integration ): """Test that the different endpoint migration logic skips over the status sensor.""" - node = Node(client, hank_binary_switch_state) + node = Node(client, copy.deepcopy(hank_binary_switch_state)) ent_reg = er.async_get(hass) dev_reg = dr.async_get(hass) device = dev_reg.async_get_or_create( - config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} + config_entry_id=integration.entry_id, + identifiers={get_device_id(client, node)}, + manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"], + model=hank_binary_switch_state["deviceConfig"]["label"], ) SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_status_sensor" @@ -271,12 +279,15 @@ async def test_skip_old_entity_migration_for_multiple( hass, hank_binary_switch_state, client, integration ): """Test that multiple entities of the same value but on a different endpoint get skipped.""" - node = Node(client, hank_binary_switch_state) + node = Node(client, copy.deepcopy(hank_binary_switch_state)) ent_reg = er.async_get(hass) dev_reg = dr.async_get(hass) device = dev_reg.async_get_or_create( - config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} + config_entry_id=integration.entry_id, + identifiers={get_device_id(client, node)}, + manufacturer=hank_binary_switch_state["deviceConfig"]["manufacturer"], + model=hank_binary_switch_state["deviceConfig"]["label"], ) SENSOR_NAME = "sensor.smart_plug_with_two_usb_ports_value_electric_consumed" @@ -328,12 +339,15 @@ async def test_old_entity_migration_notification_binary_sensor( hass, multisensor_6_state, client, integration ): """Test old entity on a different endpoint is migrated to a new one for a notification binary sensor.""" - node = Node(client, multisensor_6_state) + node = Node(client, copy.deepcopy(multisensor_6_state)) ent_reg = er.async_get(hass) dev_reg = dr.async_get(hass) device = dev_reg.async_get_or_create( - config_entry_id=integration.entry_id, identifiers={get_device_id(client, node)} + config_entry_id=integration.entry_id, + identifiers={get_device_id(client, node)}, + manufacturer=multisensor_6_state["deviceConfig"]["manufacturer"], + model=multisensor_6_state["deviceConfig"]["label"], ) entity_name = NOTIFICATION_MOTION_BINARY_SENSOR.split(".")[1]