Allow MQTT device based auto discovery (#118757)
* Allow MQTT device based auto discovery * Fix merge error * Remove unused import * Fix discovery device based topics * Fix cannot delete twice * Improve cleanup test * Follow up comment * Typo Co-authored-by: Erik Montnemery <erik@montnemery.com> * Explain more * Use tuple * Default a device payload to have priority over a platform based payload * Add unique_id to sensor test data * Set migration flag to mark a discovery topic for migration * Correct type hint * Make unique_id required for components in device based discovery payload * Remove CONF_MIGRATE_DISCOVERY from platform schema * Unload discovered MQTT item to allow migration * Follow up comments from code review * ruff * Subscribe to platform discovery wildcards first * Use normal dict * Use dict to persist wildcard subscription order * Remove missed unused parameter * Add a comment to explain we use a dict to preserve the subscription order * Add wildcard subscription order test * Remove discovery flag from test * Improve discovery migration origin logging * Assert initial wildcard discovery topics subscription order and after reconnect * Improve log messages --------- Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
cb1b72d6ba
commit
1773f2aadc
15 changed files with 1770 additions and 159 deletions
|
@ -76,8 +76,8 @@ from .const import ( # noqa: F401
|
|||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
DOMAIN,
|
||||
ENTITY_PLATFORMS,
|
||||
MQTT_CONNECTION_STATE,
|
||||
RELOADABLE_PLATFORMS,
|
||||
TEMPLATE_ERRORS,
|
||||
)
|
||||
from .models import ( # noqa: F401
|
||||
|
@ -438,7 +438,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
for entity in list(mqtt_platform.entities.values())
|
||||
if getattr(entity, "_discovery_data", None) is None
|
||||
and mqtt_platform.config_entry
|
||||
and mqtt_platform.domain in RELOADABLE_PLATFORMS
|
||||
and mqtt_platform.domain in ENTITY_PLATFORMS
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ ABBREVIATIONS = {
|
|||
"cmd_on_tpl": "command_on_template",
|
||||
"cmd_t": "command_topic",
|
||||
"cmd_tpl": "command_template",
|
||||
"cmps": "components",
|
||||
"cod_arm_req": "code_arm_required",
|
||||
"cod_dis_req": "code_disarm_required",
|
||||
"cod_form": "code_format",
|
||||
|
@ -92,6 +93,7 @@ ABBREVIATIONS = {
|
|||
"min_mirs": "min_mireds",
|
||||
"max_temp": "max_temp",
|
||||
"min_temp": "min_temp",
|
||||
"migr_discvry": "migrate_discovery",
|
||||
"mode": "mode",
|
||||
"mode_cmd_tpl": "mode_command_template",
|
||||
"mode_cmd_t": "mode_command_topic",
|
||||
|
@ -109,6 +111,7 @@ ABBREVIATIONS = {
|
|||
"osc_cmd_tpl": "oscillation_command_template",
|
||||
"osc_stat_t": "oscillation_state_topic",
|
||||
"osc_val_tpl": "oscillation_value_template",
|
||||
"p": "platform",
|
||||
"pause_cmd_t": "pause_command_topic",
|
||||
"pause_mw_cmd_tpl": "pause_command_template",
|
||||
"pct_cmd_t": "percentage_command_topic",
|
||||
|
|
|
@ -376,7 +376,9 @@ class MQTT:
|
|||
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
|
||||
set
|
||||
)
|
||||
self._wildcard_subscriptions: set[Subscription] = set()
|
||||
# To ensure the wildcard subscriptions order is preserved, we use a dict
|
||||
# with `None` values instead of a set.
|
||||
self._wildcard_subscriptions: dict[Subscription, None] = {}
|
||||
# _retained_topics prevents a Subscription from receiving a
|
||||
# retained message more than once per topic. This prevents flooding
|
||||
# already active subscribers when new subscribers subscribe to a topic
|
||||
|
@ -754,7 +756,7 @@ class MQTT:
|
|||
if subscription.is_simple_match:
|
||||
self._simple_subscriptions[subscription.topic].add(subscription)
|
||||
else:
|
||||
self._wildcard_subscriptions.add(subscription)
|
||||
self._wildcard_subscriptions[subscription] = None
|
||||
|
||||
@callback
|
||||
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
||||
|
@ -772,7 +774,7 @@ class MQTT:
|
|||
if not simple_subscriptions[topic]:
|
||||
del simple_subscriptions[topic]
|
||||
else:
|
||||
self._wildcard_subscriptions.remove(subscription)
|
||||
del self._wildcard_subscriptions[subscription]
|
||||
except (KeyError, ValueError) as exc:
|
||||
raise HomeAssistantError("Can't remove subscription twice") from exc
|
||||
|
||||
|
|
|
@ -90,6 +90,7 @@ CONF_TEMP_MIN = "min_temp"
|
|||
CONF_CERTIFICATE = "certificate"
|
||||
CONF_CLIENT_KEY = "client_key"
|
||||
CONF_CLIENT_CERT = "client_cert"
|
||||
CONF_COMPONENTS = "components"
|
||||
CONF_TLS_INSECURE = "tls_insecure"
|
||||
|
||||
# Device and integration info options
|
||||
|
@ -159,7 +160,7 @@ MQTT_CONNECTION_STATE = "mqtt_connection_state"
|
|||
PAYLOAD_EMPTY_JSON = "{}"
|
||||
PAYLOAD_NONE = "None"
|
||||
|
||||
RELOADABLE_PLATFORMS = [
|
||||
ENTITY_PLATFORMS = [
|
||||
Platform.ALARM_CONTROL_PANEL,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.BUTTON,
|
||||
|
@ -190,7 +191,7 @@ RELOADABLE_PLATFORMS = [
|
|||
|
||||
TEMPLATE_ERRORS = (jinja2.TemplateError, TemplateError, TypeError, ValueError)
|
||||
|
||||
SUPPORTED_COMPONENTS = {
|
||||
SUPPORTED_COMPONENTS = (
|
||||
"alarm_control_panel",
|
||||
"binary_sensor",
|
||||
"button",
|
||||
|
@ -219,4 +220,4 @@ SUPPORTED_COMPONENTS = {
|
|||
"vacuum",
|
||||
"valve",
|
||||
"water_heater",
|
||||
}
|
||||
)
|
||||
|
|
|
@ -12,6 +12,8 @@ import re
|
|||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
SOURCE_MQTT,
|
||||
ConfigEntry,
|
||||
|
@ -25,7 +27,7 @@ from homeassistant.helpers.dispatcher import (
|
|||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
|
||||
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo, ReceivePayloadType
|
||||
from homeassistant.helpers.typing import DiscoveryInfoType
|
||||
from homeassistant.loader import async_get_mqtt
|
||||
from homeassistant.util.json import json_loads_object
|
||||
|
@ -38,13 +40,14 @@ from .const import (
|
|||
ATTR_DISCOVERY_PAYLOAD,
|
||||
ATTR_DISCOVERY_TOPIC,
|
||||
CONF_AVAILABILITY,
|
||||
CONF_COMPONENTS,
|
||||
CONF_ORIGIN,
|
||||
CONF_TOPIC,
|
||||
DOMAIN,
|
||||
SUPPORTED_COMPONENTS,
|
||||
)
|
||||
from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
|
||||
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
|
||||
from .models import DATA_MQTT, MqttComponentConfig, MqttOriginInfo, ReceiveMessage
|
||||
from .schemas import DEVICE_DISCOVERY_SCHEMA, MQTT_ORIGIN_INFO_SCHEMA, SHARED_OPTIONS
|
||||
from .util import async_forward_entry_setup_and_setup_discovery
|
||||
|
||||
ABBREVIATIONS_SET = set(ABBREVIATIONS)
|
||||
|
@ -70,10 +73,18 @@ MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat(
|
|||
|
||||
TOPIC_BASE = "~"
|
||||
|
||||
CONF_MIGRATE_DISCOVERY = "migrate_discovery"
|
||||
|
||||
MIGRATE_DISCOVERY_SCHEMA = vol.Schema(
|
||||
{vol.Optional(CONF_MIGRATE_DISCOVERY): True},
|
||||
)
|
||||
|
||||
|
||||
class MQTTDiscoveryPayload(dict[str, Any]):
|
||||
"""Class to hold and MQTT discovery payload and discovery data."""
|
||||
|
||||
device_discovery: bool = False
|
||||
migrate_discovery: bool = False
|
||||
discovery_data: DiscoveryInfoType
|
||||
|
||||
|
||||
|
@ -85,6 +96,24 @@ class MQTTIntegrationDiscoveryConfig:
|
|||
msg: ReceiveMessage
|
||||
|
||||
|
||||
@callback
|
||||
def _async_process_discovery_migration(payload: MQTTDiscoveryPayload) -> bool:
|
||||
"""Process a discovery migration request in the discovery payload."""
|
||||
# Allow abbreviation
|
||||
if migr_discvry := (payload.pop("migr_discvry", None)):
|
||||
payload[CONF_MIGRATE_DISCOVERY] = migr_discvry
|
||||
if CONF_MIGRATE_DISCOVERY in payload:
|
||||
try:
|
||||
MIGRATE_DISCOVERY_SCHEMA(payload)
|
||||
except vol.Invalid as exc:
|
||||
_LOGGER.warning(exc)
|
||||
return False
|
||||
payload.migrate_discovery = True
|
||||
payload.clear()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
|
||||
"""Clear entry from already discovered list."""
|
||||
hass.data[DATA_MQTT].discovery_already_discovered.discard(discovery_hash)
|
||||
|
@ -96,36 +125,51 @@ def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) ->
|
|||
|
||||
|
||||
@callback
|
||||
def async_log_discovery_origin_info(
|
||||
message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO
|
||||
) -> None:
|
||||
"""Log information about the discovery and origin."""
|
||||
if not _LOGGER.isEnabledFor(level):
|
||||
# bail early if logging is disabled
|
||||
return
|
||||
def get_origin_log_string(
|
||||
discovery_payload: MQTTDiscoveryPayload, *, include_url: bool
|
||||
) -> str:
|
||||
"""Get the origin information from a discovery payload for logging."""
|
||||
if CONF_ORIGIN not in discovery_payload:
|
||||
_LOGGER.log(level, message)
|
||||
return
|
||||
return ""
|
||||
origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN]
|
||||
sw_version_log = ""
|
||||
if sw_version := origin_info.get("sw_version"):
|
||||
sw_version_log = f", version: {sw_version}"
|
||||
support_url_log = ""
|
||||
if support_url := origin_info.get("support_url"):
|
||||
if include_url and (support_url := get_origin_support_url(discovery_payload)):
|
||||
support_url_log = f", support URL: {support_url}"
|
||||
return f" from external application {origin_info["name"]}{sw_version_log}{support_url_log}"
|
||||
|
||||
|
||||
@callback
|
||||
def get_origin_support_url(discovery_payload: MQTTDiscoveryPayload) -> str | None:
|
||||
"""Get the origin information support URL from a discovery payload."""
|
||||
if CONF_ORIGIN not in discovery_payload:
|
||||
return ""
|
||||
origin_info: MqttOriginInfo = discovery_payload[CONF_ORIGIN]
|
||||
return origin_info.get("support_url")
|
||||
|
||||
|
||||
@callback
|
||||
def async_log_discovery_origin_info(
|
||||
message: str, discovery_payload: MQTTDiscoveryPayload, level: int = logging.INFO
|
||||
) -> None:
|
||||
"""Log information about the discovery and origin."""
|
||||
# We only log origin info once per device discovery
|
||||
if not _LOGGER.isEnabledFor(level):
|
||||
# bail out early if logging is disabled
|
||||
return
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"%s from external application %s%s%s",
|
||||
"%s%s",
|
||||
message,
|
||||
origin_info["name"],
|
||||
sw_version_log,
|
||||
support_url_log,
|
||||
get_origin_log_string(discovery_payload, include_url=True),
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _replace_abbreviations(
|
||||
payload: Any | dict[str, Any],
|
||||
payload: dict[str, Any] | str,
|
||||
abbreviations: dict[str, str],
|
||||
abbreviations_set: set[str],
|
||||
) -> None:
|
||||
|
@ -137,11 +181,20 @@ def _replace_abbreviations(
|
|||
|
||||
|
||||
@callback
|
||||
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
|
||||
def _replace_all_abbreviations(
|
||||
discovery_payload: dict[str, Any], component_only: bool = False
|
||||
) -> None:
|
||||
"""Replace all abbreviations in an MQTT discovery payload."""
|
||||
|
||||
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||
|
||||
if CONF_AVAILABILITY in discovery_payload:
|
||||
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||
|
||||
if component_only:
|
||||
return
|
||||
|
||||
if CONF_ORIGIN in discovery_payload:
|
||||
_replace_abbreviations(
|
||||
discovery_payload[CONF_ORIGIN],
|
||||
|
@ -156,13 +209,15 @@ def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
|
|||
DEVICE_ABBREVIATIONS_SET,
|
||||
)
|
||||
|
||||
if CONF_AVAILABILITY in discovery_payload:
|
||||
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
|
||||
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
|
||||
if CONF_COMPONENTS in discovery_payload:
|
||||
if not isinstance(discovery_payload[CONF_COMPONENTS], dict):
|
||||
return
|
||||
for comp_conf in discovery_payload[CONF_COMPONENTS].values():
|
||||
_replace_all_abbreviations(comp_conf, component_only=True)
|
||||
|
||||
|
||||
@callback
|
||||
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
|
||||
def _replace_topic_base(discovery_payload: MQTTDiscoveryPayload) -> None:
|
||||
"""Replace topic base in MQTT discovery data."""
|
||||
base = discovery_payload.pop(TOPIC_BASE)
|
||||
for key, value in discovery_payload.items():
|
||||
|
@ -182,6 +237,79 @@ def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
|
|||
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"
|
||||
|
||||
|
||||
@callback
|
||||
def _generate_device_config(
|
||||
hass: HomeAssistant,
|
||||
object_id: str,
|
||||
node_id: str | None,
|
||||
migrate_discovery: bool = False,
|
||||
) -> MQTTDiscoveryPayload:
|
||||
"""Generate a cleanup or discovery migration message on device cleanup.
|
||||
|
||||
If an empty payload, or a migrate discovery request is received for a device,
|
||||
we forward an empty payload for all previously discovered components.
|
||||
"""
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
device_node_id: str = f"{node_id} {object_id}" if node_id else object_id
|
||||
config = MQTTDiscoveryPayload({CONF_DEVICE: {}, CONF_COMPONENTS: {}})
|
||||
config.migrate_discovery = migrate_discovery
|
||||
comp_config = config[CONF_COMPONENTS]
|
||||
for platform, discover_id in mqtt_data.discovery_already_discovered:
|
||||
ids = discover_id.split(" ")
|
||||
component_node_id = ids.pop(0)
|
||||
component_object_id = " ".join(ids)
|
||||
if not ids:
|
||||
continue
|
||||
if device_node_id == component_node_id:
|
||||
comp_config[component_object_id] = {CONF_PLATFORM: platform}
|
||||
|
||||
return config if comp_config else MQTTDiscoveryPayload({})
|
||||
|
||||
|
||||
@callback
|
||||
def _parse_device_payload(
|
||||
hass: HomeAssistant,
|
||||
payload: ReceivePayloadType,
|
||||
object_id: str,
|
||||
node_id: str | None,
|
||||
) -> MQTTDiscoveryPayload:
|
||||
"""Parse a device discovery payload.
|
||||
|
||||
The device discovery payload is translated info the config payloads for every single
|
||||
component inside the device based configuration.
|
||||
An empty payload is translated in a cleanup, which forwards an empty payload to all
|
||||
removed components.
|
||||
"""
|
||||
device_payload = MQTTDiscoveryPayload()
|
||||
if payload == "":
|
||||
if not (device_payload := _generate_device_config(hass, object_id, node_id)):
|
||||
_LOGGER.warning(
|
||||
"No device components to cleanup for %s, node_id '%s'",
|
||||
object_id,
|
||||
node_id,
|
||||
)
|
||||
return device_payload
|
||||
try:
|
||||
device_payload = MQTTDiscoveryPayload(json_loads_object(payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
||||
return device_payload
|
||||
if _async_process_discovery_migration(device_payload):
|
||||
return _generate_device_config(hass, object_id, node_id, migrate_discovery=True)
|
||||
_replace_all_abbreviations(device_payload)
|
||||
try:
|
||||
DEVICE_DISCOVERY_SCHEMA(device_payload)
|
||||
except vol.Invalid as exc:
|
||||
_LOGGER.warning(
|
||||
"Invalid MQTT device discovery payload for %s, %s: '%s'",
|
||||
object_id,
|
||||
exc,
|
||||
payload,
|
||||
)
|
||||
return MQTTDiscoveryPayload({})
|
||||
return device_payload
|
||||
|
||||
|
||||
@callback
|
||||
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
|
||||
"""Parse and validate origin info from a single component discovery payload."""
|
||||
|
@ -199,6 +327,30 @@ def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
@callback
|
||||
def _merge_common_device_options(
|
||||
component_config: MQTTDiscoveryPayload, device_config: dict[str, Any]
|
||||
) -> None:
|
||||
"""Merge common device options with the component config options.
|
||||
|
||||
Common options are:
|
||||
CONF_AVAILABILITY,
|
||||
CONF_AVAILABILITY_MODE,
|
||||
CONF_AVAILABILITY_TEMPLATE,
|
||||
CONF_AVAILABILITY_TOPIC,
|
||||
CONF_COMMAND_TOPIC,
|
||||
CONF_PAYLOAD_AVAILABLE,
|
||||
CONF_PAYLOAD_NOT_AVAILABLE,
|
||||
CONF_STATE_TOPIC,
|
||||
Common options in the body of the device based config are inherited into
|
||||
the component. Unless the option is explicitly specified at component level,
|
||||
in that case the option at component level will override the common option.
|
||||
"""
|
||||
for option in SHARED_OPTIONS:
|
||||
if option in device_config and option not in component_config:
|
||||
component_config[option] = device_config.get(option)
|
||||
|
||||
|
||||
async def async_start( # noqa: C901
|
||||
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
|
||||
) -> None:
|
||||
|
@ -243,8 +395,7 @@ async def async_start( # noqa: C901
|
|||
_LOGGER.warning(
|
||||
(
|
||||
"Received message on illegal discovery topic '%s'. The topic"
|
||||
" contains "
|
||||
"not allowed characters. For more information see "
|
||||
" contains non allowed characters. For more information see "
|
||||
"https://www.home-assistant.io/integrations/mqtt/#discovery-topic"
|
||||
),
|
||||
topic,
|
||||
|
@ -253,42 +404,109 @@ async def async_start( # noqa: C901
|
|||
|
||||
component, node_id, object_id = match.groups()
|
||||
|
||||
if payload:
|
||||
discovered_components: list[MqttComponentConfig] = []
|
||||
if component == CONF_DEVICE:
|
||||
# Process device based discovery message and regenerate
|
||||
# cleanup config for the all the components that are being removed.
|
||||
# This is done when a component in the device config is omitted and detected
|
||||
# as being removed, or when the device config update payload is empty.
|
||||
# In that case this will regenerate a cleanup message for all every already
|
||||
# discovered components that were linked to the initial device discovery.
|
||||
device_discovery_payload = _parse_device_payload(
|
||||
hass, payload, object_id, node_id
|
||||
)
|
||||
if not device_discovery_payload:
|
||||
return
|
||||
device_config: dict[str, Any]
|
||||
origin_config: dict[str, Any] | None
|
||||
component_configs: dict[str, dict[str, Any]]
|
||||
device_config = device_discovery_payload[CONF_DEVICE]
|
||||
origin_config = device_discovery_payload.get(CONF_ORIGIN)
|
||||
component_configs = device_discovery_payload[CONF_COMPONENTS]
|
||||
for component_id, config in component_configs.items():
|
||||
component = config.pop(CONF_PLATFORM)
|
||||
# The object_id in the device discovery topic is the unique identifier.
|
||||
# It is used as node_id for the components it contains.
|
||||
component_node_id = object_id
|
||||
# The component_id in the discovery playload is used as object_id
|
||||
# If we have an additional node_id in the discovery topic,
|
||||
# we extend the component_id with it.
|
||||
component_object_id = (
|
||||
f"{node_id} {component_id}" if node_id else component_id
|
||||
)
|
||||
# We add wrapper to the discovery payload with the discovery data.
|
||||
# If the dict is empty after removing the platform, the payload is
|
||||
# assumed to remove the existing config and we do not want to add
|
||||
# device or orig or shared availability attributes.
|
||||
if discovery_payload := MQTTDiscoveryPayload(config):
|
||||
discovery_payload[CONF_DEVICE] = device_config
|
||||
discovery_payload[CONF_ORIGIN] = origin_config
|
||||
# Only assign shared config options
|
||||
# when they are not set at entity level
|
||||
_merge_common_device_options(
|
||||
discovery_payload, device_discovery_payload
|
||||
)
|
||||
discovery_payload.device_discovery = True
|
||||
discovery_payload.migrate_discovery = (
|
||||
device_discovery_payload.migrate_discovery
|
||||
)
|
||||
discovered_components.append(
|
||||
MqttComponentConfig(
|
||||
component,
|
||||
component_object_id,
|
||||
component_node_id,
|
||||
discovery_payload,
|
||||
)
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Process device discovery payload %s", device_discovery_payload
|
||||
)
|
||||
device_discovery_id = f"{node_id} {object_id}" if node_id else object_id
|
||||
message = f"Processing device discovery for '{device_discovery_id}'"
|
||||
async_log_discovery_origin_info(
|
||||
message, MQTTDiscoveryPayload(device_discovery_payload)
|
||||
)
|
||||
|
||||
else:
|
||||
# Process component based discovery message
|
||||
try:
|
||||
discovery_payload = MQTTDiscoveryPayload(json_loads_object(payload))
|
||||
discovery_payload = MQTTDiscoveryPayload(
|
||||
json_loads_object(payload) if payload else {}
|
||||
)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
||||
return
|
||||
if not _async_process_discovery_migration(discovery_payload):
|
||||
_replace_all_abbreviations(discovery_payload)
|
||||
if not _valid_origin_info(discovery_payload):
|
||||
return
|
||||
discovered_components.append(
|
||||
MqttComponentConfig(component, object_id, node_id, discovery_payload)
|
||||
)
|
||||
|
||||
discovery_pending_discovered = mqtt_data.discovery_pending_discovered
|
||||
for component_config in discovered_components:
|
||||
component = component_config.component
|
||||
node_id = component_config.node_id
|
||||
object_id = component_config.object_id
|
||||
discovery_payload = component_config.discovery_payload
|
||||
|
||||
if TOPIC_BASE in discovery_payload:
|
||||
_replace_topic_base(discovery_payload)
|
||||
else:
|
||||
discovery_payload = MQTTDiscoveryPayload({})
|
||||
|
||||
# If present, the node_id will be included in the discovered object id
|
||||
# If present, the node_id will be included in the discovery_id.
|
||||
discovery_id = f"{node_id} {object_id}" if node_id else object_id
|
||||
discovery_hash = (component, discovery_id)
|
||||
|
||||
if discovery_payload:
|
||||
# Attach MQTT topic to the payload, used for debug prints
|
||||
setattr(
|
||||
discovery_payload,
|
||||
"__configuration_source__",
|
||||
f"MQTT (topic: '{topic}')",
|
||||
)
|
||||
discovery_data = {
|
||||
discovery_payload.discovery_data = {
|
||||
ATTR_DISCOVERY_HASH: discovery_hash,
|
||||
ATTR_DISCOVERY_PAYLOAD: discovery_payload,
|
||||
ATTR_DISCOVERY_TOPIC: topic,
|
||||
}
|
||||
setattr(discovery_payload, "discovery_data", discovery_data)
|
||||
|
||||
discovery_payload[CONF_PLATFORM] = "mqtt"
|
||||
|
||||
if discovery_hash in mqtt_data.discovery_pending_discovered:
|
||||
pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
|
||||
if discovery_hash in discovery_pending_discovered:
|
||||
pending = discovery_pending_discovered[discovery_hash]["pending"]
|
||||
pending.appendleft(discovery_payload)
|
||||
_LOGGER.debug(
|
||||
"Component has already been discovered: %s %s, queuing update",
|
||||
|
@ -305,7 +523,7 @@ async def async_start( # noqa: C901
|
|||
) -> None:
|
||||
"""Process the payload of a new discovery."""
|
||||
|
||||
_LOGGER.debug("Process discovery payload %s", payload)
|
||||
_LOGGER.debug("Process component discovery payload %s", payload)
|
||||
discovery_hash = (component, discovery_id)
|
||||
|
||||
already_discovered = discovery_hash in mqtt_data.discovery_already_discovered
|
||||
|
@ -362,6 +580,8 @@ async def async_start( # noqa: C901
|
|||
0,
|
||||
job_type=HassJobType.Callback,
|
||||
)
|
||||
# Subscribe first for platform discovery wildcard topics first,
|
||||
# and then subscribe device discovery wildcard topics.
|
||||
for topic in chain(
|
||||
(
|
||||
f"{discovery_topic}/{component}/+/config"
|
||||
|
@ -371,6 +591,10 @@ async def async_start( # noqa: C901
|
|||
f"{discovery_topic}/{component}/+/+/config"
|
||||
for component in SUPPORTED_COMPONENTS
|
||||
),
|
||||
(
|
||||
f"{discovery_topic}/device/+/config",
|
||||
f"{discovery_topic}/device/+/+/config",
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -104,6 +104,8 @@ from .discovery import (
|
|||
MQTT_DISCOVERY_UPDATED,
|
||||
MQTTDiscoveryPayload,
|
||||
clear_discovery_hash,
|
||||
get_origin_log_string,
|
||||
get_origin_support_url,
|
||||
set_discovery_hash,
|
||||
)
|
||||
from .models import (
|
||||
|
@ -591,6 +593,7 @@ async def cleanup_device_registry(
|
|||
entity_registry = er.async_get(hass)
|
||||
if (
|
||||
device_id
|
||||
and device_id not in device_registry.deleted_devices
|
||||
and config_entry_id
|
||||
and not er.async_entries_for_device(
|
||||
entity_registry, device_id, include_disabled_entities=False
|
||||
|
@ -672,6 +675,7 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
|
|||
self._config_entry = config_entry
|
||||
self._config_entry_id = config_entry.entry_id
|
||||
self._skip_device_removal: bool = False
|
||||
self._migrate_discovery: str | None = None
|
||||
|
||||
discovery_hash = get_discovery_hash(discovery_data)
|
||||
self._remove_discovery_updated = async_dispatcher_connect(
|
||||
|
@ -704,12 +708,95 @@ class MqttDiscoveryDeviceUpdateMixin(ABC):
|
|||
) -> None:
|
||||
"""Handle discovery update."""
|
||||
discovery_hash = get_discovery_hash(self._discovery_data)
|
||||
# Start discovery migration or rollback if migrate_discovery flag is set
|
||||
# and the discovery topic is valid and not yet migrating
|
||||
if (
|
||||
discovery_payload.migrate_discovery
|
||||
and self._migrate_discovery is None
|
||||
and self._discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
== discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
):
|
||||
self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH]
|
||||
origin_info = get_origin_log_string(
|
||||
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
|
||||
)
|
||||
action = "Rollback" if discovery_payload.device_discovery else "Migration"
|
||||
schema_type = "platform" if discovery_payload.device_discovery else "device"
|
||||
_LOGGER.info(
|
||||
"%s to MQTT %s discovery schema started for %s '%s'"
|
||||
"%s on topic %s. To complete %s, publish a %s discovery "
|
||||
"message with %s '%s'. After completed %s, "
|
||||
"publish an empty (retained) payload to %s",
|
||||
action,
|
||||
schema_type,
|
||||
discovery_hash[0],
|
||||
discovery_hash[1],
|
||||
origin_info,
|
||||
self._migrate_discovery,
|
||||
action.lower(),
|
||||
schema_type,
|
||||
discovery_hash[0],
|
||||
discovery_hash[1],
|
||||
action.lower(),
|
||||
self._migrate_discovery,
|
||||
)
|
||||
|
||||
# Cleanup platform resources
|
||||
await self.async_tear_down()
|
||||
# Unregister and clean discovery
|
||||
stop_discovery_updates(
|
||||
self.hass, self._discovery_data, self._remove_discovery_updated
|
||||
)
|
||||
send_discovery_done(self.hass, self._discovery_data)
|
||||
return
|
||||
|
||||
_LOGGER.debug(
|
||||
"Got update for %s with hash: %s '%s'",
|
||||
self.log_name,
|
||||
discovery_hash,
|
||||
discovery_payload,
|
||||
)
|
||||
new_discovery_topic = discovery_payload.discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
|
||||
# Abort early if an update is not received via the registered discovery topic.
|
||||
# This can happen if a device and single component discovery payload
|
||||
# share the same discovery ID.
|
||||
if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic:
|
||||
# Prevent illegal updates
|
||||
old_origin_info = get_origin_log_string(
|
||||
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
|
||||
)
|
||||
new_origin_info = get_origin_log_string(
|
||||
discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD],
|
||||
include_url=False,
|
||||
)
|
||||
new_origin_support_url = get_origin_support_url(
|
||||
discovery_payload.discovery_data[ATTR_DISCOVERY_PAYLOAD]
|
||||
)
|
||||
if new_origin_support_url:
|
||||
get_support = f"for support visit {new_origin_support_url}"
|
||||
else:
|
||||
get_support = (
|
||||
"for documentation on migration to device schema or rollback to "
|
||||
"discovery schema, visit https://www.home-assistant.io/integrations/"
|
||||
"mqtt/#migration-from-single-component-to-device-based-discovery"
|
||||
)
|
||||
_LOGGER.warning(
|
||||
"Received a conflicting MQTT discovery message for %s '%s' which was "
|
||||
"previously discovered on topic %s%s; the conflicting discovery "
|
||||
"message was received on topic %s%s; %s",
|
||||
discovery_hash[0],
|
||||
discovery_hash[1],
|
||||
self._discovery_data[ATTR_DISCOVERY_TOPIC],
|
||||
old_origin_info,
|
||||
new_discovery_topic,
|
||||
new_origin_info,
|
||||
get_support,
|
||||
)
|
||||
send_discovery_done(self.hass, self._discovery_data)
|
||||
return
|
||||
|
||||
if (
|
||||
discovery_payload
|
||||
and discovery_payload != self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
|
||||
|
@ -806,6 +893,7 @@ class MqttDiscoveryUpdateMixin(Entity):
|
|||
mqtt_data = hass.data[DATA_MQTT]
|
||||
self._registry_hooks = mqtt_data.discovery_registry_hooks
|
||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
self._migrate_discovery: str | None = None
|
||||
if discovery_hash in self._registry_hooks:
|
||||
self._registry_hooks.pop(discovery_hash)()
|
||||
|
||||
|
@ -863,7 +951,12 @@ class MqttDiscoveryUpdateMixin(Entity):
|
|||
if TYPE_CHECKING:
|
||||
assert self._discovery_data
|
||||
self._cleanup_discovery_on_remove()
|
||||
if self._migrate_discovery is None:
|
||||
# Unload and cleanup registry
|
||||
await self._async_remove_state_and_registry_entry()
|
||||
else:
|
||||
# Only unload the entity
|
||||
await self.async_remove(force_remove=True)
|
||||
send_discovery_done(self.hass, self._discovery_data)
|
||||
|
||||
@callback
|
||||
|
@ -878,18 +971,102 @@ class MqttDiscoveryUpdateMixin(Entity):
|
|||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._discovery_data
|
||||
discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
|
||||
discovery_hash = get_discovery_hash(self._discovery_data)
|
||||
# Start discovery migration or rollback if migrate_discovery flag is set
|
||||
# and the discovery topic is valid and not yet migrating
|
||||
if (
|
||||
payload.migrate_discovery
|
||||
and self._migrate_discovery is None
|
||||
and self._discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
== payload.discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
):
|
||||
if self.unique_id is None or self.device_info is None:
|
||||
_LOGGER.error(
|
||||
"Discovery migration is not possible for "
|
||||
"for entity %s on topic %s. A unique_id "
|
||||
"and device context is required, got unique_id: %s, device: %s",
|
||||
self.entity_id,
|
||||
self._discovery_data[ATTR_DISCOVERY_TOPIC],
|
||||
self.unique_id,
|
||||
self.device_info,
|
||||
)
|
||||
send_discovery_done(self.hass, self._discovery_data)
|
||||
return
|
||||
|
||||
self._migrate_discovery = self._discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
discovery_hash = self._discovery_data[ATTR_DISCOVERY_HASH]
|
||||
origin_info = get_origin_log_string(
|
||||
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
|
||||
)
|
||||
action = "Rollback" if payload.device_discovery else "Migration"
|
||||
schema_type = "platform" if payload.device_discovery else "device"
|
||||
_LOGGER.info(
|
||||
"%s to MQTT %s discovery schema started for entity %s"
|
||||
"%s on topic %s. To complete %s, publish a %s discovery "
|
||||
"message with %s entity '%s'. After completed %s, "
|
||||
"publish an empty (retained) payload to %s",
|
||||
action,
|
||||
schema_type,
|
||||
self.entity_id,
|
||||
origin_info,
|
||||
self._migrate_discovery,
|
||||
action.lower(),
|
||||
schema_type,
|
||||
discovery_hash[0],
|
||||
discovery_hash[1],
|
||||
action.lower(),
|
||||
self._migrate_discovery,
|
||||
)
|
||||
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
|
||||
_LOGGER.debug(
|
||||
"Got update for entity with hash: %s '%s'",
|
||||
discovery_hash,
|
||||
payload,
|
||||
)
|
||||
old_payload: DiscoveryInfoType
|
||||
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
|
||||
new_discovery_topic = payload.discovery_data[ATTR_DISCOVERY_TOPIC]
|
||||
# Abort early if an update is not received via the registered discovery topic.
|
||||
# This can happen if a device and single component discovery payload
|
||||
# share the same discovery ID.
|
||||
if self._discovery_data[ATTR_DISCOVERY_TOPIC] != new_discovery_topic:
|
||||
# Prevent illegal updates
|
||||
old_origin_info = get_origin_log_string(
|
||||
self._discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
|
||||
)
|
||||
new_origin_info = get_origin_log_string(
|
||||
payload.discovery_data[ATTR_DISCOVERY_PAYLOAD], include_url=False
|
||||
)
|
||||
new_origin_support_url = get_origin_support_url(
|
||||
payload.discovery_data[ATTR_DISCOVERY_PAYLOAD]
|
||||
)
|
||||
if new_origin_support_url:
|
||||
get_support = f"for support visit {new_origin_support_url}"
|
||||
else:
|
||||
get_support = (
|
||||
"for documentation on migration to device schema or rollback to "
|
||||
"discovery schema, visit https://www.home-assistant.io/integrations/"
|
||||
"mqtt/#migration-from-single-component-to-device-based-discovery"
|
||||
)
|
||||
_LOGGER.warning(
|
||||
"Received a conflicting MQTT discovery message for entity %s; the "
|
||||
"entity was previously discovered on topic %s%s; the conflicting "
|
||||
"discovery message was received on topic %s%s; %s",
|
||||
self.entity_id,
|
||||
self._discovery_data[ATTR_DISCOVERY_TOPIC],
|
||||
old_origin_info,
|
||||
new_discovery_topic,
|
||||
new_origin_info,
|
||||
get_support,
|
||||
)
|
||||
send_discovery_done(self.hass, self._discovery_data)
|
||||
return
|
||||
|
||||
debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id)
|
||||
if not payload:
|
||||
# Empty payload: Remove component
|
||||
if self._migrate_discovery is None:
|
||||
_LOGGER.info("Removing component: %s", self.entity_id)
|
||||
else:
|
||||
_LOGGER.info("Unloading component: %s", self.entity_id)
|
||||
self.hass.async_create_task(
|
||||
self._async_process_discovery_update_and_remove()
|
||||
)
|
||||
|
|
|
@ -410,5 +410,15 @@ class MqttData:
|
|||
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MqttComponentConfig:
|
||||
"""(component, object_id, node_id, discovery_payload)."""
|
||||
|
||||
component: str
|
||||
object_id: str
|
||||
node_id: str | None
|
||||
discovery_payload: MQTTDiscoveryPayload
|
||||
|
||||
|
||||
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
|
||||
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
|
@ -11,6 +13,7 @@ from homeassistant.const import (
|
|||
CONF_MODEL,
|
||||
CONF_MODEL_ID,
|
||||
CONF_NAME,
|
||||
CONF_PLATFORM,
|
||||
CONF_UNIQUE_ID,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
|
@ -25,10 +28,13 @@ from .const import (
|
|||
CONF_AVAILABILITY_MODE,
|
||||
CONF_AVAILABILITY_TEMPLATE,
|
||||
CONF_AVAILABILITY_TOPIC,
|
||||
CONF_COMMAND_TOPIC,
|
||||
CONF_COMPONENTS,
|
||||
CONF_CONFIGURATION_URL,
|
||||
CONF_CONNECTIONS,
|
||||
CONF_DEPRECATED_VIA_HUB,
|
||||
CONF_ENABLED_BY_DEFAULT,
|
||||
CONF_ENCODING,
|
||||
CONF_ENTITY_PICTURE,
|
||||
CONF_HW_VERSION,
|
||||
CONF_IDENTIFIERS,
|
||||
|
@ -39,7 +45,9 @@ from .const import (
|
|||
CONF_ORIGIN,
|
||||
CONF_PAYLOAD_AVAILABLE,
|
||||
CONF_PAYLOAD_NOT_AVAILABLE,
|
||||
CONF_QOS,
|
||||
CONF_SERIAL_NUMBER,
|
||||
CONF_STATE_TOPIC,
|
||||
CONF_SUGGESTED_AREA,
|
||||
CONF_SUPPORT_URL,
|
||||
CONF_SW_VERSION,
|
||||
|
@ -47,10 +55,34 @@ from .const import (
|
|||
CONF_VIA_DEVICE,
|
||||
DEFAULT_PAYLOAD_AVAILABLE,
|
||||
DEFAULT_PAYLOAD_NOT_AVAILABLE,
|
||||
ENTITY_PLATFORMS,
|
||||
SUPPORTED_COMPONENTS,
|
||||
)
|
||||
from .util import valid_subscribe_topic
|
||||
from .util import valid_publish_topic, valid_qos_schema, valid_subscribe_topic
|
||||
|
||||
MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
|
||||
# Device discovery options that are also available at entity component level
|
||||
SHARED_OPTIONS = [
|
||||
CONF_AVAILABILITY,
|
||||
CONF_AVAILABILITY_MODE,
|
||||
CONF_AVAILABILITY_TEMPLATE,
|
||||
CONF_AVAILABILITY_TOPIC,
|
||||
CONF_COMMAND_TOPIC,
|
||||
CONF_PAYLOAD_AVAILABLE,
|
||||
CONF_PAYLOAD_NOT_AVAILABLE,
|
||||
CONF_STATE_TOPIC,
|
||||
]
|
||||
|
||||
MQTT_ORIGIN_INFO_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_NAME): cv.string,
|
||||
vol.Optional(CONF_SW_VERSION): cv.string,
|
||||
vol.Optional(CONF_SUPPORT_URL): cv.configuration_url,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
_MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Exclusive(CONF_AVAILABILITY_TOPIC, "availability"): valid_subscribe_topic,
|
||||
vol.Optional(CONF_AVAILABILITY_TEMPLATE): cv.template,
|
||||
|
@ -63,7 +95,7 @@ MQTT_AVAILABILITY_SINGLE_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
|
||||
_MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_AVAILABILITY_MODE, default=AVAILABILITY_LATEST): vol.All(
|
||||
cv.string, vol.In(AVAILABILITY_MODES)
|
||||
|
@ -87,8 +119,8 @@ MQTT_AVAILABILITY_LIST_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
MQTT_AVAILABILITY_SCHEMA = MQTT_AVAILABILITY_SINGLE_SCHEMA.extend(
|
||||
MQTT_AVAILABILITY_LIST_SCHEMA.schema
|
||||
_MQTT_AVAILABILITY_SCHEMA = _MQTT_AVAILABILITY_SINGLE_SCHEMA.extend(
|
||||
_MQTT_AVAILABILITY_LIST_SCHEMA.schema
|
||||
)
|
||||
|
||||
|
||||
|
@ -138,7 +170,7 @@ MQTT_ORIGIN_INFO_SCHEMA = vol.All(
|
|||
),
|
||||
)
|
||||
|
||||
MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
|
||||
MQTT_ENTITY_COMMON_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA,
|
||||
vol.Optional(CONF_ENTITY_PICTURE): cv.url,
|
||||
|
@ -152,3 +184,35 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
|
|||
vol.Optional(CONF_UNIQUE_ID): cv.string,
|
||||
}
|
||||
)
|
||||
|
||||
_UNIQUE_ID_SCHEMA = vol.Schema(
|
||||
{vol.Required(CONF_UNIQUE_ID): cv.string},
|
||||
).extend({}, extra=True)
|
||||
|
||||
|
||||
def check_unique_id(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Check if a unique ID is set in case an entity platform is configured."""
|
||||
platform = config[CONF_PLATFORM]
|
||||
if platform in ENTITY_PLATFORMS and len(config.keys()) > 1:
|
||||
_UNIQUE_ID_SCHEMA(config)
|
||||
return config
|
||||
|
||||
|
||||
_COMPONENT_CONFIG_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{vol.Required(CONF_PLATFORM): vol.In(SUPPORTED_COMPONENTS)},
|
||||
).extend({}, extra=True),
|
||||
check_unique_id,
|
||||
)
|
||||
|
||||
DEVICE_DISCOVERY_SCHEMA = _MQTT_AVAILABILITY_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA,
|
||||
vol.Required(CONF_COMPONENTS): vol.Schema({str: _COMPONENT_CONFIG_SCHEMA}),
|
||||
vol.Required(CONF_ORIGIN): MQTT_ORIGIN_INFO_SCHEMA,
|
||||
vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic,
|
||||
vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic,
|
||||
vol.Optional(CONF_QOS): valid_qos_schema,
|
||||
vol.Optional(CONF_ENCODING): cv.string,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
from collections.abc import AsyncGenerator, Generator
|
||||
from random import getrandbits
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -122,3 +122,10 @@ def record_calls(recorded_calls: list[ReceiveMessage]) -> MessageCallbackType:
|
|||
recorded_calls.append(msg)
|
||||
|
||||
return record_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tag_mock() -> Generator[AsyncMock]:
|
||||
"""Fixture to mock tag."""
|
||||
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
|
||||
yield mock_tag
|
||||
|
|
|
@ -1716,6 +1716,64 @@ async def test_mqtt_subscribes_topics_on_connect(
|
|||
assert ("still/pending", 1) in subscribe_calls
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mqtt_config_entry_data", [ENTRY_DEFAULT_BIRTH_MESSAGE])
|
||||
async def test_mqtt_subscribes_wildcard_topics_in_correct_order(
|
||||
hass: HomeAssistant,
|
||||
mock_debouncer: asyncio.Event,
|
||||
setup_with_birth_msg_client_mock: MqttMockPahoClient,
|
||||
record_calls: MessageCallbackType,
|
||||
) -> None:
|
||||
"""Test subscription to wildcard topics on connect in the order of subscription."""
|
||||
mqtt_client_mock = setup_with_birth_msg_client_mock
|
||||
|
||||
mock_debouncer.clear()
|
||||
await mqtt.async_subscribe(hass, "integration/test#", record_calls)
|
||||
await mqtt.async_subscribe(hass, "integration/kitchen_sink#", record_calls)
|
||||
await mock_debouncer.wait()
|
||||
|
||||
def _assert_subscription_order():
|
||||
discovery_subscribes = [
|
||||
f"homeassistant/{platform}/+/config" for platform in SUPPORTED_COMPONENTS
|
||||
]
|
||||
discovery_subscribes.extend(
|
||||
[
|
||||
f"homeassistant/{platform}/+/+/config"
|
||||
for platform in SUPPORTED_COMPONENTS
|
||||
]
|
||||
)
|
||||
discovery_subscribes.extend(
|
||||
["homeassistant/device/+/config", "homeassistant/device/+/+/config"]
|
||||
)
|
||||
discovery_subscribes.extend(["integration/test#", "integration/kitchen_sink#"])
|
||||
|
||||
expected_discovery_subscribes = discovery_subscribes.copy()
|
||||
|
||||
# Assert we see the expected subscribes and in the correct order
|
||||
actual_subscribes = [
|
||||
discovery_subscribes.pop(0)
|
||||
for call in help_all_subscribe_calls(mqtt_client_mock)
|
||||
if discovery_subscribes and discovery_subscribes[0] == call[0]
|
||||
]
|
||||
|
||||
# Assert we have processed all items and that they are in the correct order
|
||||
assert len(discovery_subscribes) == 0
|
||||
assert actual_subscribes == expected_discovery_subscribes
|
||||
|
||||
# Assert the initial wildcard topic subscription order
|
||||
_assert_subscription_order()
|
||||
|
||||
mqtt_client_mock.on_disconnect(Mock(), None, 0)
|
||||
|
||||
mqtt_client_mock.reset_mock()
|
||||
|
||||
mock_debouncer.clear()
|
||||
mqtt_client_mock.on_connect(Mock(), None, 0, 0)
|
||||
await mock_debouncer.wait()
|
||||
|
||||
# Assert the wildcard topic subscription order after a reconnect
|
||||
_assert_subscription_order()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mqtt_config_entry_data",
|
||||
[ENTRY_DEFAULT_BIRTH_MESSAGE | {mqtt.CONF_DISCOVERY: False}],
|
||||
|
|
|
@ -69,6 +69,7 @@ DEFAULT_CONFIG_DEVICE_INFO_MAC = {
|
|||
_SENTINEL = object()
|
||||
|
||||
DISCOVERY_COUNT = len(MQTT)
|
||||
DEVICE_DISCOVERY_COUNT = 2
|
||||
|
||||
type _MqttMessageType = list[tuple[str, str]]
|
||||
type _AttributesType = list[tuple[str, Any]]
|
||||
|
@ -1189,7 +1190,10 @@ async def help_test_entity_id_update_subscriptions(
|
|||
assert state is not None
|
||||
assert (
|
||||
mqtt_mock.async_subscribe.call_count
|
||||
== len(topics) + 2 * len(SUPPORTED_COMPONENTS) + DISCOVERY_COUNT
|
||||
== len(topics)
|
||||
+ 2 * len(SUPPORTED_COMPONENTS)
|
||||
+ DISCOVERY_COUNT
|
||||
+ DEVICE_DISCOVERY_COUNT
|
||||
)
|
||||
for topic in topics:
|
||||
mqtt_mock.async_subscribe.assert_any_call(
|
||||
|
|
|
@ -26,22 +26,42 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
|
|||
"""Stub copying the blueprints to the config folder."""
|
||||
|
||||
|
||||
async def test_get_triggers(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
) -> None:
|
||||
"""Test we get the expected triggers from a discovered mqtt device."""
|
||||
await mqtt_mock_entry()
|
||||
data1 = (
|
||||
@pytest.mark.parametrize(
|
||||
("discovery_topic", "data"),
|
||||
[
|
||||
(
|
||||
"homeassistant/device_automation/0AFFD2/bla/config",
|
||||
'{ "automation_type":"trigger",'
|
||||
' "device":{"identifiers":["0AFFD2"]},'
|
||||
' "payload": "short_press",'
|
||||
' "topic": "foobar/triggers/button1",'
|
||||
' "type": "button_short_press",'
|
||||
' "subtype": "button_1" }'
|
||||
' "subtype": "button_1" }',
|
||||
),
|
||||
(
|
||||
"homeassistant/device/0AFFD2/config",
|
||||
'{ "device":{"identifiers":["0AFFD2"]},'
|
||||
' "o": {"name": "foobar"}, "cmps": '
|
||||
'{ "bla": {'
|
||||
' "automation_type":"trigger", '
|
||||
' "payload": "short_press",'
|
||||
' "topic": "foobar/triggers/button1",'
|
||||
' "type": "button_short_press",'
|
||||
' "subtype": "button_1",'
|
||||
' "platform":"device_automation"}}}',
|
||||
),
|
||||
],
|
||||
)
|
||||
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla/config", data1)
|
||||
async def test_get_triggers(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
discovery_topic: str,
|
||||
data: str,
|
||||
) -> None:
|
||||
"""Test we get the expected triggers from a discovered mqtt device."""
|
||||
await mqtt_mock_entry()
|
||||
async_fire_mqtt_message(hass, discovery_topic, data)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
device_entry = device_registry.async_get_device(identifiers={("mqtt", "0AFFD2")})
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1197,7 +1197,6 @@ async def test_mqtt_ws_get_device_debug_info(
|
|||
}
|
||||
data_sensor = json.dumps(config_sensor)
|
||||
data_trigger = json.dumps(config_trigger)
|
||||
config_sensor["platform"] = config_trigger["platform"] = mqtt.DOMAIN
|
||||
|
||||
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data_sensor)
|
||||
async_fire_mqtt_message(
|
||||
|
@ -1254,7 +1253,6 @@ async def test_mqtt_ws_get_device_debug_info_binary(
|
|||
"unique_id": "unique",
|
||||
}
|
||||
data = json.dumps(config)
|
||||
config["platform"] = mqtt.DOMAIN
|
||||
|
||||
async_fire_mqtt_message(hass, "homeassistant/camera/bla/config", data)
|
||||
await hass.async_block_till_done()
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
"""The tests for MQTT tag scanner."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import copy
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, AsyncMock, patch
|
||||
from unittest.mock import ANY, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -47,13 +46,6 @@ DEFAULT_TAG_SCAN_JSON = (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tag_mock() -> Generator[AsyncMock]:
|
||||
"""Fixture to mock tag."""
|
||||
with patch("homeassistant.components.tag.async_scan_tag") as mock_tag:
|
||||
yield mock_tag
|
||||
|
||||
|
||||
@pytest.mark.no_fail_on_log_exception
|
||||
async def test_discover_bad_tag(
|
||||
hass: HomeAssistant,
|
||||
|
|
Loading…
Add table
Reference in a new issue