diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 540d09d7c9f..61c62a1eaa4 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -37,6 +37,7 @@ from homeassistant.exceptions import ( Unauthorized, ) from homeassistant.helpers import config_validation as cv, event, template +from homeassistant.helpers.device_registry import async_get_registry as get_dev_reg from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import Entity from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType @@ -48,6 +49,7 @@ from homeassistant.util.logging import catch_log_exception from . import config_flow, discovery, server # noqa: F401 pylint: disable=unused-import from .const import ( ATTR_DISCOVERY_HASH, + ATTR_DISCOVERY_TOPIC, CONF_BROKER, CONF_DISCOVERY, CONF_STATE_TOPIC, @@ -510,6 +512,7 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: hass.data[DATA_MQTT_HASS_CONFIG] = config websocket_api.async_register_command(hass, websocket_subscribe) + websocket_api.async_register_command(hass, websocket_remove_device) if conf is None: # If we have a config entry, setup is done by that config entry. @@ -1156,43 +1159,55 @@ class MqttAvailability(Entity): class MqttDiscoveryUpdate(Entity): """Mixin used to handle updated discovery message.""" - def __init__(self, discovery_hash, discovery_update=None) -> None: + def __init__(self, discovery_data, discovery_update=None) -> None: """Initialize the discovery update mixin.""" - self._discovery_hash = discovery_hash + self._discovery_data = discovery_data self._discovery_update = discovery_update self._remove_signal = None async def async_added_to_hass(self) -> None: """Subscribe to discovery updates.""" await super().async_added_to_hass() + discovery_hash = ( + self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None + ) @callback def discovery_callback(payload): """Handle discovery update.""" _LOGGER.info( - "Got update for entity with hash: %s '%s'", - self._discovery_hash, - payload, + "Got update for entity with hash: %s '%s'", discovery_hash, payload, ) if not payload: # Empty payload: Remove component _LOGGER.info("Removing component: %s", self.entity_id) self.hass.async_create_task(self.async_remove()) - clear_discovery_hash(self.hass, self._discovery_hash) + clear_discovery_hash(self.hass, discovery_hash) self._remove_signal() elif self._discovery_update: # Non-empty payload: Notify component _LOGGER.info("Updating component: %s", self.entity_id) - payload.pop(ATTR_DISCOVERY_HASH) self.hass.async_create_task(self._discovery_update(payload)) - if self._discovery_hash: + if discovery_hash: self._remove_signal = async_dispatcher_connect( self.hass, - MQTT_DISCOVERY_UPDATED.format(self._discovery_hash), + MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_callback, ) + async def async_removed_from_registry(self) -> None: + """Clear retained discovery topic in broker.""" + discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC] + publish( + self.hass, discovery_topic, "", retain=True, + ) + + async def async_will_remove_from_hass(self) -> None: + """Stop listening to signal.""" + if self._remove_signal: + self._remove_signal() + def device_info_from_config(config): """Return a device description for device registry.""" @@ -1247,6 +1262,25 @@ class MqttEntityDeviceInfo(Entity): return device_info_from_config(self._device_config) +@websocket_api.websocket_command( + {vol.Required("type"): "mqtt/device/remove", vol.Required("device_id"): str} +) +@websocket_api.async_response +async def websocket_remove_device(hass, connection, msg): + """Delete device.""" + device_id = msg["device_id"] + dev_registry = await get_dev_reg(hass) + + device = dev_registry.async_get(device_id) + for config_entry in device.config_entries: + config_entry = hass.config_entries.async_get_entry(config_entry) + # Only delete the device if it belongs to an MQTT device entry + if config_entry.domain == DOMAIN: + dev_registry.async_remove_device(device_id) + connection.send_message(websocket_api.result_message(msg["id"])) + break + + @websocket_api.async_response @websocket_api.websocket_command( { diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index 43d0bb570a8..043fa62f6ef 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -98,15 +98,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add an MQTT alarm control panel.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -115,10 +114,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT Alarm Control Panel platform.""" - async_add_entities([MqttAlarm(config, config_entry, discovery_hash)]) + async_add_entities([MqttAlarm(config, config_entry, discovery_data)]) class MqttAlarm( @@ -130,7 +129,7 @@ class MqttAlarm( ): """Representation of a MQTT alarm status.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Init the MQTT Alarm Control Panel.""" self._state = None self._config = config @@ -141,7 +140,7 @@ class MqttAlarm( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -207,6 +206,7 @@ class MqttAlarm( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index fe47729561d..d268c12aa87 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -79,15 +79,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT binary sensor.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -96,10 +95,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT binary sensor.""" - async_add_entities([MqttBinarySensor(config, config_entry, discovery_hash)]) + async_add_entities([MqttBinarySensor(config, config_entry, discovery_data)]) class MqttBinarySensor( @@ -111,7 +110,7 @@ class MqttBinarySensor( ): """Representation a binary sensor that is updated by MQTT.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the MQTT binary sensor.""" self._config = config self._unique_id = config.get(CONF_UNIQUE_ID) @@ -124,7 +123,7 @@ class MqttBinarySensor( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -229,6 +228,7 @@ class MqttBinarySensor( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @callback def value_is_expired(self, *_): diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index 6cf0865ff6a..9bbb1503196 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT camera.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -64,16 +63,16 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT Camera.""" - async_add_entities([MqttCamera(config, config_entry, discovery_hash)]) + async_add_entities([MqttCamera(config, config_entry, discovery_data)]) class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera): """representation of a MQTT camera.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the MQTT Camera.""" self._config = config self._unique_id = config.get(CONF_UNIQUE_ID) @@ -85,7 +84,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera): device_config = config.get(CONF_DEVICE) Camera.__init__(self) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -127,6 +126,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera): self._sub_state = await subscription.async_unsubscribe_topics( self.hass, self._sub_state ) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) async def async_camera_image(self): """Return image response.""" diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index 91a36a310cb..46404de0c8a 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -243,15 +243,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT climate device.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - hass, config, async_add_entities, config_entry, discovery_hash + hass, config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -260,10 +259,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - hass, config, async_add_entities, config_entry=None, discovery_hash=None + hass, config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT climate devices.""" - async_add_entities([MqttClimate(hass, config, config_entry, discovery_hash)]) + async_add_entities([MqttClimate(hass, config, config_entry, discovery_data)]) class MqttClimate( @@ -275,7 +274,7 @@ class MqttClimate( ): """Representation of an MQTT climate device.""" - def __init__(self, hass, config, config_entry, discovery_hash): + def __init__(self, hass, config, config_entry, discovery_data): """Initialize the climate device.""" self._config = config self._unique_id = config.get(CONF_UNIQUE_ID) @@ -303,7 +302,7 @@ class MqttClimate( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -552,6 +551,7 @@ class MqttClimate( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index 3234bebbfc1..6044ec2af6e 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -4,6 +4,7 @@ CONF_DISCOVERY = "discovery" DEFAULT_DISCOVERY = False ATTR_DISCOVERY_HASH = "discovery_hash" +ATTR_DISCOVERY_TOPIC = "discovery_topic" CONF_STATE_TOPIC = "state_topic" PROTOCOL_311 = "3.1.1" DEFAULT_QOS = 0 diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 885343b7090..a7a39678192 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -178,14 +178,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add an MQTT cover.""" - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) + discovery_data = discovery_payload.discovery_data try: config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -194,10 +194,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT Cover.""" - async_add_entities([MqttCover(config, config_entry, discovery_hash)]) + async_add_entities([MqttCover(config, config_entry, discovery_data)]) class MqttCover( @@ -209,7 +209,7 @@ class MqttCover( ): """Representation of a cover that can be controlled using MQTT.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the cover.""" self._unique_id = config.get(CONF_UNIQUE_ID) self._position = None @@ -227,7 +227,7 @@ class MqttCover( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -369,6 +369,7 @@ class MqttCover( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py index 3f0889875d0..4fcfd8f66f2 100644 --- a/homeassistant/components/mqtt/device_automation.py +++ b/homeassistant/components/mqtt/device_automation.py @@ -4,6 +4,7 @@ import logging import voluptuous as vol from homeassistant.components import mqtt +from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.dispatcher import async_dispatcher_connect from . import ATTR_DISCOVERY_HASH, device_trigger @@ -25,20 +26,26 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( async def async_setup_entry(hass, config_entry): """Set up MQTT device automation dynamically through MQTT discovery.""" + async def async_device_removed(event): + """Handle the removal of a device.""" + if event.data["action"] != "remove": + return + await device_trigger.async_device_removed(hass, event.data["device_id"]) + async def async_discover(discovery_payload): """Discover and add an MQTT device automation.""" - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) + discovery_data = discovery_payload.discovery_data try: config = PLATFORM_SCHEMA(discovery_payload) if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER: await device_trigger.async_setup_trigger( - hass, config, config_entry, discovery_hash + hass, config, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( hass, MQTT_DISCOVERY_NEW.format("device_automation", "mqtt"), async_discover ) + hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed) diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 2149024266d..92bef0578c9 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -1,6 +1,6 @@ """Provides device automations for MQTT.""" import logging -from typing import List +from typing import Callable, List import attr import voluptuous as vol @@ -99,9 +99,11 @@ class Trigger: """Device trigger settings.""" device_id = attr.ib(type=str) + discovery_hash = attr.ib(type=tuple) hass = attr.ib(type=HomeAssistantType) payload = attr.ib(type=str) qos = attr.ib(type=int) + remove_signal = attr.ib(type=Callable[[], None]) subtype = attr.ib(type=str) topic = attr.ib(type=str) type = attr.ib(type=str) @@ -128,8 +130,10 @@ class Trigger: return async_remove - async def update_trigger(self, config): + async def update_trigger(self, config, discovery_hash, remove_signal): """Update MQTT device trigger.""" + self.discovery_hash = discovery_hash + self.remove_signal = remove_signal self.type = config[CONF_TYPE] self.subtype = config[CONF_SUBTYPE] self.topic = config[CONF_TOPIC] @@ -143,8 +147,8 @@ class Trigger: def detach_trigger(self): """Remove MQTT device trigger.""" # Mark trigger as unknown - self.topic = None + # Unsubscribe if this trigger is in use for trig in self.trigger_instances: if trig.remove: @@ -163,9 +167,10 @@ async def _update_device(hass, config_entry, config): device_registry.async_get_or_create(**device_info) -async def async_setup_trigger(hass, config, config_entry, discovery_hash): +async def async_setup_trigger(hass, config, config_entry, discovery_data): """Set up the MQTT device trigger.""" config = TRIGGER_DISCOVERY_SCHEMA(config) + discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] remove_signal = None @@ -185,11 +190,10 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash): else: # Non-empty payload: Update trigger _LOGGER.info("Updating trigger: %s", discovery_hash) - payload.pop(ATTR_DISCOVERY_HASH) config = TRIGGER_DISCOVERY_SCHEMA(payload) await _update_device(hass, config_entry, config) device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] - await device_trigger.update_trigger(config) + await device_trigger.update_trigger(config, discovery_hash, remove_signal) remove_signal = async_dispatcher_connect( hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update @@ -212,14 +216,29 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash): hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass=hass, device_id=device.id, + discovery_hash=discovery_hash, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=config[CONF_TOPIC], payload=config[CONF_PAYLOAD], qos=config[CONF_QOS], + remove_signal=remove_signal, ) else: - await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(config) + await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( + config, discovery_hash, remove_signal + ) + + +async def async_device_removed(hass: HomeAssistant, device_id: str): + """Handle the removal of a device.""" + triggers = await async_get_triggers(hass, device_id) + for trig in triggers: + device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID]) + if device_trigger: + device_trigger.detach_trigger() + clear_discovery_hash(hass, device_trigger.discovery_hash) + device_trigger.remove_signal() async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]: @@ -262,6 +281,8 @@ async def async_attach_trigger( hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass=hass, device_id=device_id, + discovery_hash=None, + remove_signal=None, type=config[CONF_TYPE], subtype=config[CONF_SUBTYPE], topic=None, diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 418f648564d..c54ab395c94 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -11,7 +11,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import HomeAssistantType from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS -from .const import ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC +from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_TOPIC, CONF_STATE_TOPIC _LOGGER = logging.getLogger(__name__) @@ -137,6 +137,11 @@ async def async_start( if payload: # Attach MQTT topic to the payload, used for debug prints setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')") + discovery_data = { + ATTR_DISCOVERY_HASH: discovery_hash, + ATTR_DISCOVERY_TOPIC: topic, + } + setattr(payload, "discovery_data", discovery_data) if CONF_PLATFORM in payload and "schema" not in payload: platform = payload[CONF_PLATFORM] @@ -173,8 +178,6 @@ async def async_start( topic, ) - payload[ATTR_DISCOVERY_HASH] = discovery_hash - if ALREADY_DISCOVERED not in hass.data: hass.data[ALREADY_DISCOVERED] = {} if discovery_hash in hass.data[ALREADY_DISCOVERED]: diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index c5e4b3145de..b50bdf9734b 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -118,15 +118,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT fan.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -135,10 +134,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT fan.""" - async_add_entities([MqttFan(config, config_entry, discovery_hash)]) + async_add_entities([MqttFan(config, config_entry, discovery_data)]) class MqttFan( @@ -150,7 +149,7 @@ class MqttFan( ): """A MQTT fan component.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the MQTT fan.""" self._unique_id = config.get(CONF_UNIQUE_ID) self._state = False @@ -173,7 +172,7 @@ class MqttFan( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -317,6 +316,7 @@ class MqttFan( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/light/__init__.py b/homeassistant/components/mqtt/light/__init__.py index 511ee6049df..d48b4ae4762 100644 --- a/homeassistant/components/mqtt/light/__init__.py +++ b/homeassistant/components/mqtt/light/__init__.py @@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT light.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -64,7 +63,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up a MQTT Light.""" setup_entity = { @@ -73,5 +72,5 @@ async def _async_setup_entity( "template": async_setup_entity_template, } await setup_entity[config[CONF_SCHEMA]]( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index 23f8684cf46..a9ea21b4b0a 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -146,12 +146,12 @@ PLATFORM_SCHEMA_BASIC = ( async def async_setup_entity_basic( - config, async_add_entities, config_entry, discovery_hash=None + config, async_add_entities, config_entry, discovery_data=None ): """Set up a MQTT Light.""" config.setdefault(CONF_STATE_VALUE_TEMPLATE, config.get(CONF_VALUE_TEMPLATE)) - async_add_entities([MqttLight(config, config_entry, discovery_hash)]) + async_add_entities([MqttLight(config, config_entry, discovery_data)]) class MqttLight( @@ -164,7 +164,7 @@ class MqttLight( ): """Representation of a MQTT light.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize MQTT light.""" self._state = False self._sub_state = None @@ -194,7 +194,7 @@ class MqttLight( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -535,6 +535,7 @@ class MqttLight( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def brightness(self): diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index e7256614002..60ecf80fb63 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -119,10 +119,10 @@ PLATFORM_SCHEMA_JSON = ( async def async_setup_entity_json( - config: ConfigType, async_add_entities, config_entry, discovery_hash + config: ConfigType, async_add_entities, config_entry, discovery_data ): """Set up a MQTT JSON Light.""" - async_add_entities([MqttLightJson(config, config_entry, discovery_hash)]) + async_add_entities([MqttLightJson(config, config_entry, discovery_data)]) class MqttLightJson( @@ -135,7 +135,7 @@ class MqttLightJson( ): """Representation of a MQTT JSON light.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize MQTT JSON light.""" self._state = False self._sub_state = None @@ -158,7 +158,7 @@ class MqttLightJson( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -346,6 +346,7 @@ class MqttLightJson( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def brightness(self): diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 6bbf5ee1572..853e7f4411f 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -93,10 +93,10 @@ PLATFORM_SCHEMA_TEMPLATE = ( async def async_setup_entity_template( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ): """Set up a MQTT Template light.""" - async_add_entities([MqttTemplate(config, config_entry, discovery_hash)]) + async_add_entities([MqttTemplate(config, config_entry, discovery_data)]) class MqttTemplate( @@ -109,7 +109,7 @@ class MqttTemplate( ): """Representation of a MQTT Template light.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize a MQTT Template light.""" self._state = False self._sub_state = None @@ -133,7 +133,7 @@ class MqttTemplate( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -323,6 +323,7 @@ class MqttTemplate( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def brightness(self): diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index 6910e955288..89f005b7469 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -80,15 +80,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add an MQTT lock.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -97,10 +96,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT Lock platform.""" - async_add_entities([MqttLock(config, config_entry, discovery_hash)]) + async_add_entities([MqttLock(config, config_entry, discovery_data)]) class MqttLock( @@ -112,7 +111,7 @@ class MqttLock( ): """Representation of a lock that can be toggled using MQTT.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the lock.""" self._unique_id = config.get(CONF_UNIQUE_ID) self._state = False @@ -126,7 +125,7 @@ class MqttLock( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -192,6 +191,7 @@ class MqttLock( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index 967a434c9d5..07910697d21 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover_sensor(discovery_payload): """Discover and add a discovered MQTT sensor.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config: ConfigType, async_add_entities, config_entry=None, discovery_hash=None + config: ConfigType, async_add_entities, config_entry=None, discovery_data=None ): """Set up MQTT sensor.""" - async_add_entities([MqttSensor(config, config_entry, discovery_hash)]) + async_add_entities([MqttSensor(config, config_entry, discovery_data)]) class MqttSensor( @@ -104,7 +103,7 @@ class MqttSensor( ): """Representation of a sensor that can be updated using MQTT.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the sensor.""" self._config = config self._unique_id = config.get(CONF_UNIQUE_ID) @@ -123,7 +122,7 @@ class MqttSensor( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -208,6 +207,7 @@ class MqttSensor( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @callback def value_is_expired(self, *_): diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 65b43f6bf53..32066c67b7a 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT switch.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry=None, discovery_hash=None + config, async_add_entities, config_entry=None, discovery_data=None ): """Set up the MQTT switch.""" - async_add_entities([MqttSwitch(config, config_entry, discovery_hash)]) + async_add_entities([MqttSwitch(config, config_entry, discovery_data)]) class MqttSwitch( @@ -109,7 +108,7 @@ class MqttSwitch( ): """Representation of a switch that can be toggled using MQTT.""" - def __init__(self, config, config_entry, discovery_hash): + def __init__(self, config, config_entry, discovery_data): """Initialize the MQTT switch.""" self._state = False self._sub_state = None @@ -126,7 +125,7 @@ class MqttSwitch( MqttAttributes.__init__(self, config) MqttAvailability.__init__(self, config) - MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) + MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update) MqttEntityDeviceInfo.__init__(self, device_config, config_entry) async def async_added_to_hass(self): @@ -203,6 +202,7 @@ class MqttSwitch( ) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) @property def should_poll(self): diff --git a/homeassistant/components/mqtt/vacuum/__init__.py b/homeassistant/components/mqtt/vacuum/__init__.py index d33a23f3a6d..b16ec7aaf74 100644 --- a/homeassistant/components/mqtt/vacuum/__init__.py +++ b/homeassistant/components/mqtt/vacuum/__init__.py @@ -39,15 +39,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def async_discover(discovery_payload): """Discover and add a MQTT vacuum.""" + discovery_data = discovery_payload.discovery_data try: - discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) config = PLATFORM_SCHEMA(discovery_payload) await _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) except Exception: - if discovery_hash: - clear_discovery_hash(hass, discovery_hash) + clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH]) raise async_dispatcher_connect( @@ -56,10 +55,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async def _async_setup_entity( - config, async_add_entities, config_entry, discovery_hash=None + config, async_add_entities, config_entry, discovery_data=None ): """Set up the MQTT vacuum.""" setup_entity = {LEGACY: async_setup_entity_legacy, STATE: async_setup_entity_state} await setup_entity[config[CONF_SCHEMA]]( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ) diff --git a/homeassistant/components/mqtt/vacuum/schema_legacy.py b/homeassistant/components/mqtt/vacuum/schema_legacy.py index c6322d9fec5..eff7cc1b039 100644 --- a/homeassistant/components/mqtt/vacuum/schema_legacy.py +++ b/homeassistant/components/mqtt/vacuum/schema_legacy.py @@ -162,10 +162,10 @@ PLATFORM_SCHEMA_LEGACY = ( async def async_setup_entity_legacy( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ): """Set up a MQTT Vacuum Legacy.""" - async_add_entities([MqttVacuum(config, config_entry, discovery_hash)]) + async_add_entities([MqttVacuum(config, config_entry, discovery_data)]) class MqttVacuum( @@ -269,6 +269,7 @@ class MqttVacuum( await subscription.async_unsubscribe_topics(self.hass, self._sub_state) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/components/mqtt/vacuum/schema_state.py b/homeassistant/components/mqtt/vacuum/schema_state.py index 0399e66c0ad..f9bcc7e845e 100644 --- a/homeassistant/components/mqtt/vacuum/schema_state.py +++ b/homeassistant/components/mqtt/vacuum/schema_state.py @@ -157,10 +157,10 @@ PLATFORM_SCHEMA_STATE = ( async def async_setup_entity_state( - config, async_add_entities, config_entry, discovery_hash + config, async_add_entities, config_entry, discovery_data ): """Set up a State MQTT Vacuum.""" - async_add_entities([MqttStateVacuum(config, config_entry, discovery_hash)]) + async_add_entities([MqttStateVacuum(config, config_entry, discovery_data)]) class MqttStateVacuum( @@ -234,6 +234,7 @@ class MqttStateVacuum( await subscription.async_unsubscribe_topics(self.hass, self._sub_state) await MqttAttributes.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self) + await MqttDiscoveryUpdate.async_will_remove_from_hass(self) async def _subscribe_topics(self): """(Re)Subscribe to topics.""" diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 49ed0f4a567..186aecd78f4 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -488,6 +488,12 @@ class Entity(ABC): self._on_remove = [] self._on_remove.append(func) + async def async_removed_from_registry(self) -> None: + """Run when entity has been removed from entity registry. + + To be extended by integrations. + """ + async def async_remove(self) -> None: """Remove entity from Home Assistant.""" assert self.hass is not None @@ -534,6 +540,9 @@ class Entity(ABC): async def _async_registry_updated(self, event): """Handle entity registry update.""" data = event.data + if data["action"] == "remove" and data["entity_id"] == self.entity_id: + await self.async_removed_from_registry() + if ( data["action"] != "update" or data.get("old_entity_id", data["entity_id"]) != self.entity_id diff --git a/tests/common.py b/tests/common.py index 5a00a2bc7df..4581c96b52a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -323,11 +323,15 @@ async def async_mock_mqtt_component(hass, config=None): if config is None: config = {mqtt.CONF_BROKER: "mock-broker"} + async def _async_fire_mqtt_message(topic, payload, qos, retain): + async_fire_mqtt_message(hass, topic, payload, qos, retain) + with patch("paho.mqtt.client.Client") as mock_client: mock_client().connect.return_value = 0 mock_client().subscribe.return_value = (0, 0) mock_client().unsubscribe.return_value = (0, 0) mock_client().publish.return_value = (0, 0) + mock_client().publish.side_effect = _async_fire_mqtt_message result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) assert result diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py index c3ba6eebadd..c9d9ec4ad08 100644 --- a/tests/components/mqtt/test_device_trigger.py +++ b/tests/components/mqtt/test_device_trigger.py @@ -468,7 +468,7 @@ async def test_if_fires_on_mqtt_message_after_update( assert len(calls) == 2 -async def test_not_fires_on_mqtt_message_after_remove( +async def test_not_fires_on_mqtt_message_after_remove_by_mqtt( hass, device_reg, calls, mqtt_mock ): """Test triggers not firing after removal.""" @@ -532,6 +532,62 @@ async def test_not_fires_on_mqtt_message_after_remove( assert len(calls) == 2 +async def test_not_fires_on_mqtt_message_after_remove_from_registry( + hass, device_reg, calls, mqtt_mock +): + """Test triggers not firing after removal.""" + config_entry = MockConfigEntry(domain=DOMAIN, data={}) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data1 = ( + '{ "automation_type":"trigger",' + ' "device":{"identifiers":["0AFFD2"]},' + ' "topic": "foobar/triggers/button1",' + ' "type": "button_short_press",' + ' "subtype": "button_1" }' + ) + async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1) + await hass.async_block_till_done() + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "device", + "domain": DOMAIN, + "device_id": device_entry.id, + "discovery_id": "bla1", + "type": "button_short_press", + "subtype": "button_1", + }, + "action": { + "service": "test.automation", + "data_template": {"some": ("short_press")}, + }, + }, + ] + }, + ) + + # Fake short press. + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + # Remove the device + device_reg.async_remove_device(device_entry.id) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") + await hass.async_block_till_done() + assert len(calls) == 1 + + async def test_attach_remove(hass, device_reg, mqtt_mock): """Test attach and removal of trigger.""" config_entry = MockConfigEntry(domain=DOMAIN, data={}) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index e09b4d786a6..4a28b95e32c 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -3,6 +3,8 @@ from pathlib import Path import re from unittest.mock import patch +import pytest + from homeassistant.components import mqtt from homeassistant.components.mqtt.abbreviations import ( ABBREVIATIONS, @@ -11,7 +13,25 @@ from homeassistant.components.mqtt.abbreviations import ( from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start from homeassistant.const import STATE_OFF, STATE_ON -from tests.common import MockConfigEntry, async_fire_mqtt_message, mock_coro +from tests.common import ( + MockConfigEntry, + async_fire_mqtt_message, + mock_coro, + mock_device_registry, + mock_registry, +) + + +@pytest.fixture +def device_reg(hass): + """Return an empty, loaded, registry.""" + return mock_device_registry(hass) + + +@pytest.fixture +def entity_reg(hass): + """Return an empty, loaded, registry.""" + return mock_registry(hass) async def test_subscribing_config_topic(hass, mqtt_mock): @@ -213,6 +233,114 @@ async def test_non_duplicate_discovery(hass, mqtt_mock, caplog): assert "Component has already been discovered: binary_sensor bla" in caplog.text +async def test_removal(hass, mqtt_mock, caplog): + """Test removal of component through empty discovery message.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN) + + await async_start(hass, "homeassistant", {}, entry) + + async_fire_mqtt_message( + hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }' + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer") + assert state is not None + + async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "") + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer") + assert state is None + + +async def test_rediscover(hass, mqtt_mock, caplog): + """Test rediscover of removed component.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN) + + await async_start(hass, "homeassistant", {}, entry) + + async_fire_mqtt_message( + hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }' + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer") + assert state is not None + + async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "") + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer") + assert state is None + + async_fire_mqtt_message( + hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }' + ) + await hass.async_block_till_done() + state = hass.states.get("binary_sensor.beer") + assert state is not None + + +async def test_duplicate_removal(hass, mqtt_mock, caplog): + """Test for a non duplicate component.""" + entry = MockConfigEntry(domain=mqtt.DOMAIN) + + await async_start(hass, "homeassistant", {}, entry) + + async_fire_mqtt_message( + hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }' + ) + await hass.async_block_till_done() + async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "") + await hass.async_block_till_done() + assert "Component has already been discovered: binary_sensor bla" in caplog.text + caplog.clear() + async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "") + await hass.async_block_till_done() + + assert "Component has already been discovered: binary_sensor bla" not in caplog.text + + +async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): + """Test discvered device is cleaned up when removed from registry.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) + await hass.async_block_till_done() + + # Verify device and registry entries are created + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + assert device_entry is not None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is not None + + state = hass.states.get("sensor.mqtt_sensor") + assert state is not None + + device_reg.async_remove_device(device_entry.id) + await hass.async_block_till_done() + + # Verify device and registry entries are cleared + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + assert device_entry is None + entity_entry = entity_reg.async_get("sensor.mqtt_sensor") + assert entity_entry is None + + # Verify state is removed + state = hass.states.get("sensor.mqtt_sensor") + assert state is None + + # Verify retained discovery topic has been cleared + mqtt_mock.async_publish.assert_called_once_with( + "homeassistant/sensor/bla/config", "", 0, True + ) + + async def test_discovery_expansion(hass, mqtt_mock, caplog): """Test expansion of abbreviated discovery payload.""" entry = MockConfigEntry(domain=mqtt.DOMAIN) diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index dc79cb8a2e7..5dc05a95a55 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -8,6 +8,7 @@ import pytest import voluptuous as vol from homeassistant.components import mqtt +from homeassistant.components.mqtt.discovery import async_start from homeassistant.const import ( ATTR_DOMAIN, ATTR_SERVICE, @@ -27,11 +28,25 @@ from tests.common import ( fire_mqtt_message, get_test_home_assistant, mock_coro, + mock_device_registry, mock_mqtt_component, + mock_registry, threadsafe_coroutine_factory, ) +@pytest.fixture +def device_reg(hass): + """Return an empty, loaded, registry.""" + return mock_device_registry(hass) + + +@pytest.fixture +def entity_reg(hass): + """Return an empty, loaded, registry.""" + return mock_registry(hass) + + @pytest.fixture def mock_MQTT(): """Make sure connection is established.""" @@ -828,3 +843,70 @@ async def test_dump_service(hass): assert len(writes) == 2 assert writes[0][1][0] == "bla/1,test1\n" assert writes[1][1][0] == "bla/2,test2\n" + + +async def test_mqtt_ws_remove_discovered_device( + hass, device_reg, entity_reg, hass_ws_client, mqtt_mock +): + """Test MQTT websocket device removal.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) + await hass.async_block_till_done() + + # Verify device entry is created + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + assert device_entry is not None + + client = await hass_ws_client(hass) + await client.send_json( + {"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id} + ) + response = await client.receive_json() + assert response["success"] + + # Verify device entry is cleared + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + assert device_entry is None + + +async def test_mqtt_ws_remove_discovered_device_twice( + hass, device_reg, hass_ws_client, mqtt_mock +): + """Test MQTT websocket device removal.""" + config_entry = MockConfigEntry(domain=mqtt.DOMAIN) + config_entry.add_to_hass(hass) + await async_start(hass, "homeassistant", {}, config_entry) + + data = ( + '{ "device":{"identifiers":["0AFFD2"]},' + ' "state_topic": "foobar/sensor",' + ' "unique_id": "unique" }' + ) + + async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) + await hass.async_block_till_done() + + device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set()) + assert device_entry is not None + + client = await hass_ws_client(hass) + await client.send_json( + {"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id} + ) + response = await client.receive_json() + assert response["success"] + + await client.send_json( + {"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id} + ) + response = await client.receive_json() + assert not response["success"]