diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 8b78e985ff9..d3eea95f545 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -42,7 +42,7 @@ from . import ( service, ) from .device_registry import DeviceRegistry -from .entity_registry import DISABLED_INTEGRATION, EntityRegistry +from .entity_registry import DISABLED_DEVICE, DISABLED_INTEGRATION, EntityRegistry from .event import async_call_later, async_track_time_interval from .typing import ConfigType, DiscoveryInfoType @@ -456,6 +456,7 @@ class EntityPlatform: device_info = entity.device_info device_id = None + device = None if config_entry_id is not None and device_info is not None: processed_dev_info: dict[str, str | None] = { @@ -523,6 +524,11 @@ class EntityPlatform: unit_of_measurement=entity.unit_of_measurement, ) + if device and device.disabled and not entry.disabled: + entry = entity_registry.async_update_entity( + entry.entity_id, disabled_by=DISABLED_DEVICE + ) + entity.registry_entry = entry entity.entity_id = entry.entity_id diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index e367fa248d5..baaea35b62c 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -7,14 +7,14 @@ from unittest.mock import ANY, Mock, patch import pytest from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, PERCENTAGE -from homeassistant.core import CoreState, callback +from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, PlatformNotReady from homeassistant.helpers import ( device_registry as dr, entity_platform, entity_registry as er, ) -from homeassistant.helpers.entity import async_generate_entity_id +from homeassistant.helpers.entity import DeviceInfo, async_generate_entity_id from homeassistant.helpers.entity_component import ( DEFAULT_SCAN_INTERVAL, EntityComponent, @@ -1080,6 +1080,44 @@ async def test_entity_disabled_by_integration(hass): assert entry_disabled.disabled_by == er.DISABLED_INTEGRATION +async def test_entity_disabled_by_device(hass: HomeAssistant): + """Test entity disabled by device.""" + + connections = {(dr.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")} + entity_disabled = MockEntity( + unique_id="disabled", device_info=DeviceInfo(connections=connections) + ) + + async def async_setup_entry(hass, config_entry, async_add_entities): + """Mock setup entry method.""" + async_add_entities([entity_disabled]) + return True + + platform = MockPlatform(async_setup_entry=async_setup_entry) + config_entry = MockConfigEntry(entry_id="super-mock-id", domain=DOMAIN) + entity_platform = MockEntityPlatform( + hass, platform_name=config_entry.domain, platform=platform + ) + + device_registry = dr.async_get(hass) + device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections=connections, + disabled_by=dr.DeviceEntryDisabler.USER, + ) + + assert await entity_platform.async_setup_entry(config_entry) + await hass.async_block_till_done() + + assert entity_disabled.hass is None + assert entity_disabled.platform is None + + registry = er.async_get(hass) + + entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled") + assert entry_disabled.disabled_by == er.DISABLED_DEVICE + + async def test_entity_info_added_to_entity_registry(hass): """Test entity info is written to entity registry.""" component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20))