Correct formatting mqtt MQTT_DISCOVERY_DONE and MQTT_DISCOVERY_UPDATED message (#116947)

This commit is contained in:
Jan Bouwhuis 2024-05-06 22:32:46 +02:00 committed by GitHub
parent e65f2f1984
commit 821c7d813d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 11 deletions

View file

@ -82,13 +82,15 @@ SUPPORTED_COMPONENTS = {
} }
MQTT_DISCOVERY_UPDATED: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat( MQTT_DISCOVERY_UPDATED: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat(
"mqtt_discovery_updated_{}" "mqtt_discovery_updated_{}_{}"
) )
MQTT_DISCOVERY_NEW: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat( MQTT_DISCOVERY_NEW: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat(
"mqtt_discovery_new_{}_{}" "mqtt_discovery_new_{}_{}"
) )
MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component" MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component"
MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat("mqtt_discovery_done_{}") MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat(
"mqtt_discovery_done_{}_{}"
)
TOPIC_BASE = "~" TOPIC_BASE = "~"
@ -329,7 +331,7 @@ async def async_start( # noqa: C901
discovery_pending_discovered[discovery_hash] = { discovery_pending_discovered[discovery_hash] = {
"unsub": async_dispatcher_connect( "unsub": async_dispatcher_connect(
hass, hass,
MQTT_DISCOVERY_DONE.format(discovery_hash), MQTT_DISCOVERY_DONE.format(*discovery_hash),
discovery_done, discovery_done,
), ),
"pending": deque([]), "pending": deque([]),
@ -343,7 +345,7 @@ async def async_start( # noqa: C901
message = f"Component has already been discovered: {component} {discovery_id}, sending update" message = f"Component has already been discovered: {component} {discovery_id}, sending update"
async_log_discovery_origin_info(message, payload) async_log_discovery_origin_info(message, payload)
async_dispatcher_send( async_dispatcher_send(
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), payload hass, MQTT_DISCOVERY_UPDATED.format(*discovery_hash), payload
) )
elif payload: elif payload:
# Add component # Add component
@ -356,7 +358,7 @@ async def async_start( # noqa: C901
else: else:
# Unhandled discovery message # Unhandled discovery message
async_dispatcher_send( async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None hass, MQTT_DISCOVERY_DONE.format(*discovery_hash), None
) )
discovery_topics = [ discovery_topics = [

View file

@ -305,12 +305,12 @@ async def _async_discover(
except vol.Invalid as err: except vol.Invalid as err:
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
clear_discovery_hash(hass, discovery_hash) clear_discovery_hash(hass, discovery_hash)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(*discovery_hash), None)
async_handle_schema_error(discovery_payload, err) async_handle_schema_error(discovery_payload, err)
except Exception: except Exception:
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
clear_discovery_hash(hass, discovery_hash) clear_discovery_hash(hass, discovery_hash)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(*discovery_hash), None)
raise raise
@ -745,7 +745,7 @@ def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]:
def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None: def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None:
"""Acknowledge a discovery message has been handled.""" """Acknowledge a discovery message has been handled."""
discovery_hash = get_discovery_hash(discovery_data) discovery_hash = get_discovery_hash(discovery_data)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(*discovery_hash), None)
def stop_discovery_updates( def stop_discovery_updates(
@ -809,7 +809,7 @@ class MqttDiscoveryDeviceUpdate(ABC):
discovery_hash = get_discovery_hash(discovery_data) discovery_hash = get_discovery_hash(discovery_data)
self._remove_discovery_updated = async_dispatcher_connect( self._remove_discovery_updated = async_dispatcher_connect(
hass, hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash), MQTT_DISCOVERY_UPDATED.format(*discovery_hash),
self.async_discovery_update, self.async_discovery_update,
) )
config_entry.async_on_unload(self._entry_unload) config_entry.async_on_unload(self._entry_unload)
@ -1044,7 +1044,7 @@ class MqttDiscoveryUpdate(Entity):
set_discovery_hash(self.hass, discovery_hash) set_discovery_hash(self.hass, discovery_hash)
self._remove_discovery_updated = async_dispatcher_connect( self._remove_discovery_updated = async_dispatcher_connect(
self.hass, self.hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash), MQTT_DISCOVERY_UPDATED.format(*discovery_hash),
discovery_callback, discovery_callback,
) )

View file

@ -15,7 +15,14 @@ from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS, ABBREVIATIONS,
DEVICE_ABBREVIATIONS, DEVICE_ABBREVIATIONS,
) )
from homeassistant.components.mqtt.discovery import async_start from homeassistant.components.mqtt.discovery import (
MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW,
MQTT_DISCOVERY_NEW_COMPONENT,
MQTT_DISCOVERY_UPDATED,
MQTTDiscoveryPayload,
async_start,
)
from homeassistant.const import ( from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
STATE_ON, STATE_ON,
@ -26,8 +33,13 @@ from homeassistant.const import (
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.signal_type import SignalTypeFormat
from .test_common import help_all_subscribe_calls, help_test_unload_config_entry from .test_common import help_all_subscribe_calls, help_test_unload_config_entry
@ -1765,3 +1777,34 @@ async def test_update_with_bad_config_not_breaks_discovery(
state = hass.states.get("sensor.sbfspot_12345") state = hass.states.get("sensor.sbfspot_12345")
assert state and state.state == "new_value" assert state and state.state == "new_value"
@pytest.mark.parametrize(
"signal_message",
[
MQTT_DISCOVERY_NEW,
MQTT_DISCOVERY_NEW_COMPONENT,
MQTT_DISCOVERY_UPDATED,
MQTT_DISCOVERY_DONE,
],
)
async def test_discovery_dispatcher_signal_type_messages(
hass: HomeAssistant, signal_message: SignalTypeFormat[MQTTDiscoveryPayload]
) -> None:
"""Test discovery dispatcher messages."""
domain_id_tuple = ("sensor", "very_unique")
test_data = {"name": "test", "state_topic": "test-topic"}
calls = []
def _callback(*args) -> None:
calls.append(*args)
unsub = async_dispatcher_connect(
hass, signal_message.format(*domain_id_tuple), _callback
)
async_dispatcher_send(hass, signal_message.format(*domain_id_tuple), test_data)
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0] == test_data
unsub()