Add MQTT WS command to remove device (#31989)

* Add MQTT WS command to remove device

* Review comments, fix test

* Fix tests
This commit is contained in:
Erik Montnemery 2020-02-25 05:46:02 +01:00 committed by GitHub
parent 4236d62b44
commit 7e387f93d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 473 additions and 124 deletions

View file

@ -37,6 +37,7 @@ from homeassistant.exceptions import (
Unauthorized, Unauthorized,
) )
from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import async_get_registry as get_dev_reg
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType
@ -48,6 +49,7 @@ from homeassistant.util.logging import catch_log_exception
from . import config_flow, discovery, server # noqa: F401 pylint: disable=unused-import from . import config_flow, discovery, server # noqa: F401 pylint: disable=unused-import
from .const import ( from .const import (
ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
CONF_BROKER, CONF_BROKER,
CONF_DISCOVERY, CONF_DISCOVERY,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
@ -510,6 +512,7 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
hass.data[DATA_MQTT_HASS_CONFIG] = config hass.data[DATA_MQTT_HASS_CONFIG] = config
websocket_api.async_register_command(hass, websocket_subscribe) websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_remove_device)
if conf is None: if conf is None:
# If we have a config entry, setup is done by that config entry. # If we have a config entry, setup is done by that config entry.
@ -1156,43 +1159,55 @@ class MqttAvailability(Entity):
class MqttDiscoveryUpdate(Entity): class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message.""" """Mixin used to handle updated discovery message."""
def __init__(self, discovery_hash, discovery_update=None) -> None: def __init__(self, discovery_data, discovery_update=None) -> None:
"""Initialize the discovery update mixin.""" """Initialize the discovery update mixin."""
self._discovery_hash = discovery_hash self._discovery_data = discovery_data
self._discovery_update = discovery_update self._discovery_update = discovery_update
self._remove_signal = None self._remove_signal = None
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Subscribe to discovery updates.""" """Subscribe to discovery updates."""
await super().async_added_to_hass() await super().async_added_to_hass()
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
@callback @callback
def discovery_callback(payload): def discovery_callback(payload):
"""Handle discovery update.""" """Handle discovery update."""
_LOGGER.info( _LOGGER.info(
"Got update for entity with hash: %s '%s'", "Got update for entity with hash: %s '%s'", discovery_hash, payload,
self._discovery_hash,
payload,
) )
if not payload: if not payload:
# Empty payload: Remove component # Empty payload: Remove component
_LOGGER.info("Removing component: %s", self.entity_id) _LOGGER.info("Removing component: %s", self.entity_id)
self.hass.async_create_task(self.async_remove()) self.hass.async_create_task(self.async_remove())
clear_discovery_hash(self.hass, self._discovery_hash) clear_discovery_hash(self.hass, discovery_hash)
self._remove_signal() self._remove_signal()
elif self._discovery_update: elif self._discovery_update:
# Non-empty payload: Notify component # Non-empty payload: Notify component
_LOGGER.info("Updating component: %s", self.entity_id) _LOGGER.info("Updating component: %s", self.entity_id)
payload.pop(ATTR_DISCOVERY_HASH)
self.hass.async_create_task(self._discovery_update(payload)) self.hass.async_create_task(self._discovery_update(payload))
if self._discovery_hash: if discovery_hash:
self._remove_signal = async_dispatcher_connect( self._remove_signal = async_dispatcher_connect(
self.hass, self.hass,
MQTT_DISCOVERY_UPDATED.format(self._discovery_hash), MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback, discovery_callback,
) )
async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker."""
discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC]
publish(
self.hass, discovery_topic, "", retain=True,
)
async def async_will_remove_from_hass(self) -> None:
"""Stop listening to signal."""
if self._remove_signal:
self._remove_signal()
def device_info_from_config(config): def device_info_from_config(config):
"""Return a device description for device registry.""" """Return a device description for device registry."""
@ -1247,6 +1262,25 @@ class MqttEntityDeviceInfo(Entity):
return device_info_from_config(self._device_config) return device_info_from_config(self._device_config)
@websocket_api.websocket_command(
{vol.Required("type"): "mqtt/device/remove", vol.Required("device_id"): str}
)
@websocket_api.async_response
async def websocket_remove_device(hass, connection, msg):
"""Delete device."""
device_id = msg["device_id"]
dev_registry = await get_dev_reg(hass)
device = dev_registry.async_get(device_id)
for config_entry in device.config_entries:
config_entry = hass.config_entries.async_get_entry(config_entry)
# Only delete the device if it belongs to an MQTT device entry
if config_entry.domain == DOMAIN:
dev_registry.async_remove_device(device_id)
connection.send_message(websocket_api.result_message(msg["id"]))
break
@websocket_api.async_response @websocket_api.async_response
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {

View file

@ -98,15 +98,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT alarm control panel.""" """Discover and add an MQTT alarm control panel."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -115,10 +114,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT Alarm Control Panel platform.""" """Set up the MQTT Alarm Control Panel platform."""
async_add_entities([MqttAlarm(config, config_entry, discovery_hash)]) async_add_entities([MqttAlarm(config, config_entry, discovery_data)])
class MqttAlarm( class MqttAlarm(
@ -130,7 +129,7 @@ class MqttAlarm(
): ):
"""Representation of a MQTT alarm status.""" """Representation of a MQTT alarm status."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Init the MQTT Alarm Control Panel.""" """Init the MQTT Alarm Control Panel."""
self._state = None self._state = None
self._config = config self._config = config
@ -141,7 +140,7 @@ class MqttAlarm(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -207,6 +206,7 @@ class MqttAlarm(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -79,15 +79,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT binary sensor.""" """Discover and add a MQTT binary sensor."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -96,10 +95,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT binary sensor.""" """Set up the MQTT binary sensor."""
async_add_entities([MqttBinarySensor(config, config_entry, discovery_hash)]) async_add_entities([MqttBinarySensor(config, config_entry, discovery_data)])
class MqttBinarySensor( class MqttBinarySensor(
@ -111,7 +110,7 @@ class MqttBinarySensor(
): ):
"""Representation a binary sensor that is updated by MQTT.""" """Representation a binary sensor that is updated by MQTT."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT binary sensor.""" """Initialize the MQTT binary sensor."""
self._config = config self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
@ -124,7 +123,7 @@ class MqttBinarySensor(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -229,6 +228,7 @@ class MqttBinarySensor(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@callback @callback
def value_is_expired(self, *_): def value_is_expired(self, *_):

View file

@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT camera.""" """Discover and add a MQTT camera."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -64,16 +63,16 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT Camera.""" """Set up the MQTT Camera."""
async_add_entities([MqttCamera(config, config_entry, discovery_hash)]) async_add_entities([MqttCamera(config, config_entry, discovery_data)])
class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera): class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
"""representation of a MQTT camera.""" """representation of a MQTT camera."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT Camera.""" """Initialize the MQTT Camera."""
self._config = config self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
@ -85,7 +84,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
device_config = config.get(CONF_DEVICE) device_config = config.get(CONF_DEVICE)
Camera.__init__(self) Camera.__init__(self)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -127,6 +126,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
self._sub_state = await subscription.async_unsubscribe_topics( self._sub_state = await subscription.async_unsubscribe_topics(
self.hass, self._sub_state self.hass, self._sub_state
) )
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def async_camera_image(self): async def async_camera_image(self):
"""Return image response.""" """Return image response."""

View file

@ -243,15 +243,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT climate device.""" """Discover and add a MQTT climate device."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
hass, config, async_add_entities, config_entry, discovery_hash hass, config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -260,10 +259,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
hass, config, async_add_entities, config_entry=None, discovery_hash=None hass, config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT climate devices.""" """Set up the MQTT climate devices."""
async_add_entities([MqttClimate(hass, config, config_entry, discovery_hash)]) async_add_entities([MqttClimate(hass, config, config_entry, discovery_data)])
class MqttClimate( class MqttClimate(
@ -275,7 +274,7 @@ class MqttClimate(
): ):
"""Representation of an MQTT climate device.""" """Representation of an MQTT climate device."""
def __init__(self, hass, config, config_entry, discovery_hash): def __init__(self, hass, config, config_entry, discovery_data):
"""Initialize the climate device.""" """Initialize the climate device."""
self._config = config self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
@ -303,7 +302,7 @@ class MqttClimate(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -552,6 +551,7 @@ class MqttClimate(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -4,6 +4,7 @@ CONF_DISCOVERY = "discovery"
DEFAULT_DISCOVERY = False DEFAULT_DISCOVERY = False
ATTR_DISCOVERY_HASH = "discovery_hash" ATTR_DISCOVERY_HASH = "discovery_hash"
ATTR_DISCOVERY_TOPIC = "discovery_topic"
CONF_STATE_TOPIC = "state_topic" CONF_STATE_TOPIC = "state_topic"
PROTOCOL_311 = "3.1.1" PROTOCOL_311 = "3.1.1"
DEFAULT_QOS = 0 DEFAULT_QOS = 0

View file

@ -178,14 +178,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT cover.""" """Discover and add an MQTT cover."""
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) discovery_data = discovery_payload.discovery_data
try: try:
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
clear_discovery_hash(hass, discovery_hash) clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -194,10 +194,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT Cover.""" """Set up the MQTT Cover."""
async_add_entities([MqttCover(config, config_entry, discovery_hash)]) async_add_entities([MqttCover(config, config_entry, discovery_data)])
class MqttCover( class MqttCover(
@ -209,7 +209,7 @@ class MqttCover(
): ):
"""Representation of a cover that can be controlled using MQTT.""" """Representation of a cover that can be controlled using MQTT."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the cover.""" """Initialize the cover."""
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
self._position = None self._position = None
@ -227,7 +227,7 @@ class MqttCover(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -369,6 +369,7 @@ class MqttCover(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -4,6 +4,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from . import ATTR_DISCOVERY_HASH, device_trigger from . import ATTR_DISCOVERY_HASH, device_trigger
@ -25,20 +26,26 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
async def async_setup_entry(hass, config_entry): async def async_setup_entry(hass, config_entry):
"""Set up MQTT device automation dynamically through MQTT discovery.""" """Set up MQTT device automation dynamically through MQTT discovery."""
async def async_device_removed(event):
"""Handle the removal of a device."""
if event.data["action"] != "remove":
return
await device_trigger.async_device_removed(hass, event.data["device_id"])
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT device automation.""" """Discover and add an MQTT device automation."""
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH) discovery_data = discovery_payload.discovery_data
try: try:
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER: if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
await device_trigger.async_setup_trigger( await device_trigger.async_setup_trigger(
hass, config, config_entry, discovery_hash hass, config, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW.format("device_automation", "mqtt"), async_discover hass, MQTT_DISCOVERY_NEW.format("device_automation", "mqtt"), async_discover
) )
hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed)

View file

@ -1,6 +1,6 @@
"""Provides device automations for MQTT.""" """Provides device automations for MQTT."""
import logging import logging
from typing import List from typing import Callable, List
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -99,9 +99,11 @@ class Trigger:
"""Device trigger settings.""" """Device trigger settings."""
device_id = attr.ib(type=str) device_id = attr.ib(type=str)
discovery_hash = attr.ib(type=tuple)
hass = attr.ib(type=HomeAssistantType) hass = attr.ib(type=HomeAssistantType)
payload = attr.ib(type=str) payload = attr.ib(type=str)
qos = attr.ib(type=int) qos = attr.ib(type=int)
remove_signal = attr.ib(type=Callable[[], None])
subtype = attr.ib(type=str) subtype = attr.ib(type=str)
topic = attr.ib(type=str) topic = attr.ib(type=str)
type = attr.ib(type=str) type = attr.ib(type=str)
@ -128,8 +130,10 @@ class Trigger:
return async_remove return async_remove
async def update_trigger(self, config): async def update_trigger(self, config, discovery_hash, remove_signal):
"""Update MQTT device trigger.""" """Update MQTT device trigger."""
self.discovery_hash = discovery_hash
self.remove_signal = remove_signal
self.type = config[CONF_TYPE] self.type = config[CONF_TYPE]
self.subtype = config[CONF_SUBTYPE] self.subtype = config[CONF_SUBTYPE]
self.topic = config[CONF_TOPIC] self.topic = config[CONF_TOPIC]
@ -143,8 +147,8 @@ class Trigger:
def detach_trigger(self): def detach_trigger(self):
"""Remove MQTT device trigger.""" """Remove MQTT device trigger."""
# Mark trigger as unknown # Mark trigger as unknown
self.topic = None self.topic = None
# Unsubscribe if this trigger is in use # Unsubscribe if this trigger is in use
for trig in self.trigger_instances: for trig in self.trigger_instances:
if trig.remove: if trig.remove:
@ -163,9 +167,10 @@ async def _update_device(hass, config_entry, config):
device_registry.async_get_or_create(**device_info) device_registry.async_get_or_create(**device_info)
async def async_setup_trigger(hass, config, config_entry, discovery_hash): async def async_setup_trigger(hass, config, config_entry, discovery_data):
"""Set up the MQTT device trigger.""" """Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config) config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
remove_signal = None remove_signal = None
@ -185,11 +190,10 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash):
else: else:
# Non-empty payload: Update trigger # Non-empty payload: Update trigger
_LOGGER.info("Updating trigger: %s", discovery_hash) _LOGGER.info("Updating trigger: %s", discovery_hash)
payload.pop(ATTR_DISCOVERY_HASH)
config = TRIGGER_DISCOVERY_SCHEMA(payload) config = TRIGGER_DISCOVERY_SCHEMA(payload)
await _update_device(hass, config_entry, config) await _update_device(hass, config_entry, config)
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id] device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config) await device_trigger.update_trigger(config, discovery_hash, remove_signal)
remove_signal = async_dispatcher_connect( remove_signal = async_dispatcher_connect(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
@ -212,14 +216,29 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash):
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass, hass=hass,
device_id=device.id, device_id=device.id,
discovery_hash=discovery_hash,
type=config[CONF_TYPE], type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE], subtype=config[CONF_SUBTYPE],
topic=config[CONF_TOPIC], topic=config[CONF_TOPIC],
payload=config[CONF_PAYLOAD], payload=config[CONF_PAYLOAD],
qos=config[CONF_QOS], qos=config[CONF_QOS],
remove_signal=remove_signal,
) )
else: else:
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(config) await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
config, discovery_hash, remove_signal
)
async def async_device_removed(hass: HomeAssistant, device_id: str):
"""Handle the removal of a device."""
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID])
if device_trigger:
device_trigger.detach_trigger()
clear_discovery_hash(hass, device_trigger.discovery_hash)
device_trigger.remove_signal()
async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]: async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]:
@ -262,6 +281,8 @@ async def async_attach_trigger(
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass, hass=hass,
device_id=device_id, device_id=device_id,
discovery_hash=None,
remove_signal=None,
type=config[CONF_TYPE], type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE], subtype=config[CONF_SUBTYPE],
topic=None, topic=None,

View file

@ -11,7 +11,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS
from .const import ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_TOPIC, CONF_STATE_TOPIC
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -137,6 +137,11 @@ async def async_start(
if payload: if payload:
# Attach MQTT topic to the payload, used for debug prints # Attach MQTT topic to the payload, used for debug prints
setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')") setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')")
discovery_data = {
ATTR_DISCOVERY_HASH: discovery_hash,
ATTR_DISCOVERY_TOPIC: topic,
}
setattr(payload, "discovery_data", discovery_data)
if CONF_PLATFORM in payload and "schema" not in payload: if CONF_PLATFORM in payload and "schema" not in payload:
platform = payload[CONF_PLATFORM] platform = payload[CONF_PLATFORM]
@ -173,8 +178,6 @@ async def async_start(
topic, topic,
) )
payload[ATTR_DISCOVERY_HASH] = discovery_hash
if ALREADY_DISCOVERED not in hass.data: if ALREADY_DISCOVERED not in hass.data:
hass.data[ALREADY_DISCOVERED] = {} hass.data[ALREADY_DISCOVERED] = {}
if discovery_hash in hass.data[ALREADY_DISCOVERED]: if discovery_hash in hass.data[ALREADY_DISCOVERED]:

View file

@ -118,15 +118,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT fan.""" """Discover and add a MQTT fan."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -135,10 +134,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT fan.""" """Set up the MQTT fan."""
async_add_entities([MqttFan(config, config_entry, discovery_hash)]) async_add_entities([MqttFan(config, config_entry, discovery_data)])
class MqttFan( class MqttFan(
@ -150,7 +149,7 @@ class MqttFan(
): ):
"""A MQTT fan component.""" """A MQTT fan component."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT fan.""" """Initialize the MQTT fan."""
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
self._state = False self._state = False
@ -173,7 +172,7 @@ class MqttFan(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -317,6 +316,7 @@ class MqttFan(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT light.""" """Discover and add a MQTT light."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -64,7 +63,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up a MQTT Light.""" """Set up a MQTT Light."""
setup_entity = { setup_entity = {
@ -73,5 +72,5 @@ async def _async_setup_entity(
"template": async_setup_entity_template, "template": async_setup_entity_template,
} }
await setup_entity[config[CONF_SCHEMA]]( await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )

View file

@ -146,12 +146,12 @@ PLATFORM_SCHEMA_BASIC = (
async def async_setup_entity_basic( async def async_setup_entity_basic(
config, async_add_entities, config_entry, discovery_hash=None config, async_add_entities, config_entry, discovery_data=None
): ):
"""Set up a MQTT Light.""" """Set up a MQTT Light."""
config.setdefault(CONF_STATE_VALUE_TEMPLATE, config.get(CONF_VALUE_TEMPLATE)) config.setdefault(CONF_STATE_VALUE_TEMPLATE, config.get(CONF_VALUE_TEMPLATE))
async_add_entities([MqttLight(config, config_entry, discovery_hash)]) async_add_entities([MqttLight(config, config_entry, discovery_data)])
class MqttLight( class MqttLight(
@ -164,7 +164,7 @@ class MqttLight(
): ):
"""Representation of a MQTT light.""" """Representation of a MQTT light."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize MQTT light.""" """Initialize MQTT light."""
self._state = False self._state = False
self._sub_state = None self._sub_state = None
@ -194,7 +194,7 @@ class MqttLight(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -535,6 +535,7 @@ class MqttLight(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def brightness(self): def brightness(self):

View file

@ -119,10 +119,10 @@ PLATFORM_SCHEMA_JSON = (
async def async_setup_entity_json( async def async_setup_entity_json(
config: ConfigType, async_add_entities, config_entry, discovery_hash config: ConfigType, async_add_entities, config_entry, discovery_data
): ):
"""Set up a MQTT JSON Light.""" """Set up a MQTT JSON Light."""
async_add_entities([MqttLightJson(config, config_entry, discovery_hash)]) async_add_entities([MqttLightJson(config, config_entry, discovery_data)])
class MqttLightJson( class MqttLightJson(
@ -135,7 +135,7 @@ class MqttLightJson(
): ):
"""Representation of a MQTT JSON light.""" """Representation of a MQTT JSON light."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize MQTT JSON light.""" """Initialize MQTT JSON light."""
self._state = False self._state = False
self._sub_state = None self._sub_state = None
@ -158,7 +158,7 @@ class MqttLightJson(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -346,6 +346,7 @@ class MqttLightJson(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def brightness(self): def brightness(self):

View file

@ -93,10 +93,10 @@ PLATFORM_SCHEMA_TEMPLATE = (
async def async_setup_entity_template( async def async_setup_entity_template(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
): ):
"""Set up a MQTT Template light.""" """Set up a MQTT Template light."""
async_add_entities([MqttTemplate(config, config_entry, discovery_hash)]) async_add_entities([MqttTemplate(config, config_entry, discovery_data)])
class MqttTemplate( class MqttTemplate(
@ -109,7 +109,7 @@ class MqttTemplate(
): ):
"""Representation of a MQTT Template light.""" """Representation of a MQTT Template light."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize a MQTT Template light.""" """Initialize a MQTT Template light."""
self._state = False self._state = False
self._sub_state = None self._sub_state = None
@ -133,7 +133,7 @@ class MqttTemplate(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -323,6 +323,7 @@ class MqttTemplate(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def brightness(self): def brightness(self):

View file

@ -80,15 +80,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT lock.""" """Discover and add an MQTT lock."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -97,10 +96,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT Lock platform.""" """Set up the MQTT Lock platform."""
async_add_entities([MqttLock(config, config_entry, discovery_hash)]) async_add_entities([MqttLock(config, config_entry, discovery_data)])
class MqttLock( class MqttLock(
@ -112,7 +111,7 @@ class MqttLock(
): ):
"""Representation of a lock that can be toggled using MQTT.""" """Representation of a lock that can be toggled using MQTT."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the lock.""" """Initialize the lock."""
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
self._state = False self._state = False
@ -126,7 +125,7 @@ class MqttLock(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -192,6 +191,7 @@ class MqttLock(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover_sensor(discovery_payload): async def async_discover_sensor(discovery_payload):
"""Discover and add a discovered MQTT sensor.""" """Discover and add a discovered MQTT sensor."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config: ConfigType, async_add_entities, config_entry=None, discovery_hash=None config: ConfigType, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up MQTT sensor.""" """Set up MQTT sensor."""
async_add_entities([MqttSensor(config, config_entry, discovery_hash)]) async_add_entities([MqttSensor(config, config_entry, discovery_data)])
class MqttSensor( class MqttSensor(
@ -104,7 +103,7 @@ class MqttSensor(
): ):
"""Representation of a sensor that can be updated using MQTT.""" """Representation of a sensor that can be updated using MQTT."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the sensor.""" """Initialize the sensor."""
self._config = config self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id = config.get(CONF_UNIQUE_ID)
@ -123,7 +122,7 @@ class MqttSensor(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -208,6 +207,7 @@ class MqttSensor(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@callback @callback
def value_is_expired(self, *_): def value_is_expired(self, *_):

View file

@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT switch.""" """Discover and add a MQTT switch."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None config, async_add_entities, config_entry=None, discovery_data=None
): ):
"""Set up the MQTT switch.""" """Set up the MQTT switch."""
async_add_entities([MqttSwitch(config, config_entry, discovery_hash)]) async_add_entities([MqttSwitch(config, config_entry, discovery_data)])
class MqttSwitch( class MqttSwitch(
@ -109,7 +108,7 @@ class MqttSwitch(
): ):
"""Representation of a switch that can be toggled using MQTT.""" """Representation of a switch that can be toggled using MQTT."""
def __init__(self, config, config_entry, discovery_hash): def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT switch.""" """Initialize the MQTT switch."""
self._state = False self._state = False
self._sub_state = None self._sub_state = None
@ -126,7 +125,7 @@ class MqttSwitch(
MqttAttributes.__init__(self, config) MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config) MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update) MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry) MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self): async def async_added_to_hass(self):
@ -203,6 +202,7 @@ class MqttSwitch(
) )
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property @property
def should_poll(self): def should_poll(self):

View file

@ -39,15 +39,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add a MQTT vacuum.""" """Discover and add a MQTT vacuum."""
discovery_data = discovery_payload.discovery_data
try: try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload) config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity( await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )
except Exception: except Exception:
if discovery_hash: clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
clear_discovery_hash(hass, discovery_hash)
raise raise
async_dispatcher_connect( async_dispatcher_connect(
@ -56,10 +55,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity( async def _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash=None config, async_add_entities, config_entry, discovery_data=None
): ):
"""Set up the MQTT vacuum.""" """Set up the MQTT vacuum."""
setup_entity = {LEGACY: async_setup_entity_legacy, STATE: async_setup_entity_state} setup_entity = {LEGACY: async_setup_entity_legacy, STATE: async_setup_entity_state}
await setup_entity[config[CONF_SCHEMA]]( await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
) )

View file

@ -162,10 +162,10 @@ PLATFORM_SCHEMA_LEGACY = (
async def async_setup_entity_legacy( async def async_setup_entity_legacy(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
): ):
"""Set up a MQTT Vacuum Legacy.""" """Set up a MQTT Vacuum Legacy."""
async_add_entities([MqttVacuum(config, config_entry, discovery_hash)]) async_add_entities([MqttVacuum(config, config_entry, discovery_data)])
class MqttVacuum( class MqttVacuum(
@ -269,6 +269,7 @@ class MqttVacuum(
await subscription.async_unsubscribe_topics(self.hass, self._sub_state) await subscription.async_unsubscribe_topics(self.hass, self._sub_state)
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def _subscribe_topics(self): async def _subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""

View file

@ -157,10 +157,10 @@ PLATFORM_SCHEMA_STATE = (
async def async_setup_entity_state( async def async_setup_entity_state(
config, async_add_entities, config_entry, discovery_hash config, async_add_entities, config_entry, discovery_data
): ):
"""Set up a State MQTT Vacuum.""" """Set up a State MQTT Vacuum."""
async_add_entities([MqttStateVacuum(config, config_entry, discovery_hash)]) async_add_entities([MqttStateVacuum(config, config_entry, discovery_data)])
class MqttStateVacuum( class MqttStateVacuum(
@ -234,6 +234,7 @@ class MqttStateVacuum(
await subscription.async_unsubscribe_topics(self.hass, self._sub_state) await subscription.async_unsubscribe_topics(self.hass, self._sub_state)
await MqttAttributes.async_will_remove_from_hass(self) await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self) await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def _subscribe_topics(self): async def _subscribe_topics(self):
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""

View file

@ -488,6 +488,12 @@ class Entity(ABC):
self._on_remove = [] self._on_remove = []
self._on_remove.append(func) self._on_remove.append(func)
async def async_removed_from_registry(self) -> None:
"""Run when entity has been removed from entity registry.
To be extended by integrations.
"""
async def async_remove(self) -> None: async def async_remove(self) -> None:
"""Remove entity from Home Assistant.""" """Remove entity from Home Assistant."""
assert self.hass is not None assert self.hass is not None
@ -534,6 +540,9 @@ class Entity(ABC):
async def _async_registry_updated(self, event): async def _async_registry_updated(self, event):
"""Handle entity registry update.""" """Handle entity registry update."""
data = event.data data = event.data
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
await self.async_removed_from_registry()
if ( if (
data["action"] != "update" data["action"] != "update"
or data.get("old_entity_id", data["entity_id"]) != self.entity_id or data.get("old_entity_id", data["entity_id"]) != self.entity_id

View file

@ -323,11 +323,15 @@ async def async_mock_mqtt_component(hass, config=None):
if config is None: if config is None:
config = {mqtt.CONF_BROKER: "mock-broker"} config = {mqtt.CONF_BROKER: "mock-broker"}
async def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
with patch("paho.mqtt.client.Client") as mock_client: with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect.return_value = 0 mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0) mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0) mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0) mock_client().publish.return_value = (0, 0)
mock_client().publish.side_effect = _async_fire_mqtt_message
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
assert result assert result

View file

@ -468,7 +468,7 @@ async def test_if_fires_on_mqtt_message_after_update(
assert len(calls) == 2 assert len(calls) == 2
async def test_not_fires_on_mqtt_message_after_remove( async def test_not_fires_on_mqtt_message_after_remove_by_mqtt(
hass, device_reg, calls, mqtt_mock hass, device_reg, calls, mqtt_mock
): ):
"""Test triggers not firing after removal.""" """Test triggers not firing after removal."""
@ -532,6 +532,62 @@ async def test_not_fires_on_mqtt_message_after_remove(
assert len(calls) == 2 assert len(calls) == 2
async def test_not_fires_on_mqtt_message_after_remove_from_registry(
hass, device_reg, calls, mqtt_mock
):
"""Test triggers not firing after removal."""
config_entry = MockConfigEntry(domain=DOMAIN, data={})
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data1 = (
'{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1" }'
)
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1)
await hass.async_block_till_done()
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "device",
"domain": DOMAIN,
"device_id": device_entry.id,
"discovery_id": "bla1",
"type": "button_short_press",
"subtype": "button_1",
},
"action": {
"service": "test.automation",
"data_template": {"some": ("short_press")},
},
},
]
},
)
# Fake short press.
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done()
assert len(calls) == 1
# Remove the device
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_attach_remove(hass, device_reg, mqtt_mock): async def test_attach_remove(hass, device_reg, mqtt_mock):
"""Test attach and removal of trigger.""" """Test attach and removal of trigger."""
config_entry = MockConfigEntry(domain=DOMAIN, data={}) config_entry = MockConfigEntry(domain=DOMAIN, data={})

View file

@ -3,6 +3,8 @@ from pathlib import Path
import re import re
from unittest.mock import patch from unittest.mock import patch
import pytest
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt.abbreviations import ( from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS, ABBREVIATIONS,
@ -11,7 +13,25 @@ from homeassistant.components.mqtt.abbreviations import (
from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start
from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.const import STATE_OFF, STATE_ON
from tests.common import MockConfigEntry, async_fire_mqtt_message, mock_coro from tests.common import (
MockConfigEntry,
async_fire_mqtt_message,
mock_coro,
mock_device_registry,
mock_registry,
)
@pytest.fixture
def device_reg(hass):
"""Return an empty, loaded, registry."""
return mock_device_registry(hass)
@pytest.fixture
def entity_reg(hass):
"""Return an empty, loaded, registry."""
return mock_registry(hass)
async def test_subscribing_config_topic(hass, mqtt_mock): async def test_subscribing_config_topic(hass, mqtt_mock):
@ -213,6 +233,114 @@ async def test_non_duplicate_discovery(hass, mqtt_mock, caplog):
assert "Component has already been discovered: binary_sensor bla" in caplog.text assert "Component has already been discovered: binary_sensor bla" in caplog.text
async def test_removal(hass, mqtt_mock, caplog):
"""Test removal of component through empty discovery message."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is None
async def test_rediscover(hass, mqtt_mock, caplog):
"""Test rediscover of removed component."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is None
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async def test_duplicate_removal(hass, mqtt_mock, caplog):
"""Test for a non duplicate component."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
assert "Component has already been discovered: binary_sensor bla" in caplog.text
caplog.clear()
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
assert "Component has already been discovered: binary_sensor bla" not in caplog.text
async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
"""Test discvered device is cleaned up when removed from registry."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
# Verify device and registry entries are created
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
entity_entry = entity_reg.async_get("sensor.mqtt_sensor")
assert entity_entry is not None
state = hass.states.get("sensor.mqtt_sensor")
assert state is not None
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
# Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is None
entity_entry = entity_reg.async_get("sensor.mqtt_sensor")
assert entity_entry is None
# Verify state is removed
state = hass.states.get("sensor.mqtt_sensor")
assert state is None
# Verify retained discovery topic has been cleared
mqtt_mock.async_publish.assert_called_once_with(
"homeassistant/sensor/bla/config", "", 0, True
)
async def test_discovery_expansion(hass, mqtt_mock, caplog): async def test_discovery_expansion(hass, mqtt_mock, caplog):
"""Test expansion of abbreviated discovery payload.""" """Test expansion of abbreviated discovery payload."""
entry = MockConfigEntry(domain=mqtt.DOMAIN) entry = MockConfigEntry(domain=mqtt.DOMAIN)

View file

@ -8,6 +8,7 @@ import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components import mqtt from homeassistant.components import mqtt
from homeassistant.components.mqtt.discovery import async_start
from homeassistant.const import ( from homeassistant.const import (
ATTR_DOMAIN, ATTR_DOMAIN,
ATTR_SERVICE, ATTR_SERVICE,
@ -27,11 +28,25 @@ from tests.common import (
fire_mqtt_message, fire_mqtt_message,
get_test_home_assistant, get_test_home_assistant,
mock_coro, mock_coro,
mock_device_registry,
mock_mqtt_component, mock_mqtt_component,
mock_registry,
threadsafe_coroutine_factory, threadsafe_coroutine_factory,
) )
@pytest.fixture
def device_reg(hass):
"""Return an empty, loaded, registry."""
return mock_device_registry(hass)
@pytest.fixture
def entity_reg(hass):
"""Return an empty, loaded, registry."""
return mock_registry(hass)
@pytest.fixture @pytest.fixture
def mock_MQTT(): def mock_MQTT():
"""Make sure connection is established.""" """Make sure connection is established."""
@ -828,3 +843,70 @@ async def test_dump_service(hass):
assert len(writes) == 2 assert len(writes) == 2
assert writes[0][1][0] == "bla/1,test1\n" assert writes[0][1][0] == "bla/1,test1\n"
assert writes[1][1][0] == "bla/2,test2\n" assert writes[1][1][0] == "bla/2,test2\n"
async def test_mqtt_ws_remove_discovered_device(
hass, device_reg, entity_reg, hass_ws_client, mqtt_mock
):
"""Test MQTT websocket device removal."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
# Verify device entry is created
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
client = await hass_ws_client(hass)
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert response["success"]
# Verify device entry is cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is None
async def test_mqtt_ws_remove_discovered_device_twice(
hass, device_reg, hass_ws_client, mqtt_mock
):
"""Test MQTT websocket device removal."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
client = await hass_ws_client(hass)
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert response["success"]
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert not response["success"]