Add MQTT WS command to remove device (#31989)

* Add MQTT WS command to remove device

* Review comments, fix test

* Fix tests
This commit is contained in:
Erik Montnemery 2020-02-25 05:46:02 +01:00 committed by GitHub
parent 4236d62b44
commit 7e387f93d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 473 additions and 124 deletions

View file

@ -37,6 +37,7 @@ from homeassistant.exceptions import (
Unauthorized,
)
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import async_get_registry as get_dev_reg
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceDataType
@ -48,6 +49,7 @@ from homeassistant.util.logging import catch_log_exception
from . import config_flow, discovery, server # noqa: F401 pylint: disable=unused-import
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_TOPIC,
CONF_BROKER,
CONF_DISCOVERY,
CONF_STATE_TOPIC,
@ -510,6 +512,7 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
hass.data[DATA_MQTT_HASS_CONFIG] = config
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_remove_device)
if conf is None:
# If we have a config entry, setup is done by that config entry.
@ -1156,43 +1159,55 @@ class MqttAvailability(Entity):
class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message."""
def __init__(self, discovery_hash, discovery_update=None) -> None:
def __init__(self, discovery_data, discovery_update=None) -> None:
"""Initialize the discovery update mixin."""
self._discovery_hash = discovery_hash
self._discovery_data = discovery_data
self._discovery_update = discovery_update
self._remove_signal = None
async def async_added_to_hass(self) -> None:
"""Subscribe to discovery updates."""
await super().async_added_to_hass()
discovery_hash = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
)
@callback
def discovery_callback(payload):
"""Handle discovery update."""
_LOGGER.info(
"Got update for entity with hash: %s '%s'",
self._discovery_hash,
payload,
"Got update for entity with hash: %s '%s'", discovery_hash, payload,
)
if not payload:
# Empty payload: Remove component
_LOGGER.info("Removing component: %s", self.entity_id)
self.hass.async_create_task(self.async_remove())
clear_discovery_hash(self.hass, self._discovery_hash)
clear_discovery_hash(self.hass, discovery_hash)
self._remove_signal()
elif self._discovery_update:
# Non-empty payload: Notify component
_LOGGER.info("Updating component: %s", self.entity_id)
payload.pop(ATTR_DISCOVERY_HASH)
self.hass.async_create_task(self._discovery_update(payload))
if self._discovery_hash:
if discovery_hash:
self._remove_signal = async_dispatcher_connect(
self.hass,
MQTT_DISCOVERY_UPDATED.format(self._discovery_hash),
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
discovery_callback,
)
async def async_removed_from_registry(self) -> None:
"""Clear retained discovery topic in broker."""
discovery_topic = self._discovery_data[ATTR_DISCOVERY_TOPIC]
publish(
self.hass, discovery_topic, "", retain=True,
)
async def async_will_remove_from_hass(self) -> None:
"""Stop listening to signal."""
if self._remove_signal:
self._remove_signal()
def device_info_from_config(config):
"""Return a device description for device registry."""
@ -1247,6 +1262,25 @@ class MqttEntityDeviceInfo(Entity):
return device_info_from_config(self._device_config)
@websocket_api.websocket_command(
{vol.Required("type"): "mqtt/device/remove", vol.Required("device_id"): str}
)
@websocket_api.async_response
async def websocket_remove_device(hass, connection, msg):
"""Delete device."""
device_id = msg["device_id"]
dev_registry = await get_dev_reg(hass)
device = dev_registry.async_get(device_id)
for config_entry in device.config_entries:
config_entry = hass.config_entries.async_get_entry(config_entry)
# Only delete the device if it belongs to an MQTT device entry
if config_entry.domain == DOMAIN:
dev_registry.async_remove_device(device_id)
connection.send_message(websocket_api.result_message(msg["id"]))
break
@websocket_api.async_response
@websocket_api.websocket_command(
{

View file

@ -98,15 +98,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add an MQTT alarm control panel."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -115,10 +114,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT Alarm Control Panel platform."""
async_add_entities([MqttAlarm(config, config_entry, discovery_hash)])
async_add_entities([MqttAlarm(config, config_entry, discovery_data)])
class MqttAlarm(
@ -130,7 +129,7 @@ class MqttAlarm(
):
"""Representation of a MQTT alarm status."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Init the MQTT Alarm Control Panel."""
self._state = None
self._config = config
@ -141,7 +140,7 @@ class MqttAlarm(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -207,6 +206,7 @@ class MqttAlarm(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -79,15 +79,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT binary sensor."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -96,10 +95,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT binary sensor."""
async_add_entities([MqttBinarySensor(config, config_entry, discovery_hash)])
async_add_entities([MqttBinarySensor(config, config_entry, discovery_data)])
class MqttBinarySensor(
@ -111,7 +110,7 @@ class MqttBinarySensor(
):
"""Representation a binary sensor that is updated by MQTT."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT binary sensor."""
self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID)
@ -124,7 +123,7 @@ class MqttBinarySensor(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -229,6 +228,7 @@ class MqttBinarySensor(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@callback
def value_is_expired(self, *_):

View file

@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT camera."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -64,16 +63,16 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT Camera."""
async_add_entities([MqttCamera(config, config_entry, discovery_hash)])
async_add_entities([MqttCamera(config, config_entry, discovery_data)])
class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
"""representation of a MQTT camera."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT Camera."""
self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID)
@ -85,7 +84,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
device_config = config.get(CONF_DEVICE)
Camera.__init__(self)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -127,6 +126,7 @@ class MqttCamera(MqttDiscoveryUpdate, MqttEntityDeviceInfo, Camera):
self._sub_state = await subscription.async_unsubscribe_topics(
self.hass, self._sub_state
)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def async_camera_image(self):
"""Return image response."""

View file

@ -243,15 +243,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT climate device."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
hass, config, async_add_entities, config_entry, discovery_hash
hass, config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -260,10 +259,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
hass, config, async_add_entities, config_entry=None, discovery_hash=None
hass, config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT climate devices."""
async_add_entities([MqttClimate(hass, config, config_entry, discovery_hash)])
async_add_entities([MqttClimate(hass, config, config_entry, discovery_data)])
class MqttClimate(
@ -275,7 +274,7 @@ class MqttClimate(
):
"""Representation of an MQTT climate device."""
def __init__(self, hass, config, config_entry, discovery_hash):
def __init__(self, hass, config, config_entry, discovery_data):
"""Initialize the climate device."""
self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID)
@ -303,7 +302,7 @@ class MqttClimate(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -552,6 +551,7 @@ class MqttClimate(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -4,6 +4,7 @@ CONF_DISCOVERY = "discovery"
DEFAULT_DISCOVERY = False
ATTR_DISCOVERY_HASH = "discovery_hash"
ATTR_DISCOVERY_TOPIC = "discovery_topic"
CONF_STATE_TOPIC = "state_topic"
PROTOCOL_311 = "3.1.1"
DEFAULT_QOS = 0

View file

@ -178,14 +178,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add an MQTT cover."""
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
discovery_data = discovery_payload.discovery_data
try:
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -194,10 +194,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT Cover."""
async_add_entities([MqttCover(config, config_entry, discovery_hash)])
async_add_entities([MqttCover(config, config_entry, discovery_data)])
class MqttCover(
@ -209,7 +209,7 @@ class MqttCover(
):
"""Representation of a cover that can be controlled using MQTT."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the cover."""
self._unique_id = config.get(CONF_UNIQUE_ID)
self._position = None
@ -227,7 +227,7 @@ class MqttCover(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -369,6 +369,7 @@ class MqttCover(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -4,6 +4,7 @@ import logging
import voluptuous as vol
from homeassistant.components import mqtt
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from . import ATTR_DISCOVERY_HASH, device_trigger
@ -25,20 +26,26 @@ PLATFORM_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
async def async_setup_entry(hass, config_entry):
"""Set up MQTT device automation dynamically through MQTT discovery."""
async def async_device_removed(event):
"""Handle the removal of a device."""
if event.data["action"] != "remove":
return
await device_trigger.async_device_removed(hass, event.data["device_id"])
async def async_discover(discovery_payload):
"""Discover and add an MQTT device automation."""
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
discovery_data = discovery_payload.discovery_data
try:
config = PLATFORM_SCHEMA(discovery_payload)
if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
await device_trigger.async_setup_trigger(
hass, config, config_entry, discovery_hash
hass, config, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW.format("device_automation", "mqtt"), async_discover
)
hass.bus.async_listen(EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed)

View file

@ -1,6 +1,6 @@
"""Provides device automations for MQTT."""
import logging
from typing import List
from typing import Callable, List
import attr
import voluptuous as vol
@ -99,9 +99,11 @@ class Trigger:
"""Device trigger settings."""
device_id = attr.ib(type=str)
discovery_hash = attr.ib(type=tuple)
hass = attr.ib(type=HomeAssistantType)
payload = attr.ib(type=str)
qos = attr.ib(type=int)
remove_signal = attr.ib(type=Callable[[], None])
subtype = attr.ib(type=str)
topic = attr.ib(type=str)
type = attr.ib(type=str)
@ -128,8 +130,10 @@ class Trigger:
return async_remove
async def update_trigger(self, config):
async def update_trigger(self, config, discovery_hash, remove_signal):
"""Update MQTT device trigger."""
self.discovery_hash = discovery_hash
self.remove_signal = remove_signal
self.type = config[CONF_TYPE]
self.subtype = config[CONF_SUBTYPE]
self.topic = config[CONF_TOPIC]
@ -143,8 +147,8 @@ class Trigger:
def detach_trigger(self):
"""Remove MQTT device trigger."""
# Mark trigger as unknown
self.topic = None
# Unsubscribe if this trigger is in use
for trig in self.trigger_instances:
if trig.remove:
@ -163,9 +167,10 @@ async def _update_device(hass, config_entry, config):
device_registry.async_get_or_create(**device_info)
async def async_setup_trigger(hass, config, config_entry, discovery_hash):
async def async_setup_trigger(hass, config, config_entry, discovery_data):
"""Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
remove_signal = None
@ -185,11 +190,10 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash):
else:
# Non-empty payload: Update trigger
_LOGGER.info("Updating trigger: %s", discovery_hash)
payload.pop(ATTR_DISCOVERY_HASH)
config = TRIGGER_DISCOVERY_SCHEMA(payload)
await _update_device(hass, config_entry, config)
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
await device_trigger.update_trigger(config)
await device_trigger.update_trigger(config, discovery_hash, remove_signal)
remove_signal = async_dispatcher_connect(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
@ -212,14 +216,29 @@ async def async_setup_trigger(hass, config, config_entry, discovery_hash):
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass,
device_id=device.id,
discovery_hash=discovery_hash,
type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE],
topic=config[CONF_TOPIC],
payload=config[CONF_PAYLOAD],
qos=config[CONF_QOS],
remove_signal=remove_signal,
)
else:
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(config)
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
config, discovery_hash, remove_signal
)
async def async_device_removed(hass: HomeAssistant, device_id: str):
"""Handle the removal of a device."""
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID])
if device_trigger:
device_trigger.detach_trigger()
clear_discovery_hash(hass, device_trigger.discovery_hash)
device_trigger.remove_signal()
async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]:
@ -262,6 +281,8 @@ async def async_attach_trigger(
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
hass=hass,
device_id=device_id,
discovery_hash=None,
remove_signal=None,
type=config[CONF_TYPE],
subtype=config[CONF_SUBTYPE],
topic=None,

View file

@ -11,7 +11,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType
from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS
from .const import ATTR_DISCOVERY_HASH, CONF_STATE_TOPIC
from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_TOPIC, CONF_STATE_TOPIC
_LOGGER = logging.getLogger(__name__)
@ -137,6 +137,11 @@ async def async_start(
if payload:
# Attach MQTT topic to the payload, used for debug prints
setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')")
discovery_data = {
ATTR_DISCOVERY_HASH: discovery_hash,
ATTR_DISCOVERY_TOPIC: topic,
}
setattr(payload, "discovery_data", discovery_data)
if CONF_PLATFORM in payload and "schema" not in payload:
platform = payload[CONF_PLATFORM]
@ -173,8 +178,6 @@ async def async_start(
topic,
)
payload[ATTR_DISCOVERY_HASH] = discovery_hash
if ALREADY_DISCOVERED not in hass.data:
hass.data[ALREADY_DISCOVERED] = {}
if discovery_hash in hass.data[ALREADY_DISCOVERED]:

View file

@ -118,15 +118,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT fan."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -135,10 +134,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT fan."""
async_add_entities([MqttFan(config, config_entry, discovery_hash)])
async_add_entities([MqttFan(config, config_entry, discovery_data)])
class MqttFan(
@ -150,7 +149,7 @@ class MqttFan(
):
"""A MQTT fan component."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT fan."""
self._unique_id = config.get(CONF_UNIQUE_ID)
self._state = False
@ -173,7 +172,7 @@ class MqttFan(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -317,6 +316,7 @@ class MqttFan(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -47,15 +47,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT light."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -64,7 +63,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up a MQTT Light."""
setup_entity = {
@ -73,5 +72,5 @@ async def _async_setup_entity(
"template": async_setup_entity_template,
}
await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)

View file

@ -146,12 +146,12 @@ PLATFORM_SCHEMA_BASIC = (
async def async_setup_entity_basic(
config, async_add_entities, config_entry, discovery_hash=None
config, async_add_entities, config_entry, discovery_data=None
):
"""Set up a MQTT Light."""
config.setdefault(CONF_STATE_VALUE_TEMPLATE, config.get(CONF_VALUE_TEMPLATE))
async_add_entities([MqttLight(config, config_entry, discovery_hash)])
async_add_entities([MqttLight(config, config_entry, discovery_data)])
class MqttLight(
@ -164,7 +164,7 @@ class MqttLight(
):
"""Representation of a MQTT light."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize MQTT light."""
self._state = False
self._sub_state = None
@ -194,7 +194,7 @@ class MqttLight(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -535,6 +535,7 @@ class MqttLight(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def brightness(self):

View file

@ -119,10 +119,10 @@ PLATFORM_SCHEMA_JSON = (
async def async_setup_entity_json(
config: ConfigType, async_add_entities, config_entry, discovery_hash
config: ConfigType, async_add_entities, config_entry, discovery_data
):
"""Set up a MQTT JSON Light."""
async_add_entities([MqttLightJson(config, config_entry, discovery_hash)])
async_add_entities([MqttLightJson(config, config_entry, discovery_data)])
class MqttLightJson(
@ -135,7 +135,7 @@ class MqttLightJson(
):
"""Representation of a MQTT JSON light."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize MQTT JSON light."""
self._state = False
self._sub_state = None
@ -158,7 +158,7 @@ class MqttLightJson(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -346,6 +346,7 @@ class MqttLightJson(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def brightness(self):

View file

@ -93,10 +93,10 @@ PLATFORM_SCHEMA_TEMPLATE = (
async def async_setup_entity_template(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
):
"""Set up a MQTT Template light."""
async_add_entities([MqttTemplate(config, config_entry, discovery_hash)])
async_add_entities([MqttTemplate(config, config_entry, discovery_data)])
class MqttTemplate(
@ -109,7 +109,7 @@ class MqttTemplate(
):
"""Representation of a MQTT Template light."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize a MQTT Template light."""
self._state = False
self._sub_state = None
@ -133,7 +133,7 @@ class MqttTemplate(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -323,6 +323,7 @@ class MqttTemplate(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def brightness(self):

View file

@ -80,15 +80,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add an MQTT lock."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -97,10 +96,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT Lock platform."""
async_add_entities([MqttLock(config, config_entry, discovery_hash)])
async_add_entities([MqttLock(config, config_entry, discovery_data)])
class MqttLock(
@ -112,7 +111,7 @@ class MqttLock(
):
"""Representation of a lock that can be toggled using MQTT."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the lock."""
self._unique_id = config.get(CONF_UNIQUE_ID)
self._state = False
@ -126,7 +125,7 @@ class MqttLock(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -192,6 +191,7 @@ class MqttLock(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover_sensor(discovery_payload):
"""Discover and add a discovered MQTT sensor."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config: ConfigType, async_add_entities, config_entry=None, discovery_hash=None
config: ConfigType, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up MQTT sensor."""
async_add_entities([MqttSensor(config, config_entry, discovery_hash)])
async_add_entities([MqttSensor(config, config_entry, discovery_data)])
class MqttSensor(
@ -104,7 +103,7 @@ class MqttSensor(
):
"""Representation of a sensor that can be updated using MQTT."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the sensor."""
self._config = config
self._unique_id = config.get(CONF_UNIQUE_ID)
@ -123,7 +122,7 @@ class MqttSensor(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -208,6 +207,7 @@ class MqttSensor(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@callback
def value_is_expired(self, *_):

View file

@ -76,15 +76,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT switch."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -93,10 +92,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry=None, discovery_hash=None
config, async_add_entities, config_entry=None, discovery_data=None
):
"""Set up the MQTT switch."""
async_add_entities([MqttSwitch(config, config_entry, discovery_hash)])
async_add_entities([MqttSwitch(config, config_entry, discovery_data)])
class MqttSwitch(
@ -109,7 +108,7 @@ class MqttSwitch(
):
"""Representation of a switch that can be toggled using MQTT."""
def __init__(self, config, config_entry, discovery_hash):
def __init__(self, config, config_entry, discovery_data):
"""Initialize the MQTT switch."""
self._state = False
self._sub_state = None
@ -126,7 +125,7 @@ class MqttSwitch(
MqttAttributes.__init__(self, config)
MqttAvailability.__init__(self, config)
MqttDiscoveryUpdate.__init__(self, discovery_hash, self.discovery_update)
MqttDiscoveryUpdate.__init__(self, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, device_config, config_entry)
async def async_added_to_hass(self):
@ -203,6 +202,7 @@ class MqttSwitch(
)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
@property
def should_poll(self):

View file

@ -39,15 +39,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def async_discover(discovery_payload):
"""Discover and add a MQTT vacuum."""
discovery_data = discovery_payload.discovery_data
try:
discovery_hash = discovery_payload.pop(ATTR_DISCOVERY_HASH)
config = PLATFORM_SCHEMA(discovery_payload)
await _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)
except Exception:
if discovery_hash:
clear_discovery_hash(hass, discovery_hash)
clear_discovery_hash(hass, discovery_data[ATTR_DISCOVERY_HASH])
raise
async_dispatcher_connect(
@ -56,10 +55,10 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async def _async_setup_entity(
config, async_add_entities, config_entry, discovery_hash=None
config, async_add_entities, config_entry, discovery_data=None
):
"""Set up the MQTT vacuum."""
setup_entity = {LEGACY: async_setup_entity_legacy, STATE: async_setup_entity_state}
await setup_entity[config[CONF_SCHEMA]](
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
)

View file

@ -162,10 +162,10 @@ PLATFORM_SCHEMA_LEGACY = (
async def async_setup_entity_legacy(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
):
"""Set up a MQTT Vacuum Legacy."""
async_add_entities([MqttVacuum(config, config_entry, discovery_hash)])
async_add_entities([MqttVacuum(config, config_entry, discovery_data)])
class MqttVacuum(
@ -269,6 +269,7 @@ class MqttVacuum(
await subscription.async_unsubscribe_topics(self.hass, self._sub_state)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View file

@ -157,10 +157,10 @@ PLATFORM_SCHEMA_STATE = (
async def async_setup_entity_state(
config, async_add_entities, config_entry, discovery_hash
config, async_add_entities, config_entry, discovery_data
):
"""Set up a State MQTT Vacuum."""
async_add_entities([MqttStateVacuum(config, config_entry, discovery_hash)])
async_add_entities([MqttStateVacuum(config, config_entry, discovery_data)])
class MqttStateVacuum(
@ -234,6 +234,7 @@ class MqttStateVacuum(
await subscription.async_unsubscribe_topics(self.hass, self._sub_state)
await MqttAttributes.async_will_remove_from_hass(self)
await MqttAvailability.async_will_remove_from_hass(self)
await MqttDiscoveryUpdate.async_will_remove_from_hass(self)
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View file

@ -488,6 +488,12 @@ class Entity(ABC):
self._on_remove = []
self._on_remove.append(func)
async def async_removed_from_registry(self) -> None:
"""Run when entity has been removed from entity registry.
To be extended by integrations.
"""
async def async_remove(self) -> None:
"""Remove entity from Home Assistant."""
assert self.hass is not None
@ -534,6 +540,9 @@ class Entity(ABC):
async def _async_registry_updated(self, event):
"""Handle entity registry update."""
data = event.data
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
await self.async_removed_from_registry()
if (
data["action"] != "update"
or data.get("old_entity_id", data["entity_id"]) != self.entity_id

View file

@ -323,11 +323,15 @@ async def async_mock_mqtt_component(hass, config=None):
if config is None:
config = {mqtt.CONF_BROKER: "mock-broker"}
async def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
mock_client().publish.side_effect = _async_fire_mqtt_message
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
assert result

View file

@ -468,7 +468,7 @@ async def test_if_fires_on_mqtt_message_after_update(
assert len(calls) == 2
async def test_not_fires_on_mqtt_message_after_remove(
async def test_not_fires_on_mqtt_message_after_remove_by_mqtt(
hass, device_reg, calls, mqtt_mock
):
"""Test triggers not firing after removal."""
@ -532,6 +532,62 @@ async def test_not_fires_on_mqtt_message_after_remove(
assert len(calls) == 2
async def test_not_fires_on_mqtt_message_after_remove_from_registry(
hass, device_reg, calls, mqtt_mock
):
"""Test triggers not firing after removal."""
config_entry = MockConfigEntry(domain=DOMAIN, data={})
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data1 = (
'{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1" }'
)
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1)
await hass.async_block_till_done()
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "device",
"domain": DOMAIN,
"device_id": device_entry.id,
"discovery_id": "bla1",
"type": "button_short_press",
"subtype": "button_1",
},
"action": {
"service": "test.automation",
"data_template": {"some": ("short_press")},
},
},
]
},
)
# Fake short press.
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done()
assert len(calls) == 1
# Remove the device
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_attach_remove(hass, device_reg, mqtt_mock):
"""Test attach and removal of trigger."""
config_entry = MockConfigEntry(domain=DOMAIN, data={})

View file

@ -3,6 +3,8 @@ from pathlib import Path
import re
from unittest.mock import patch
import pytest
from homeassistant.components import mqtt
from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS,
@ -11,7 +13,25 @@ from homeassistant.components.mqtt.abbreviations import (
from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start
from homeassistant.const import STATE_OFF, STATE_ON
from tests.common import MockConfigEntry, async_fire_mqtt_message, mock_coro
from tests.common import (
MockConfigEntry,
async_fire_mqtt_message,
mock_coro,
mock_device_registry,
mock_registry,
)
@pytest.fixture
def device_reg(hass):
"""Return an empty, loaded, registry."""
return mock_device_registry(hass)
@pytest.fixture
def entity_reg(hass):
"""Return an empty, loaded, registry."""
return mock_registry(hass)
async def test_subscribing_config_topic(hass, mqtt_mock):
@ -213,6 +233,114 @@ async def test_non_duplicate_discovery(hass, mqtt_mock, caplog):
assert "Component has already been discovered: binary_sensor bla" in caplog.text
async def test_removal(hass, mqtt_mock, caplog):
"""Test removal of component through empty discovery message."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is None
async def test_rediscover(hass, mqtt_mock, caplog):
"""Test rediscover of removed component."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is None
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
state = hass.states.get("binary_sensor.beer")
assert state is not None
async def test_duplicate_removal(hass, mqtt_mock, caplog):
"""Test for a non duplicate component."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)
await async_start(hass, "homeassistant", {}, entry)
async_fire_mqtt_message(
hass, "homeassistant/binary_sensor/bla/config", '{ "name": "Beer" }'
)
await hass.async_block_till_done()
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
assert "Component has already been discovered: binary_sensor bla" in caplog.text
caplog.clear()
async_fire_mqtt_message(hass, "homeassistant/binary_sensor/bla/config", "")
await hass.async_block_till_done()
assert "Component has already been discovered: binary_sensor bla" not in caplog.text
async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
"""Test discvered device is cleaned up when removed from registry."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
# Verify device and registry entries are created
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
entity_entry = entity_reg.async_get("sensor.mqtt_sensor")
assert entity_entry is not None
state = hass.states.get("sensor.mqtt_sensor")
assert state is not None
device_reg.async_remove_device(device_entry.id)
await hass.async_block_till_done()
# Verify device and registry entries are cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is None
entity_entry = entity_reg.async_get("sensor.mqtt_sensor")
assert entity_entry is None
# Verify state is removed
state = hass.states.get("sensor.mqtt_sensor")
assert state is None
# Verify retained discovery topic has been cleared
mqtt_mock.async_publish.assert_called_once_with(
"homeassistant/sensor/bla/config", "", 0, True
)
async def test_discovery_expansion(hass, mqtt_mock, caplog):
"""Test expansion of abbreviated discovery payload."""
entry = MockConfigEntry(domain=mqtt.DOMAIN)

View file

@ -8,6 +8,7 @@ import pytest
import voluptuous as vol
from homeassistant.components import mqtt
from homeassistant.components.mqtt.discovery import async_start
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_SERVICE,
@ -27,11 +28,25 @@ from tests.common import (
fire_mqtt_message,
get_test_home_assistant,
mock_coro,
mock_device_registry,
mock_mqtt_component,
mock_registry,
threadsafe_coroutine_factory,
)
@pytest.fixture
def device_reg(hass):
"""Return an empty, loaded, registry."""
return mock_device_registry(hass)
@pytest.fixture
def entity_reg(hass):
"""Return an empty, loaded, registry."""
return mock_registry(hass)
@pytest.fixture
def mock_MQTT():
"""Make sure connection is established."""
@ -828,3 +843,70 @@ async def test_dump_service(hass):
assert len(writes) == 2
assert writes[0][1][0] == "bla/1,test1\n"
assert writes[1][1][0] == "bla/2,test2\n"
async def test_mqtt_ws_remove_discovered_device(
hass, device_reg, entity_reg, hass_ws_client, mqtt_mock
):
"""Test MQTT websocket device removal."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
# Verify device entry is created
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
client = await hass_ws_client(hass)
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert response["success"]
# Verify device entry is cleared
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is None
async def test_mqtt_ws_remove_discovered_device_twice(
hass, device_reg, hass_ws_client, mqtt_mock
):
"""Test MQTT websocket device removal."""
config_entry = MockConfigEntry(domain=mqtt.DOMAIN)
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())
assert device_entry is not None
client = await hass_ws_client(hass)
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert response["success"]
await client.send_json(
{"id": 5, "type": "mqtt/device/remove", "device_id": device_entry.id}
)
response = await client.receive_json()
assert not response["success"]