Allow MQTT device based auto discovery (#118757)

* Allow MQTT device based auto discovery

* Fix merge error

* Remove unused import

* Fix discovery device based topics

* Fix cannot delete twice

* Improve cleanup test

* Follow up comment

* Typo

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Explain more

* Use tuple

* Default a device payload to have priority over a platform based payload

* Add unique_id to sensor test data

* Set migration flag to mark a discovery topic for migration

* Correct type hint

* Make unique_id required for components in device based discovery payload

* Remove CONF_MIGRATE_DISCOVERY from platform schema

* Unload discovered MQTT item to allow migration

* Follow up comments from code review

* ruff

* Subscribe to platform discovery wildcards first

* Use normal dict

* Use dict to persist wildcard subscription order

* Remove missed unused parameter

* Add a comment to explain we use a dict  to preserve the subscription order

* Add wildcard subscription order test

* Remove discovery flag from test

* Improve discovery migration origin logging

* Assert initial  wildcard discovery topics subscription order and after reconnect

* Improve log messages

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2024-10-30 17:10:15 +01:00 committed by GitHub
parent cb1b72d6ba
commit 1773f2aadc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1770 additions and 159 deletions

View file

@ -76,8 +76,8 @@ from .const import ( # noqa: F401
DEFAULT_QOS,
DEFAULT_RETAIN,
DOMAIN,
ENTITY_PLATFORMS,
MQTT_CONNECTION_STATE,
RELOADABLE_PLATFORMS,
TEMPLATE_ERRORS,
)
from .models import ( # noqa: F401
@ -438,7 +438,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
for entity in list(mqtt_platform.entities.values())
if getattr(entity, "_discovery_data", None) is None
and mqtt_platform.config_entry
and mqtt_platform.domain in RELOADABLE_PLATFORMS
and mqtt_platform.domain in ENTITY_PLATFORMS
]
await asyncio.gather(*tasks)

View file

@ -30,6 +30,7 @@ ABBREVIATIONS = {
"cmd_on_tpl": "command_on_template",
"cmd_t": "command_topic",
"cmd_tpl": "command_template",
"cmps": "components",
"cod_arm_req": "code_arm_required",
"cod_dis_req": "code_disarm_required",
"cod_form": "code_format",
@ -92,6 +93,7 @@ ABBREVIATIONS = {
"min_mirs": "min_mireds",
"max_temp": "max_temp",
"min_temp": "min_temp",
"migr_discvry": "migrate_discovery",
"mode": "mode",
"mode_cmd_tpl": "mode_command_template",
"mode_cmd_t": "mode_command_topic",
@ -109,6 +111,7 @@ ABBREVIATIONS = {
"osc_cmd_tpl": "oscillation_command_template",
"osc_stat_t": "oscillation_state_topic",
"osc_val_tpl": "oscillation_value_template",
"p": "platform",
"pause_cmd_t": "pause_command_topic",
"pause_mw_cmd_tpl": "pause_command_template",
"pct_cmd_t": "percentage_command_topic",

View file

@ -376,7 +376,9 @@ class MQTT:
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
set
)
self._wildcard_subscriptions: set[Subscription] = set()
# To ensure the wildcard subscriptions order is preserved, we use a dict
# with `None` values instead of a set.
self._wildcard_subscriptions: dict[Subscription, None] = {}
# _retained_topics prevents a Subscription from receiving a
# retained message more than once per topic. This prevents flooding
# already active subscribers when new subscribers subscribe to a topic
@ -754,7 +756,7 @@ class MQTT:
if subscription.is_simple_match:
self._simple_subscriptions[subscription.topic].add(subscription)
else:
self._wildcard_subscriptions.add(subscription)
self._wildcard_subscriptions[subscription] = None
@callback
def _async_untrack_subscription(self, subscription: Subscription) -> None:
@ -772,7 +774,7 @@ class MQTT:
if not simple_subscriptions[topic]:
del simple_subscriptions[topic]
else:
self._wildcard_subscriptions.remove(subscription)
del self._wildcard_subscriptions[subscription]
except (KeyError, ValueError) as exc:
raise HomeAssistantError("Can't remove subscription twice") from exc

View file

@ -90,6 +90,7 @@ CONF_TEMP_MIN = "min_temp"
CONF_CERTIFICATE = "certificate"
CONF_CLIENT_KEY = "client_key"
CONF_CLIENT_CERT = "client_cert"
CONF_COMPONENTS = "components"
CONF_TLS_INSECURE = "tls_insecure"
# Device and integration info options
@ -159,7 +160,7 @@ MQTT_CONNECTION_STATE = "mqtt_connection_state"
PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None"
RELOADABLE_PLATFORMS = [
ENTITY_PLATFORMS = [
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,
Platform.BUTTON,
@ -190,7 +191,7 @@ RELOADABLE_PLATFORMS = [
TEMPLATE_ERRORS = (jinja2.TemplateError, TemplateError, TypeError, ValueError)
SUPPORTED_COMPONENTS = {
SUPPORTED_COMPONENTS = (
"alarm_control_panel",
"binary_sensor",
"button",
@ -219,4 +220,4 @@ SUPPORTED_COMPONENTS = {
"vacuum",
"valve",
"water_heater",
}
)

View file

@ -12,6 +12,8 @@ import re
import time
from typing import TYPE_CHECKING, Any
import voluptuous as vol
from homeassistant.config_entries import (
SOURCE_MQTT,
ConfigEntry,
@ -25,7 +27,7 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo, ReceivePayloadType
from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.loader import async_get_mqtt
from homeassistant.util.json import json_loads_object
@ -38,13 +40,14 @@ from .const import (
ATTR_DISCOVERY_PAYLOAD,
ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY,
CONF_COMPONENTS,
CONF_ORIGIN,
CONF_TOPIC,
DOMAIN,
SUPPORTED_COMPONENTS,
)
from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
from .models import DATA_MQTT, MqttComponentConfig, MqttOriginInfo, ReceiveMessage
from .schemas import DEVICE_DISCOVERY_SCHEMA, MQTT_ORIGIN_INFO_SCHEMA, SHARED_OPTIONS
from .util import async_forward_entry_setup_and_setup_discovery
ABBREVIATIONS_SET = set(ABBREVIATIONS)
@ -70,10 +73,18 @@ MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat(
TOPIC_BASE = "~"
CONF_MIGRATE_DISCOVERY = "migrate_discovery"
MIGRATE_DISCOVERY_SCHEMA = vol.Schema(
{vol.Optional(CONF_MIGRATE_DISCOVERY): True},
)
class MQTTDiscoveryPayload(dict[str, Any]):
"""Class to hold and MQTT discovery payload and discovery data."""
device_discovery: bool = False
migrate_discovery: bool = False
discovery_data: DiscoveryInfoType
@ -85,6 +96,24 @@ class MQTTIntegrationDiscoveryConfig:
msg: ReceiveMessage
@callback
def _async_process_discovery_migration(payload: MQTTDiscoveryPayload) -> bool:
"""Process a discovery migration request in the discovery payload."""
# Allow abbreviation
if migr_discvry := (payload.pop("migr_discvry", None)):
payload[CONF_MIGRATE_DISCOVERY] = migr_discvry
if CONF_MIGRATE_DISCOVERY in payload:
try:
MIGRATE_DISCOVERY_SCHEMA(payload)
except vol.Invalid as exc:
_LOGGER.warning(exc)
return False
payload.migrate_discovery = True
payload.clear()
return True
return False
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list."""
hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash)
@ -96,36 +125,51 @@ def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) ->
@callback
def async_log_discovery_origin_info(
message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO
) -> None:
"""Log information about the discovery and origin."""
if not _LOGGER.isEnabledFor(level):
# bail early if logging is disabled
return
def get_origin_log_string(
discovery_payload: MQTTDiscoveryPayload, *, include_url: bool
) -> str:
"""Get the origin information from a discovery payload for logging."""
if CONF_ORIGIN not in discovery_payload:
_LOGGER.log(level, message)
return
return ""
origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN]
sw_version_log = ""
if sw_version := origin_info.get("sw_version"):
sw_version_log = f", version: {sw_version}"
support_url_log = ""
if support_url := origin_info.get("support_url"):
if include_url and (support_url := get_origin_support_url(discovery_payload)):
support_url_log = f", support URL: {support_url}"
return f" from external application {origin_info["name"]}{sw_version_log}{support_url_log}"
@callback
def get_origin_support_url(discovery_payload: MQTTDiscoveryPayload) -> str | None:
"""Get the origin information support URL from a discovery payload."""
if CONF_ORIGIN not in discovery_payload:
return ""
origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN]
return origin_info.get("support_url")
@callback
def async_log_discovery_origin_info(
message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO
) -> None:
"""Log information about the discovery and origin."""
# We only log origin info once per device discovery
if not _LOGGER.isEnabledFor(level):
# bail out early if logging is disabled
return
_LOGGER.log(
level,
"%s from external application %s%s%s",
"%s%s",
message,
origin_info["name"],
sw_version_log,
support_url_log,
get_origin_log_string(discovery_payload, include_url=True),
)
@callback
def _replace_abbreviations(
payload: Any | dict[str, Any],
payload: dict[str, Any] | str,
abbreviations: dict[str, str],
abbreviations_set: set[str],
) -> None:
@ -137,11 +181,20 @@ def _replace_abbreviations(
@callback
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
def _replace_all_abbreviations(
discovery_payload: dict[str, Any], component_only: bool = False
) -> None:
"""Replace all abbreviations in an MQTT discovery payload."""
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
if component_only:
return
if CONF_ORIGIN in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_ORIGIN],
@ -156,13 +209,15 @@ def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
DEVICE_ABBREVIATIONS_SET,
)
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
if CONF_COMPONENTS in discovery_payload:
if not isinstance(discovery_payload[CONF_COMPONENTS], dict):
return
for comp_conf in discovery_payload[CONF_COMPONENTS].values():
_replace_all_abbreviations(comp_conf, component_only=True)
@callback
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
def _replace_topic_base(discovery_payload: MQTTDiscoveryPayload) -> None:
"""Replace topic base in MQTT discovery data."""
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
@ -182,6 +237,79 @@ def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
@callback
def _generate_device_config(
hass: HomeAssistant,
object_id: str,
node_id: str | None,
migrate_discovery: bool = False,
) -> MQTTDiscoveryPayload:
"""Generate a cleanup or discovery migration message on device cleanup.
If an empty payload, or a migrate discovery request is received for a device,
we forward an empty payload for all previously discovered components.
"""
mqtt_data = hass.data[DATA_MQTT]
device_node_id: str = f"{node_id} {object_id}" if node_id else object_id
config = MQTTDiscoveryPayload({CONF_DEVICE: {}, CONF_COMPONENTS: {}})
config.migrate_discovery = migrate_discovery
comp_config = config[CONF_COMPONENTS]
for platform, discover_id in mqtt_data.discovery_already_discovered:
ids = discover_id.split(" ")
component_node_id = ids.pop(0)
component_object_id = " ".join(ids)
if not ids:
continue
if device_node_id == component_node_id:
comp_config[component_object_id] = {CONF_PLATFORM: platform}
return config if comp_config else MQTTDiscoveryPayload({})
@callback
def _parse_device_payload(
hass: HomeAssistant,
payload: ReceivePayloadType,
object_id: str,
node_id: str | None,
) -> MQTTDiscoveryPayload:
"""Parse a device discovery payload.
The device discovery payload is translated info the config payloads for every single
component inside the device based configuration.
An empty payload is translated in a cleanup, which forwards an empty payload to all
removed components.
"""
device_payload = MQTTDiscoveryPayload()
if payload == "":
if not (device_payload := _generate_device_config(hass, object_id, node_id)):
_LOGGER.warning(
"No device components to cleanup for %s, node_id '%s'",
object_id,
node_id,
)
return device_payload
try:
device_payload = MQTTDiscoveryPayload(json_loads_object(payload))
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return device_payload
if _async_process_discovery_migration(device_payload):
return _generate_device_config(hass, object_id, node_id, migrate_discovery=True)
_replace_all_abbreviations(device_payload)
try:
DEVICE_DISCOVERY_SCHEMA(device_payload)
except vol.Invalid as exc:
_LOGGER.warning(
"Invalid MQTT device discovery payload for %s, %s: '%s'",
object_id,
exc,
payload,
)
return MQTTDiscoveryPayload({})
return device_payload
@callback
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
"""Parse and validate origin info from a single component discovery payload."""
@ -199,6 +327,30 @@ def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
return True
@callback
def _merge_common_device_options(
component_config: MQTTDiscoveryPayload, device_config: dict[str, Any]
) -> None:
"""Merge common device options with the component config options.
Common options are:
CONF_AVAILABILITY,
CONF_AVAILABILITY_MODE,
CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC,
CONF_COMMAND_TOPIC,
CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE,
CONF_STATE_TOPIC,
Common options in the body of the device based config are inherited into
the component. Unless the option is explicitly specified at component level,
in that case the option at component level will override the common option.
"""
for option in SHARED_OPTIONS:
if option in device_config and option not in component_config:
component_config[option] = device_config.get(option)
async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None:
@ -243,8 +395,7 @@ async def async_start( # noqa: C901
_LOGGER.warning(
(
"Received message on illegal discovery topic '%s'. The topic"
" contains "
"not allowed characters. For more information see "
" contains non allowed characters. For more information see "
"https://www.home-assistant.io/integrations/mqtt/#discovery-topic"
),
topic,
@ -253,51 +404,118 @@ async def async_start( # noqa: C901
component, node_id, object_id = match.groups()
if payload:
discovered_components: list[MqttComponentConfig] = []
if component == CONF_DEVICE:
# Process device based discovery message and regenerate
# cleanup config for the all the components that are being removed.
# This is done when a component in the device config is omitted and detected
# as being removed, or when the device config update payload is empty.
# In that case this will regenerate a cleanup message for all every already
# discovered components that were linked to the initial device discovery.
device_discovery_payload = _parse_device_payload(
hass, payload, object_id, node_id
)
if not device_discovery_payload:
return
device_config: dict[str, Any]
origin_config: dict[str, Any] | None
component_configs: dict[str, dict[str, Any]]
device_config = device_discovery_payload[CONF_DEVICE]
origin_config = device_discovery_payload.get(CONF_ORIGIN)
component_configs = device_discovery_payload[CONF_COMPONENTS]
for component_id, config in component_configs.items():
component = config.pop(CONF_PLATFORM)
# The object_id in the device discovery topic is the unique identifier.
# It is used as node_id for the components it contains.
component_node_id = object_id
# The component_id in the discovery playload is used as object_id
# If we have an additional node_id in the discovery topic,
# we extend the component_id with it.
component_object_id = (
f"{node_id} {component_id}" if node_id else component_id
)
# We add wrapper to the discovery payload with the discovery data.
# If the dict is empty after removing the platform, the payload is
# assumed to remove the existing config and we do not want to add
# device or orig or shared availability attributes.
if discovery_payload := MQTTDiscoveryPayload(config):
discovery_payload[CONF_DEVICE] = device_config
discovery_payload[CONF_ORIGIN] = origin_config
# Only assign shared config options
# when they are not set at entity level
_merge_common_device_options(
discovery_payload, device_discovery_payload
)
discovery_payload.device_discovery = True
discovery_payload.migrate_discovery = (
device_discovery_payload.migrate_discovery
)
discovered_components.append(
MqttComponentConfig(
component,
component_object_id,
component_node_id,
discovery_payload,
)
)
_LOGGER.debug(
"Process device discovery payload %s", device_discovery_payload
)
device_discovery_id = f"{node_id} {object_id}" if node_id else object_id
message = f"Processing device discovery for '{device_discovery_id}'"
async_log_discovery_origin_info(
message, MQTTDiscoveryPayload(device_discovery_payload)
)
else:
# Process component based discovery message
try:
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
discovery_payload = MQTTDiscoveryPayload(
json_loads_object(payload) if payload else {}
)
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
_replace_all_abbreviations(discovery_payload)
if not _valid_origin_info(discovery_payload):
return
if not _async_process_discovery_migration(discovery_payload):
_replace_all_abbreviations(discovery_payload)
if not _valid_origin_info(discovery_payload):
return
discovered_components.append(
MqttComponentConfig(component, object_id, node_id, discovery_payload)
)
discovery_pending_discovered = mqtt_data.discovery_pending_discovered
for component_config in discovered_components:
component = component_config.component
node_id = component_config.node_id
object_id = component_config.object_id
discovery_payload = component_config.discovery_payload
if TOPIC_BASE in discovery_payload:
_replace_topic_base(discovery_payload)
else:
discovery_payload = MQTTDiscoveryPayload({})
# If present, the node_id will be included in the discovered object id
discovery_id = f"{node_id} {object_id}" if node_id else object_id
discovery_hash = (component, discovery_id)
# If present, the node_id will be included in the discovery_id.
discovery_id = f"{node_id} {object_id}" if node_id else object_id
discovery_hash = (component, discovery_id)
if discovery_payload:
# Attach MQTT topic to the payload, used for debug prints
setattr(
discovery_payload,
"__configuration_source__",
f"MQTT (topic: '{topic}')",
)
discovery_data = {
discovery_payload.discovery_data = {
ATTR_DISCOVERY_HASH: discovery_hash,
ATTR_DISCOVERY_PAYLOAD: discovery_payload,
ATTR_DISCOVERY_TOPIC: topic,
}
setattr(discovery_payload, "discovery_data", discovery_data)
discovery_payload[CONF_PLATFORM] = "mqtt"
if discovery_hash in discovery_pending_discovered:
pending = discovery_pending_discovered[discovery_hash]["pending"]
pending.appendleft(discovery_payload)
_LOGGER.debug(
"Component has already been discovered: %s %s, queuing update",
component,
discovery_id,
)
return
if discovery_hash in mqtt_data.discovery_pending_discovered:
pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
pending.appendleft(discovery_payload)
_LOGGER.debug(
"Component has already been discovered: %s %s, queuing update",
component,
discovery_id,
)
return
async_process_discovery_payload(component, discovery_id, discovery_payload)
async_process_discovery_payload(component, discovery_id, discovery_payload)
@callback
def async_process_discovery_payload(
@ -305,7 +523,7 @@ async def async_start( # noqa: C901
) -> None:
"""Process the payload of a new discovery."""
_LOGGER.debug("Process discovery payload %s", payload)
_LOGGER.debug("Process component discovery payload %s", payload)
discovery_hash = (component, discovery_id)
already_discovered = discovery_hash in mqtt_data.discovery_already_discovered
@ -362,6 +580,8 @@ async def async_start( # noqa: C901
0,
job_type=HassJobType.Callback,
)
# Subscribe first for platform discovery wildcard topics first,
# and then subscribe device discovery wildcard topics.
for topic in chain(
(
f"{discovery_topic}/{component}/+/config"
@ -371,6 +591,10 @@ async def async_start( # noqa: C901
f"{discovery_topic}/{component}/+/+/config"
for component in SUPPORTED_COMPONENTS
),
(
f"{discovery_topic}/device/+/config",
f"{discovery_topic}/device/+/+/config",
),
)
]

View file

@ -104,6 +104,8 @@ from .discovery import (
MQTT_DISCOVERY_UPDATED,
MQTTDiscoveryPayload,
clear_discovery_hash,
get_origin_log_string,
get_origin_support_url,
set_discovery_hash,
)
from .models import (
@ -591,6 +593,7 @@ async def cleanup_device_registry(
entity_registry = er.async_get(hass)
if (
device_id
and device_id not in device_registry.deleted_devices
and config_entry_id
and not er.async_entries_for_device(
entity_registry, device_id, include_disabled_entities=False
@ -672,6 +675,7 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
self._config_entry = config_entry
self._config_entry_id = config_entry.entry_id
self._skip_device_removal: bool = False
self._migrate_discovery: str | None = None
discovery_hash = get_discovery_hash(discovery_data)
self._remove_discovery_updated = async_dispatcher_connect(
@ -704,12 +708,95 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
) -> None:
"""Handle discovery update."""
discovery_hash = get_discovery_hash(self._discovery_data)
# Start discovery migration or rollback if migrate_discovery flag is set
# and the discovery topic is valid and not yet migrating
if (
discovery_payload.migrate_discovery
and self._migrate_discovery is None
and self._discovery_data[ATTR_DISCOVERY_TOPIC]
== discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC]
):
self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC]
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH]
origin_info = get_origin_log_string(
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
)
action = "Rollback" if discovery_payload.device_discovery else "Migration"
schema_type = "platform" if discovery_payload.device_discovery else "device"
_LOGGER.info(
"%s to MQTT %s discovery schema started for %s '%s'"
"%s on topic %s. To complete %s, publish a %s discovery "
"message with %s '%s'. After completed %s, "
"publish an empty (retained) payload to %s",
action,
schema_type,
discovery_hash[0],
discovery_hash[1],
origin_info,
self._migrate_discovery,
action.lower(),
schema_type,
discovery_hash[0],
discovery_hash[1],
action.lower(),
self._migrate_discovery,
)
# Cleanup platform resources
await self.async_tear_down()
# Unregister and clean discovery
stop_discovery_updates(
self.hass, self._discovery_data, self._remove_discovery_updated
)
send_discovery_done(self.hass, self._discovery_data)
return
_LOGGER.debug(
"Got update for %s with hash: %s '%s'",
self.log_name,
discovery_hash,
discovery_payload,
)
new_discovery_topic = discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC]
# Abort early if an update is not received via the registered discovery topic.
# This can happen if a device and single component discovery payload
# share the same discovery ID.
if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic:
# Prevent illegal updates
old_origin_info = get_origin_log_string(
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
)
new_origin_info = get_origin_log_string(
discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD],
include_url=False,
)
new_origin_support_url = get_origin_support_url(
discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD]
)
if new_origin_support_url:
get_support = f"for support visit {new_origin_support_url}"
else:
get_support = (
"for documentation on migration to device schema or rollback to "
"discovery schema, visit https://www.home-assistant.io/integrations/"
"mqtt/#migration-from-single-component-to-device-based-discovery"
)
_LOGGER.warning(
"Received a conflicting MQTT discovery message for %s '%s' which was "
"previously discovered on topic %s%s; the conflicting discovery "
"message was received on topic %s%s; %s",
discovery_hash[0],
discovery_hash[1],
self._discovery_data[ATTR_DISCOVERY_TOPIC],
old_origin_info,
new_discovery_topic,
new_origin_info,
get_support,
)
send_discovery_done(self.hass, self._discovery_data)
return
if (
discovery_payload
and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
@ -806,6 +893,7 @@ class MqttDiscoveryUpdateMixin(Entity):
mqtt_data = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
self._migrate_discovery: str | None = None
if discovery_hash in self._registry_hooks:
self._registry_hooks.pop(discovery_hash)()
@ -863,7 +951,12 @@ class MqttDiscoveryUpdateMixin(Entity):
if TYPE_CHECKING:
assert self._discovery_data
self._cleanup_discovery_on_remove()
await self._async_remove_state_and_registry_entry()
if self._migrate_discovery is None:
# Unload and cleanup registry
await self._async_remove_state_and_registry_entry()
else:
# Only unload the entity
await self.async_remove(force_remove=True)
send_discovery_done(self.hass, self._discovery_data)
@callback
@ -878,18 +971,102 @@ class MqttDiscoveryUpdateMixin(Entity):
"""
if TYPE_CHECKING:
assert self._discovery_data
discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
discovery_hash = get_discovery_hash(self._discovery_data)
# Start discovery migration or rollback if migrate_discovery flag is set
# and the discovery topic is valid and not yet migrating
if (
payload.migrate_discovery
and self._migrate_discovery is None
and self._discovery_data[ATTR_DISCOVERY_TOPIC]
== payload.discovery_data[ATTR_DISCOVERY_TOPIC]
):
if self.unique_id is None or self.device_info is None:
_LOGGER.error(
"Discovery migration is not possible for "
"for entity %s on topic %s. A unique_id "
"and device context is required, got unique_id: %s, device: %s",
self.entity_id,
self._discovery_data[ATTR_DISCOVERY_TOPIC],
self.unique_id,
self.device_info,
)
send_discovery_done(self.hass, self._discovery_data)
return
self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC]
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH]
origin_info = get_origin_log_string(
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
)
action = "Rollback" if payload.device_discovery else "Migration"
schema_type = "platform" if payload.device_discovery else "device"
_LOGGER.info(
"%s to MQTT %s discovery schema started for entity %s"
"%s on topic %s. To complete %s, publish a %s discovery "
"message with %s entity '%s'. After completed %s, "
"publish an empty (retained) payload to %s",
action,
schema_type,
self.entity_id,
origin_info,
self._migrate_discovery,
action.lower(),
schema_type,
discovery_hash[0],
discovery_hash[1],
action.lower(),
self._migrate_discovery,
)
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
_LOGGER.debug(
"Got update for entity with hash: %s '%s'",
discovery_hash,
payload,
)
old_payload: DiscoveryInfoType
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
new_discovery_topic = payload.discovery_data[ATTR_DISCOVERY_TOPIC]
# Abort early if an update is not received via the registered discovery topic.
# This can happen if a device and single component discovery payload
# share the same discovery ID.
if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic:
# Prevent illegal updates
old_origin_info = get_origin_log_string(
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
)
new_origin_info = get_origin_log_string(
payload.discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
)
new_origin_support_url = get_origin_support_url(
payload.discovery_data[ATTR_DISCOVERY_PAYLOAD]
)
if new_origin_support_url:
get_support = f"for support visit {new_origin_support_url}"
else:
get_support = (
"for documentation on migration to device schema or rollback to "
"discovery schema, visit https://www.home-assistant.io/integrations/"
"mqtt/#migration-from-single-component-to-device-based-discovery"
)
_LOGGER.warning(
"Received a conflicting MQTT discovery message for entity %s; the "
"entity was previously discovered on topic %s%s; the conflicting "
"discovery message was received on topic %s%s; %s",
self.entity_id,
self._discovery_data[ATTR_DISCOVERY_TOPIC],
old_origin_info,
new_discovery_topic,
new_origin_info,
get_support,
)
send_discovery_done(self.hass, self._discovery_data)
return
debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id)
if not payload:
# Empty payload: Remove component
_LOGGER.info("Removing component: %s", self.entity_id)
if self._migrate_discovery is None:
_LOGGER.info("Removing component: %s", self.entity_id)
else:
_LOGGER.info("Unloading component: %s", self.entity_id)
self.hass.async_create_task(
self._async_process_discovery_update_and_remove()
)

View file

@ -410,5 +410,15 @@ class MqttData:
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
@dataclass(slots=True)
class MqttComponentConfig:
"""(component, object_id, node_id, discovery_payload)."""
component: str
object_id: str
node_id: str | None
discovery_payload: MQTTDiscoveryPayload
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")

View file

@ -2,6 +2,8 @@
from __future__ import annotations
from typing import Any
import voluptuous as vol
from homeassistant.const import (
@ -11,6 +13,7 @@ from homeassistant.const import (
CONF_MODEL,
CONF_MODEL_ID,
CONF_NAME,
CONF_PLATFORM,
CONF_UNIQUE_ID,
CONF_VALUE_TEMPLATE,
)
@ -25,10 +28,13 @@ from .const import (
CONF_AVAILABILITY_MODE,
CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC,
CONF_COMMAND_TOPIC,
CONF_COMPONENTS,
CONF_CONFIGURATION_URL,
CONF_CONNECTIONS,
CONF_DEPRECATED_VIA_HUB,
CONF_ENABLED_BY_DEFAULT,
CONF_ENCODING,
CONF_ENTITY_PICTURE,
CONF_HW_VERSION,
CONF_IDENTIFIERS,
@ -39,7 +45,9 @@ from .const import (
CONF_ORIGIN,
CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE,
CONF_QOS,
CONF_SERIAL_NUMBER,
CONF_STATE_TOPIC,
CONF_SUGGESTED_AREA,
CONF_SUPPORT_URL,
CONF_SW_VERSION,
@ -47,10 +55,34 @@ from .const import (
CONF_VIA_DEVICE,
DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE,
ENTITY_PLATFORMS,
SUPPORTED_COMPONENTS,
)
from .util import valid_subscribe_topic
from .util import valid_publish_topic, valid_qos_schema, valid_subscribe_topic
MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
# Device discovery options that are also available at entity component level
SHARED_OPTIONS = [
CONF_AVAILABILITY,
CONF_AVAILABILITY_MODE,
CONF_AVAILABILITY_TEMPLATE,
CONF_AVAILABILITY_TOPIC,
CONF_COMMAND_TOPIC,
CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE,
CONF_STATE_TOPIC,
]
MQTT_ORIGIN_INFO_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(CONF_NAME): cv.string,
vol.Optional(CONF_SW_VERSION): cv.string,
vol.Optional(CONF_SUPPORT_URL): cv.configuration_url,
}
),
)
_MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
{
vol.Exclusive(CONF_AVAILABILITY_TOPIC, "availability"): valid_subscribe_topic,
vol.Optional(CONF_AVAILABILITY_TEMPLATE): cv.template,
@ -63,7 +95,7 @@ MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
}
)
MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
_MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
{
vol.Optional(CONF_AVAILABILITY_MODE, default=AVAILABILITY_LATEST): vol.All(
cv.string, vol.In(AVAILABILITY_MODES)
@ -87,8 +119,8 @@ MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
}
)
MQTT_AVAILABILITY_SCHEMA = MQTT_AVAILABILITY_SINGLE_SCHEMA.extend(
MQTT_AVAILABILITY_LIST_SCHEMA.schema
_MQTT_AVAILABILITY_SCHEMA = _MQTT_AVAILABILITY_SINGLE_SCHEMA.extend(
_MQTT_AVAILABILITY_LIST_SCHEMA.schema
)
@ -138,7 +170,7 @@ MQTT_ORIGIN_INFO_SCHEMA = vol.All(
),
)
MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
MQTT_ENTITY_COMMON_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend(
{
vol.Optional(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA,
vol.Optional(CONF_ENTITY_PICTURE): cv.url,
@ -152,3 +184,35 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
vol.Optional(CONF_UNIQUE_ID): cv.string,
}
)
_UNIQUE_ID_SCHEMA = vol.Schema(
{vol.Required(CONF_UNIQUE_ID): cv.string},
).extend({}, extra=True)
def check_unique_id(config: dict[str, Any]) -> dict[str, Any]:
"""Check if a unique ID is set in case an entity platform is configured."""
platform = config[CONF_PLATFORM]
if platform in ENTITY_PLATFORMS and len(config.keys()) > 1:
_UNIQUE_ID_SCHEMA(config)
return config
_COMPONENT_CONFIG_SCHEMA = vol.All(
vol.Schema(
{vol.Required(CONF_PLATFORM): vol.In(SUPPORTED_COMPONENTS)},
).extend({}, extra=True),
check_unique_id,
)
DEVICE_DISCOVERY_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend(
{
vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA,
vol.Required(CONF_COMPONENTS): vol.Schema({str: _COMPONENT_CONFIG_SCHEMA}),
vol.Required(CONF_ORIGIN): MQTT_ORIGIN_INFO_SCHEMA,
vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic,
vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic,
vol.Optional(CONF_QOS): valid_qos_schema,
vol.Optional(CONF_ENCODING): cv.string,
}
)

View file

@ -4,7 +4,7 @@ import asyncio
from collections.abc import AsyncGenerator, Generator
from random import getrandbits
from typing import Any
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import pytest
@ -122,3 +122,10 @@ def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType:
recorded_calls.append(msg)
return record_calls
@pytest.fixture
def tag_mock() -> Generator[AsyncMock]:
"""Fixture to mock tag."""
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
yield mock_tag

View file

@ -1716,6 +1716,64 @@ async def test_mqtt_subscribes_topics_on_connect(
assert ("still/pending", 1) in subscribe_calls
@pytest.mark.parametrize("mqtt_config_entry_data", [ENTRY_DEFAULT_BIRTH_MESSAGE])
async def test_mqtt_subscribes_wildcard_topics_in_correct_order(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
setup_with_birth_msg_client_mock: MqttMockPahoClient,
record_calls: MessageCallbackType,
) -> None:
"""Test subscription to wildcard topics on connect in the order of subscription."""
mqtt_client_mock = setup_with_birth_msg_client_mock
mock_debouncer.clear()
await mqtt.async_subscribe(hass, "integration/test#", record_calls)
await mqtt.async_subscribe(hass, "integration/kitchen_sink#", record_calls)
await mock_debouncer.wait()
def _assert_subscription_order():
discovery_subscribes = [
f"homeassistant/{platform}/+/config" for platform in SUPPORTED_COMPONENTS
]
discovery_subscribes.extend(
[
f"homeassistant/{platform}/+/+/config"
for platform in SUPPORTED_COMPONENTS
]
)
discovery_subscribes.extend(
["homeassistant/device/+/config", "homeassistant/device/+/+/config"]
)
discovery_subscribes.extend(["integration/test#", "integration/kitchen_sink#"])
expected_discovery_subscribes = discovery_subscribes.copy()
# Assert we see the expected subscribes and in the correct order
actual_subscribes = [
discovery_subscribes.pop(0)
for call in help_all_subscribe_calls(mqtt_client_mock)
if discovery_subscribes and discovery_subscribes[0] == call[0]
]
# Assert we have processed all items and that they are in the correct order
assert len(discovery_subscribes) == 0
assert actual_subscribes == expected_discovery_subscribes
# Assert the initial wildcard topic subscription order
_assert_subscription_order()
mqtt_client_mock.on_disconnect(Mock(), None, 0)
mqtt_client_mock.reset_mock()
mock_debouncer.clear()
mqtt_client_mock.on_connect(Mock(), None, 0, 0)
await mock_debouncer.wait()
# Assert the wildcard topic subscription order after a reconnect
_assert_subscription_order()
@pytest.mark.parametrize(
"mqtt_config_entry_data",
[ENTRY_DEFAULT_BIRTH_MESSAGE | {mqtt.CONF_DISCOVERY: False}],

View file

@ -69,6 +69,7 @@ DEFAULT_CONFIG_DEVICE_INFO_MAC = {
_SENTINEL = object()
DISCOVERY_COUNT = len(MQTT)
DEVICE_DISCOVERY_COUNT = 2
type _MqttMessageType = list[tuple[str, str]]
type _AttributesType = list[tuple[str, Any]]
@ -1189,7 +1190,10 @@ async def help_test_entity_id_update_subscriptions(
assert state is not None
assert (
mqtt_mock.async_subscribe.call_count
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
== len(topics)
+ 2 * len(SUPPORTED_COMPONENTS)
+ DISCOVERY_COUNT
+ DEVICE_DISCOVERY_COUNT
)
for topic in topics:
mqtt_mock.async_subscribe.assert_any_call(

View file

@ -26,22 +26,42 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
"""Stub copying the blueprints to the config folder."""
@pytest.mark.parametrize(
("discovery_topic", "data"),
[
(
"homeassistant/device_automation/0AFFD2/bla/config",
'{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},'
' "payload": "short_press",'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1" }',
),
(
"homeassistant/device/0AFFD2/config",
'{ "device":{"identifiers":["0AFFD2"]},'
' "o": {"name": "foobar"}, "cmps": '
'{ "bla": {'
' "automation_type":"trigger", '
' "payload": "short_press",'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1",'
' "platform":"device_automation"}}}',
),
],
)
async def test_get_triggers(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry: MqttMockHAClientGenerator,
discovery_topic: str,
data: str,
) -> None:
"""Test we get the expected triggers from a discovered mqtt device."""
await mqtt_mock_entry()
data1 = (
'{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},'
' "payload": "short_press",'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1" }'
)
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1)
async_fire_mqtt_message(hass, discovery_topic, data)
await hass.async_block_till_done()
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})

File diff suppressed because it is too large Load diff

View file

@ -1197,7 +1197,6 @@ async def test_mqtt_ws_get_device_debug_info(
}
data_sensor = json.dumps(config_sensor)
data_trigger = json.dumps(config_trigger)
config_sensor["platform"] = config_trigger["platform"] = mqtt.DOMAIN
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data_sensor)
async_fire_mqtt_message(
@ -1254,7 +1253,6 @@ async def test_mqtt_ws_get_device_debug_info_binary(
"unique_id": "unique",
}
data = json.dumps(config)
config["platform"] = mqtt.DOMAIN
async_fire_mqtt_message(hass, "homeassistant/camera/bla/config", data)
await hass.async_block_till_done()

View file

@ -1,10 +1,9 @@
"""The tests for MQTT tag scanner."""
from collections.abc import Generator
import copy
import json
from typing import Any
from unittest.mock import ANY, AsyncMock, patch
from unittest.mock import ANY, AsyncMock
import pytest
@ -47,13 +46,6 @@ DEFAULT_TAG_SCAN_JSON = (
)
@pytest.fixture
def tag_mock() -> Generator[AsyncMock]:
"""Fixture to mock tag."""
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
yield mock_tag
@pytest.mark.no_fail_on_log_exception
async def test_discover_bad_tag(
hass: HomeAssistant,