diff --git a/homeassistant/components/statistics/__init__.py b/homeassistant/components/statistics/__init__.py index 70739c618f7..f71274e0ee7 100644 --- a/homeassistant/components/statistics/__init__.py +++ b/homeassistant/components/statistics/__init__.py @@ -1,8 +1,11 @@ """The statistics component.""" from homeassistant.config_entries import ConfigEntry -from homeassistant.const import Platform +from homeassistant.const import CONF_ENTITY_ID, Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers.device import ( + async_remove_stale_devices_links_keep_entity_device, +) DOMAIN = "statistics" PLATFORMS = [Platform.SENSOR] @@ -11,6 +14,12 @@ PLATFORMS = [Platform.SENSOR] async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Statistics from a config entry.""" + async_remove_stale_devices_links_keep_entity_device( + hass, + entry.entry_id, + entry.options[CONF_ENTITY_ID], + ) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) entry.async_on_unload(entry.add_update_listener(update_listener)) diff --git a/homeassistant/components/statistics/sensor.py b/homeassistant/components/statistics/sensor.py index 8d28254ad61..ca1d75b57ed 100644 --- a/homeassistant/components/statistics/sensor.py +++ b/homeassistant/components/statistics/sensor.py @@ -43,6 +43,7 @@ from homeassistant.core import ( split_entity_id, ) from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( async_track_point_in_utc_time, @@ -268,6 +269,7 @@ async def async_setup_platform( async_add_entities( new_entities=[ StatisticsSensor( + hass=hass, source_entity_id=config[CONF_ENTITY_ID], name=config[CONF_NAME], unique_id=config.get(CONF_UNIQUE_ID), @@ -304,6 +306,7 @@ async def async_setup_entry( async_add_entities( [ StatisticsSensor( + hass=hass, source_entity_id=entry.options[CONF_ENTITY_ID], name=entry.options[CONF_NAME], unique_id=entry.entry_id, @@ -327,6 +330,7 @@ class StatisticsSensor(SensorEntity): def __init__( self, + hass: HomeAssistant, source_entity_id: str, name: str, unique_id: str | None, @@ -341,6 +345,10 @@ class StatisticsSensor(SensorEntity): self._attr_name: str = name self._attr_unique_id: str | None = unique_id self._source_entity_id: str = source_entity_id + self._attr_device_info = async_device_info_to_link_from_entity( + hass, + source_entity_id, + ) self.is_binary: bool = ( split_entity_id(self._source_entity_id)[0] == BINARY_SENSOR_DOMAIN ) diff --git a/tests/components/statistics/test_init.py b/tests/components/statistics/test_init.py index 6cb943c0687..64829ea7d66 100644 --- a/tests/components/statistics/test_init.py +++ b/tests/components/statistics/test_init.py @@ -2,8 +2,10 @@ from __future__ import annotations +from homeassistant.components.statistics import DOMAIN as STATISTICS_DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er from tests.common import MockConfigEntry @@ -15,3 +17,93 @@ async def test_unload_entry(hass: HomeAssistant, loaded_entry: MockConfigEntry) assert await hass.config_entries.async_unload(loaded_entry.entry_id) await hass.async_block_till_done() assert loaded_entry.state is ConfigEntryState.NOT_LOADED + + +async def test_device_cleaning( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test the cleaning of devices linked to the helper Statistics.""" + + # Source entity device config entry + source_config_entry = MockConfigEntry() + source_config_entry.add_to_hass(hass) + + # Device entry of the source entity + source_device1_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test1")}, + connections={("mac", "30:31:32:33:34:01")}, + ) + + # Source entity registry + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device1_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + # Configure the configuration entry for Statistics + statistics_config_entry = MockConfigEntry( + data={}, + domain=STATISTICS_DOMAIN, + options={ + "name": "Statistics", + "entity_id": "sensor.test_source", + "state_characteristic": "mean", + "keep_last_sample": False, + "percentile": 50.0, + "precision": 2.0, + "sampling_size": 20.0, + }, + title="Statistics", + ) + statistics_config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(statistics_config_entry.entry_id) + await hass.async_block_till_done() + + # Confirm the link between the source entity device and the statistics sensor + statistics_entity = entity_registry.async_get("sensor.statistics") + assert statistics_entity is not None + assert statistics_entity.device_id == source_entity.device_id + + # Device entry incorrectly linked to Statistics config entry + device_registry.async_get_or_create( + config_entry_id=statistics_config_entry.entry_id, + identifiers={("sensor", "identifier_test2")}, + connections={("mac", "30:31:32:33:34:02")}, + ) + device_registry.async_get_or_create( + config_entry_id=statistics_config_entry.entry_id, + identifiers={("sensor", "identifier_test3")}, + connections={("mac", "30:31:32:33:34:03")}, + ) + await hass.async_block_till_done() + + # Before reloading the config entry, two devices are expected to be linked + devices_before_reload = device_registry.devices.get_devices_for_config_entry_id( + statistics_config_entry.entry_id + ) + assert len(devices_before_reload) == 3 + + # Config entry reload + await hass.config_entries.async_reload(statistics_config_entry.entry_id) + await hass.async_block_till_done() + + # Confirm the link between the source entity device and the statistics sensor + statistics_entity = entity_registry.async_get("sensor.statistics") + assert statistics_entity is not None + assert statistics_entity.device_id == source_entity.device_id + + # After reloading the config entry, only one linked device is expected + devices_after_reload = device_registry.devices.get_devices_for_config_entry_id( + statistics_config_entry.entry_id + ) + assert len(devices_after_reload) == 1 + + assert devices_after_reload[0].id == source_device1_entry.id diff --git a/tests/components/statistics/test_sensor.py b/tests/components/statistics/test_sensor.py index 269c17e34b9..c90d685714c 100644 --- a/tests/components/statistics/test_sensor.py +++ b/tests/components/statistics/test_sensor.py @@ -41,7 +41,7 @@ from homeassistant.const import ( UnitOfTemperature, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -1654,3 +1654,50 @@ async def test_reload(recorder_mock: Recorder, hass: HomeAssistant) -> None: assert hass.states.get("sensor.test") is None assert hass.states.get("sensor.cputest") + + +async def test_device_id( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test for source entity device for Statistics.""" + source_config_entry = MockConfigEntry() + source_config_entry.add_to_hass(hass) + source_device_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test")}, + connections={("mac", "30:31:32:33:34:35")}, + ) + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + statistics_config_entry = MockConfigEntry( + data={}, + domain=STATISTICS_DOMAIN, + options={ + "name": "Statistics", + "entity_id": "sensor.test_source", + "state_characteristic": "mean", + "keep_last_sample": False, + "percentile": 50.0, + "precision": 2.0, + "sampling_size": 20.0, + }, + title="Statistics", + ) + statistics_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(statistics_config_entry.entry_id) + await hass.async_block_till_done() + + statistics_entity = entity_registry.async_get("sensor.statistics") + assert statistics_entity is not None + assert statistics_entity.device_id == source_entity.device_id