diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index 9a8d80461ae..7053e607161 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -42,6 +42,7 @@ ABBREVIATIONS = { "dev_cla": "device_class", "dock_t": "docked_topic", "dock_tpl": "docked_template", + "en": "enabled_by_default", "err_t": "error_topic", "err_tpl": "error_template", "fanspd_t": "fan_speed_topic", diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 332632f4e0f..9b1c7a9fb21 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -52,6 +52,7 @@ AVAILABILITY_MODES = [AVAILABILITY_ALL, AVAILABILITY_ANY, AVAILABILITY_LATEST] CONF_AVAILABILITY = "availability" CONF_AVAILABILITY_MODE = "availability_mode" CONF_AVAILABILITY_TOPIC = "availability_topic" +CONF_ENABLED_BY_DEFAULT = "enabled_by_default" CONF_PAYLOAD_AVAILABLE = "payload_available" CONF_PAYLOAD_NOT_AVAILABLE = "payload_not_available" CONF_JSON_ATTRS_TOPIC = "json_attributes_topic" @@ -140,6 +141,7 @@ MQTT_ENTITY_DEVICE_INFO_SCHEMA = vol.All( MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend( { vol.Optional(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, + vol.Optional(CONF_ENABLED_BY_DEFAULT, default=True): cv.boolean, vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_JSON_ATTRS_TOPIC): valid_subscribe_topic, vol.Optional(CONF_JSON_ATTRS_TEMPLATE): cv.template, @@ -353,7 +355,7 @@ async def cleanup_device_registry(hass, device_id): if ( device_id and not hass.helpers.entity_registry.async_entries_for_device( - entity_registry, device_id, include_disabled_entities=True + entity_registry, device_id, include_disabled_entities=False ) and not await device_trigger.async_get_triggers(hass, device_id) and not tag.async_has_tags(hass, device_id) @@ -586,6 +588,11 @@ class MqttEntity( async def _subscribe_topics(self): """(Re)Subscribe to topics.""" + @property + def entity_registry_enabled_default(self) -> bool: + """Return if the entity should be enabled when first added to the entity registry.""" + return self._config[CONF_ENABLED_BY_DEFAULT] + @property def icon(self): """Return icon of the entity if any.""" diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index 43f27373a3e..3d58cf834e9 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -1146,3 +1146,39 @@ async def help_test_entity_debug_info_update_entity_id(hass, mqtt_mock, domain, assert ( f"{domain}.test" not in hass.data[debug_info.DATA_MQTT_DEBUG_INFO]["entities"] ) + + +async def help_test_entity_disabled_by_default(hass, mqtt_mock, domain, config): + """Test device registry remove.""" + # Add device settings to config + config = copy.deepcopy(config[domain]) + config["device"] = copy.deepcopy(DEFAULT_CONFIG_DEVICE_INFO_ID) + config["enabled_by_default"] = False + config["unique_id"] = "veryunique1" + + dev_registry = dr.async_get(hass) + ent_registry = er.async_get(hass) + + # Discover a disabled entity + data = json.dumps(config) + async_fire_mqtt_message(hass, f"homeassistant/{domain}/bla1/config", data) + await hass.async_block_till_done() + entity_id = ent_registry.async_get_entity_id(domain, mqtt.DOMAIN, "veryunique1") + assert not hass.states.get(entity_id) + assert dev_registry.async_get_device({("mqtt", "helloworld")}) + + # Discover an enabled entity, tied to the same device + config["enabled_by_default"] = True + config["unique_id"] = "veryunique2" + data = json.dumps(config) + async_fire_mqtt_message(hass, f"homeassistant/{domain}/bla2/config", data) + await hass.async_block_till_done() + entity_id = ent_registry.async_get_entity_id(domain, mqtt.DOMAIN, "veryunique2") + assert hass.states.get(entity_id) + + # Remove the enabled entity, both entities and the device should be removed + async_fire_mqtt_message(hass, f"homeassistant/{domain}/bla2/config", "") + await hass.async_block_till_done() + assert not ent_registry.async_get_entity_id(domain, mqtt.DOMAIN, "veryunique1") + assert not ent_registry.async_get_entity_id(domain, mqtt.DOMAIN, "veryunique2") + assert not dev_registry.async_get_device({("mqtt", "helloworld")}) diff --git a/tests/components/mqtt/test_sensor.py b/tests/components/mqtt/test_sensor.py index 4e3634ebfa8..373048f6f1a 100644 --- a/tests/components/mqtt/test_sensor.py +++ b/tests/components/mqtt/test_sensor.py @@ -37,6 +37,7 @@ from .test_common import ( help_test_entity_device_info_update, help_test_entity_device_info_with_connection, help_test_entity_device_info_with_identifier, + help_test_entity_disabled_by_default, help_test_entity_id_update_discovery_update, help_test_entity_id_update_subscriptions, help_test_setting_attribute_via_mqtt_json_message, @@ -632,3 +633,10 @@ async def test_entity_debug_info_update_entity_id(hass, mqtt_mock): await help_test_entity_debug_info_update_entity_id( hass, mqtt_mock, sensor.DOMAIN, DEFAULT_CONFIG ) + + +async def test_entity_disabled_by_default(hass, mqtt_mock): + """Test entity disabled by default.""" + await help_test_entity_disabled_by_default( + hass, mqtt_mock, sensor.DOMAIN, DEFAULT_CONFIG + )