Move MQTT discovery hass.data globals to dataclass (#78706)

* Add MQTT discovery hass.data globals to dataclass

* isort

* Additional rework

* Add hass.data["mqtt_tags"] to dataclass

* Follow-up comment

* Corrections
This commit is contained in:
Jan Bouwhuis 2022-09-28 14:13:44 +02:00 committed by GitHub
parent a38c125765
commit 84b2c74746
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 120 additions and 86 deletions

View file

@ -27,7 +27,14 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.core import CoreState, Event, HassJob, HomeAssistant, callback from homeassistant.core import (
CALLBACK_TYPE,
CoreState,
Event,
HassJob,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -178,7 +185,7 @@ async def async_subscribe(
| AsyncDeprecatedMessageCallbackType, | AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
): ) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic. """Subscribe to an MQTT topic.
Call the return value to unsubscribe. Call the return value to unsubscribe.
@ -357,12 +364,12 @@ class MQTT:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
) )
def cleanup(self): def cleanup(self) -> None:
"""Clean up listeners.""" """Clean up listeners."""
while self._cleanup_on_unload: while self._cleanup_on_unload:
self._cleanup_on_unload.pop()() self._cleanup_on_unload.pop()()
def init_client(self): def init_client(self) -> None:
"""Initialize paho client.""" """Initialize paho client."""
self._mqttc = MqttClientSetup(self.conf).client self._mqttc = MqttClientSetup(self.conf).client
self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_connect = self._mqtt_on_connect
@ -429,10 +436,10 @@ class MQTT:
self._mqttc.loop_start() self._mqttc.loop_start()
async def async_disconnect(self): async def async_disconnect(self) -> None:
"""Stop the MQTT client.""" """Stop the MQTT client."""
def stop(): def stop() -> None:
"""Stop the MQTT client.""" """Stop the MQTT client."""
# Do not disconnect, we want the broker to always publish will # Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop() self._mqttc.loop_stop()

View file

@ -18,8 +18,9 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_PROTOCOL,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import ConfigType
from .client import MqttClientSetup from .client import MqttClientSetup
from .const import ( from .const import (
@ -73,7 +74,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
self.hass, get_mqtt_data(self.hass, True).config or {},
user_input[CONF_BROKER], user_input[CONF_BROKER],
user_input[CONF_PORT], user_input[CONF_PORT],
user_input.get(CONF_USERNAME), user_input.get(CONF_USERNAME),
@ -117,7 +118,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
data = self._hassio_discovery data = self._hassio_discovery
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
self.hass, get_mqtt_data(self.hass, True).config or {},
data[CONF_HOST], data[CONF_HOST],
data[CONF_PORT], data[CONF_PORT],
data.get(CONF_USERNAME), data.get(CONF_USERNAME),
@ -164,13 +165,13 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT broker configuration.""" """Manage the MQTT broker configuration."""
mqtt_data = get_mqtt_data(self.hass, True) mqtt_data = get_mqtt_data(self.hass, True)
yaml_config = mqtt_data.config or {}
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = mqtt_data.config or {}
if user_input is not None: if user_input is not None:
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
self.hass, yaml_config,
user_input[CONF_BROKER], user_input[CONF_BROKER],
user_input[CONF_PORT], user_input[CONF_PORT],
user_input.get(CONF_USERNAME), user_input.get(CONF_USERNAME),
@ -338,7 +339,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
def try_connection( def try_connection(
hass: HomeAssistant, yaml_config: ConfigType,
broker: str, broker: str,
port: int, port: int,
username: str | None, username: str | None,
@ -351,8 +352,6 @@ def try_connection(
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml # Get the config from configuration.yaml
mqtt_data = get_mqtt_data(hass, True)
yaml_config = mqtt_data.config or {}
entry_config = { entry_config = {
CONF_BROKER: broker, CONF_BROKER: broker,
CONF_PORT: port, CONF_PORT: port,

View file

@ -8,6 +8,7 @@ import logging
import re import re
import time import time
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@ -18,6 +19,7 @@ from homeassistant.helpers.dispatcher import (
) )
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.loader import async_get_mqtt from homeassistant.loader import async_get_mqtt
from .. import mqtt from .. import mqtt
@ -30,6 +32,7 @@ from .const import (
CONF_TOPIC, CONF_TOPIC,
DOMAIN, DOMAIN,
) )
from .models import ReceiveMessage
from .util import get_mqtt_data from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -62,11 +65,6 @@ SUPPORTED_COMPONENTS = [
"vacuum", "vacuum",
] ]
ALREADY_DISCOVERED = "mqtt_discovered_components"
PENDING_DISCOVERED = "mqtt_pending_components"
DATA_CONFIG_FLOW_LOCK = "mqtt_discovery_config_flow_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
@ -77,21 +75,21 @@ TOPIC_BASE = "~"
class MQTTConfig(dict): class MQTTConfig(dict):
"""Dummy class to allow adding attributes.""" """Dummy class to allow adding attributes."""
discovery_data: dict discovery_data: DiscoveryInfoType
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry in ALREADY_DISCOVERED list.""" """Clear entry from already discovered list."""
del hass.data[ALREADY_DISCOVERED][discovery_hash] get_mqtt_data(hass).discovery_already_discovered.remove(discovery_hash)
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]): def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry in ALREADY_DISCOVERED list.""" """Add entry to already discovered list."""
hass.data[ALREADY_DISCOVERED][discovery_hash] = {} get_mqtt_data(hass).discovery_already_discovered.add(discovery_hash)
async def async_start( # noqa: C901 async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
@ -175,8 +173,8 @@ async def async_start( # noqa: C901
payload[CONF_PLATFORM] = "mqtt" payload[CONF_PLATFORM] = "mqtt"
if discovery_hash in hass.data[PENDING_DISCOVERED]: if discovery_hash in mqtt_data.discovery_pending_discovered:
pending = hass.data[PENDING_DISCOVERED][discovery_hash]["pending"] pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
pending.appendleft(payload) pending.appendleft(payload)
_LOGGER.info( _LOGGER.info(
"Component has already been discovered: %s %s, queuing update", "Component has already been discovered: %s %s, queuing update",
@ -187,27 +185,31 @@ async def async_start( # noqa: C901
await async_process_discovery_payload(component, discovery_id, payload) await async_process_discovery_payload(component, discovery_id, payload)
async def async_process_discovery_payload(component, discovery_id, payload): async def async_process_discovery_payload(
component: str, discovery_id: str, payload: ConfigType
) -> None:
"""Process the payload of a new discovery.""" """Process the payload of a new discovery."""
_LOGGER.debug("Process discovery payload %s", payload) _LOGGER.debug("Process discovery payload %s", payload)
discovery_hash = (component, discovery_id) discovery_hash = (component, discovery_id)
if discovery_hash in hass.data[ALREADY_DISCOVERED] or payload: if discovery_hash in mqtt_data.discovery_already_discovered or payload:
async def discovery_done(_): async def discovery_done(_) -> None:
pending = hass.data[PENDING_DISCOVERED][discovery_hash]["pending"] pending = mqtt_data.discovery_pending_discovered[discovery_hash][
"pending"
]
_LOGGER.debug("Pending discovery for %s: %s", discovery_hash, pending) _LOGGER.debug("Pending discovery for %s: %s", discovery_hash, pending)
if not pending: if not pending:
hass.data[PENDING_DISCOVERED][discovery_hash]["unsub"]() mqtt_data.discovery_pending_discovered[discovery_hash]["unsub"]()
hass.data[PENDING_DISCOVERED].pop(discovery_hash) mqtt_data.discovery_pending_discovered.pop(discovery_hash)
else: else:
payload = pending.pop() payload = pending.pop()
await async_process_discovery_payload( await async_process_discovery_payload(
component, discovery_id, payload component, discovery_id, payload
) )
if discovery_hash not in hass.data[PENDING_DISCOVERED]: if discovery_hash not in mqtt_data.discovery_pending_discovered:
hass.data[PENDING_DISCOVERED][discovery_hash] = { mqtt_data.discovery_pending_discovered[discovery_hash] = {
"unsub": async_dispatcher_connect( "unsub": async_dispatcher_connect(
hass, hass,
MQTT_DISCOVERY_DONE.format(discovery_hash), MQTT_DISCOVERY_DONE.format(discovery_hash),
@ -216,7 +218,7 @@ async def async_start( # noqa: C901
"pending": deque([]), "pending": deque([]),
} }
if discovery_hash in hass.data[ALREADY_DISCOVERED]: if discovery_hash in mqtt_data.discovery_already_discovered:
# Dispatch update # Dispatch update
_LOGGER.info( _LOGGER.info(
"Component has already been discovered: %s %s, sending update", "Component has already been discovered: %s %s, sending update",
@ -229,7 +231,7 @@ async def async_start( # noqa: C901
elif payload: elif payload:
# Add component # Add component
_LOGGER.info("Found new component: %s %s", component, discovery_id) _LOGGER.info("Found new component: %s %s", component, discovery_id)
hass.data[ALREADY_DISCOVERED][discovery_hash] = None mqtt_data.discovery_already_discovered.add(discovery_hash)
async_dispatcher_send( async_dispatcher_send(
hass, MQTT_DISCOVERY_NEW.format(component, "mqtt"), payload hass, MQTT_DISCOVERY_NEW.format(component, "mqtt"), payload
) )
@ -239,15 +241,11 @@ async def async_start( # noqa: C901
hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
) )
hass.data.setdefault(DATA_CONFIG_FLOW_LOCK, asyncio.Lock())
hass.data[ALREADY_DISCOVERED] = {}
hass.data[PENDING_DISCOVERED] = {}
discovery_topics = [ discovery_topics = [
f"{discovery_topic}/+/+/config", f"{discovery_topic}/+/+/config",
f"{discovery_topic}/+/+/+/config", f"{discovery_topic}/+/+/+/config",
] ]
hass.data[DISCOVERY_UNSUBSCRIBE] = await asyncio.gather( mqtt_data.discovery_unsubscribe = await asyncio.gather(
*( *(
mqtt.async_subscribe(hass, topic, async_discovery_message_received, 0) mqtt.async_subscribe(hass, topic, async_discovery_message_received, 0)
for topic in discovery_topics for topic in discovery_topics
@ -257,19 +255,20 @@ async def async_start( # noqa: C901
mqtt_data.last_discovery = time.time() mqtt_data.last_discovery = time.time()
mqtt_integrations = await async_get_mqtt(hass) mqtt_integrations = await async_get_mqtt(hass)
hass.data[INTEGRATION_UNSUBSCRIBE] = {}
for (integration, topics) in mqtt_integrations.items(): for (integration, topics) in mqtt_integrations.items():
async def async_integration_message_received(integration, msg): async def async_integration_message_received(
integration: str, msg: ReceiveMessage
) -> None:
"""Process the received message.""" """Process the received message."""
assert mqtt_data.data_config_flow_lock
key = f"{integration}_{msg.subscribed_topic}" key = f"{integration}_{msg.subscribed_topic}"
# Lock to prevent initiating many parallel config flows. # Lock to prevent initiating many parallel config flows.
# Note: The lock is not intended to prevent a race, only for performance # Note: The lock is not intended to prevent a race, only for performance
async with hass.data[DATA_CONFIG_FLOW_LOCK]: async with mqtt_data.data_config_flow_lock:
# Already unsubscribed # Already unsubscribed
if key not in hass.data[INTEGRATION_UNSUBSCRIBE]: if key not in mqtt_data.integration_unsubscribe:
return return
data = MqttServiceInfo( data = MqttServiceInfo(
@ -289,14 +288,14 @@ async def async_start( # noqa: C901
and result["reason"] and result["reason"]
in ("already_configured", "single_instance_allowed") in ("already_configured", "single_instance_allowed")
): ):
unsub = hass.data[INTEGRATION_UNSUBSCRIBE].pop(key, None) unsub = mqtt_data.integration_unsubscribe.pop(key, None)
if unsub is None: if unsub is None:
return return
unsub() unsub()
for topic in topics: for topic in topics:
key = f"{integration}_{topic}" key = f"{integration}_{topic}"
hass.data[INTEGRATION_UNSUBSCRIBE][key] = await mqtt.async_subscribe( mqtt_data.integration_unsubscribe[key] = await mqtt.async_subscribe(
hass, hass,
topic, topic,
functools.partial(async_integration_message_received, integration), functools.partial(async_integration_message_received, integration),
@ -306,11 +305,10 @@ async def async_start( # noqa: C901
async def async_stop(hass: HomeAssistant) -> None: async def async_stop(hass: HomeAssistant) -> None:
"""Stop MQTT Discovery.""" """Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data: mqtt_data = get_mqtt_data(hass)
for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]: for unsub in mqtt_data.discovery_unsubscribe:
unsub() unsub()
hass.data[DISCOVERY_UNSUBSCRIBE] = [] mqtt_data.discovery_unsubscribe = []
if INTEGRATION_UNSUBSCRIBE in hass.data: for key, unsub in list(mqtt_data.integration_unsubscribe.items()):
for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()):
unsub() unsub()
hass.data[INTEGRATION_UNSUBSCRIBE].pop(key) mqtt_data.integration_unsubscribe.pop(key)

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from ast import literal_eval from ast import literal_eval
import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -22,6 +23,8 @@ if TYPE_CHECKING:
from .client import MQTT, Subscription from .client import MQTT, Subscription
from .debug_info import TimestampedPublishMessage from .debug_info import TimestampedPublishMessage
from .device_trigger import Trigger from .device_trigger import Trigger
from .discovery import MQTTConfig
from .tag import MQTTTagScanner
_SENTINEL = object() _SENTINEL = object()
@ -80,6 +83,13 @@ class TriggerDebugInfo(TypedDict):
discovery_data: DiscoveryInfoType discovery_data: DiscoveryInfoType
class PendingDiscovered(TypedDict):
"""Pending discovered items."""
pending: deque[MQTTConfig]
unsub: CALLBACK_TYPE
class MqttCommandTemplate: class MqttCommandTemplate:
"""Class for rendering MQTT payload with command templates.""" """Class for rendering MQTT payload with command templates."""
@ -237,9 +247,16 @@ class MqttData:
default_factory=dict default_factory=dict
) )
device_triggers: dict[str, Trigger] = field(default_factory=dict) device_triggers: dict[str, Trigger] = field(default_factory=dict)
data_config_flow_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
discovery_already_discovered: set[tuple[str, str]] = field(default_factory=set)
discovery_pending_discovered: dict[tuple[str, str], PendingDiscovered] = field(
default_factory=dict
)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict default_factory=dict
) )
discovery_unsubscribe: list[CALLBACK_TYPE] = field(default_factory=list)
integration_unsubscribe: dict[str, CALLBACK_TYPE] = field(default_factory=dict)
last_discovery: float = 0.0 last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False reload_entry: bool = False
@ -248,4 +265,5 @@ class MqttData:
) )
reload_needed: bool = False reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list) subscriptions_to_restore: list[Subscription] = field(default_factory=list)
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
updated_config: ConfigType = field(default_factory=dict) updated_config: ConfigType = field(default_factory=dict)

View file

@ -23,12 +23,11 @@ from .mixins import (
) )
from .models import MqttValueTemplate, ReceiveMessage from .models import MqttValueTemplate, ReceiveMessage
from .subscription import EntitySubscription from .subscription import EntitySubscription
from .util import valid_subscribe_topic from .util import get_mqtt_data, valid_subscribe_topic
LOG_NAME = "Tag" LOG_NAME = "Tag"
TAG = "tag" TAG = "tag"
TAGS = "mqtt_tags"
PLATFORM_SCHEMA = MQTT_BASE_SCHEMA.extend( PLATFORM_SCHEMA = MQTT_BASE_SCHEMA.extend(
{ {
@ -59,9 +58,8 @@ async def _async_setup_tag(
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
device_id = update_device(hass, config_entry, config) device_id = update_device(hass, config_entry, config)
hass.data.setdefault(TAGS, {}) if device_id is not None and device_id not in (tags := get_mqtt_data(hass).tags):
if device_id not in hass.data[TAGS]: tags[device_id] = {}
hass.data[TAGS][device_id] = {}
tag_scanner = MQTTTagScanner( tag_scanner = MQTTTagScanner(
hass, hass,
@ -74,16 +72,16 @@ async def _async_setup_tag(
await tag_scanner.subscribe_topics() await tag_scanner.subscribe_topics()
if device_id: if device_id:
hass.data[TAGS][device_id][discovery_id] = tag_scanner tags[device_id][discovery_id] = tag_scanner
send_discovery_done(hass, discovery_data) send_discovery_done(hass, discovery_data)
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool: def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
"""Device has tag scanners.""" """Device has tag scanners."""
if TAGS not in hass.data or device_id not in hass.data[TAGS]: if device_id not in (tags := get_mqtt_data(hass).tags):
return False return False
return hass.data[TAGS][device_id] != {} return tags[device_id] != {}
class MQTTTagScanner(MqttDiscoveryDeviceUpdate): class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
@ -159,4 +157,4 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
self.hass, self._sub_state self.hass, self._sub_state
) )
if self.device_id: if self.device_id:
self.hass.data[TAGS][self.device_id].pop(discovery_id) get_mqtt_data(self.hass).tags[self.device_id].pop(discovery_id)

View file

@ -133,9 +133,7 @@ class MQTTRoomSensor(SensorEntity):
): ):
update_state(**device) update_state(**device)
return await mqtt.async_subscribe( await mqtt.async_subscribe(self.hass, self._state_topic, message_received, 1)
self.hass, self._state_topic, message_received, 1
)
@property @property
def name(self): def name(self):

View file

@ -186,7 +186,18 @@ async def test_manual_config_set(
"discovery": True, "discovery": True,
} }
# Check we tried the connection, with precedence for config entry settings # Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with(hass, "127.0.0.1", 1883, None, None) mock_try_connection.assert_called_once_with(
{
"broker": "bla",
"keepalive": 60,
"discovery_prefix": "homeassistant",
"protocol": "3.1.1",
},
"127.0.0.1",
1883,
None,
None,
)
# Check config entry got setup # Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1 assert len(mock_finish_setup.mock_calls) == 1
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]

View file

@ -6,7 +6,6 @@ import pytest
from homeassistant.components import device_tracker, mqtt from homeassistant.components import device_tracker, mqtt
from homeassistant.components.mqtt.const import DOMAIN as MQTT_DOMAIN from homeassistant.components.mqtt.const import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED
from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNKNOWN, Platform from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNKNOWN, Platform
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -60,7 +59,7 @@ async def test_discover_device_tracker(hass, mqtt_mock_entry_no_yaml_config, cap
assert state is not None assert state is not None
assert state.name == "test" assert state.name == "test"
assert ("device_tracker", "bla") in hass.data[ALREADY_DISCOVERED] assert ("device_tracker", "bla") in hass.data["mqtt"].discovery_already_discovered
@pytest.mark.no_fail_on_log_exception @pytest.mark.no_fail_on_log_exception

View file

@ -13,7 +13,7 @@ from homeassistant.components.mqtt.abbreviations import (
ABBREVIATIONS, ABBREVIATIONS,
DEVICE_ABBREVIATIONS, DEVICE_ABBREVIATIONS,
) )
from homeassistant.components.mqtt.discovery import ALREADY_DISCOVERED, async_start from homeassistant.components.mqtt.discovery import async_start
from homeassistant.const import ( from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
STATE_ON, STATE_ON,
@ -152,7 +152,7 @@ async def test_correct_config_discovery(hass, mqtt_mock_entry_no_yaml_config, ca
assert state is not None assert state is not None
assert state.name == "Beer" assert state.name == "Beer"
assert ("binary_sensor", "bla") in hass.data[ALREADY_DISCOVERED] assert ("binary_sensor", "bla") in hass.data["mqtt"].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.FAN]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.FAN])
@ -170,7 +170,7 @@ async def test_discover_fan(hass, mqtt_mock_entry_no_yaml_config, caplog):
assert state is not None assert state is not None
assert state.name == "Beer" assert state.name == "Beer"
assert ("fan", "bla") in hass.data[ALREADY_DISCOVERED] assert ("fan", "bla") in hass.data["mqtt"].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.CLIMATE]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.CLIMATE])
@ -190,7 +190,7 @@ async def test_discover_climate(hass, mqtt_mock_entry_no_yaml_config, caplog):
assert state is not None assert state is not None
assert state.name == "ClimateTest" assert state.name == "ClimateTest"
assert ("climate", "bla") in hass.data[ALREADY_DISCOVERED] assert ("climate", "bla") in hass.data["mqtt"].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.ALARM_CONTROL_PANEL]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.ALARM_CONTROL_PANEL])
@ -212,7 +212,9 @@ async def test_discover_alarm_control_panel(
assert state is not None assert state is not None
assert state.name == "AlarmControlPanelTest" assert state.name == "AlarmControlPanelTest"
assert ("alarm_control_panel", "bla") in hass.data[ALREADY_DISCOVERED] assert ("alarm_control_panel", "bla") in hass.data[
"mqtt"
].discovery_already_discovered
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -372,7 +374,7 @@ async def test_discovery_with_object_id(
assert state is not None assert state is not None
assert state.name == name assert state.name == name
assert (domain, "object bla") in hass.data[ALREADY_DISCOVERED] assert (domain, "object bla") in hass.data["mqtt"].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.BINARY_SENSOR]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.BINARY_SENSOR])
@ -390,7 +392,9 @@ async def test_discovery_incl_nodeid(hass, mqtt_mock_entry_no_yaml_config, caplo
assert state is not None assert state is not None
assert state.name == "Beer" assert state.name == "Beer"
assert ("binary_sensor", "my_node_id bla") in hass.data[ALREADY_DISCOVERED] assert ("binary_sensor", "my_node_id bla") in hass.data[
"mqtt"
].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.BINARY_SENSOR]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.BINARY_SENSOR])
@ -970,7 +974,7 @@ async def test_discovery_expansion(hass, mqtt_mock_entry_no_yaml_config, caplog)
state = hass.states.get("switch.DiscoveryExpansionTest1") state = hass.states.get("switch.DiscoveryExpansionTest1")
assert state is not None assert state is not None
assert state.name == "DiscoveryExpansionTest1" assert state.name == "DiscoveryExpansionTest1"
assert ("switch", "bla") in hass.data[ALREADY_DISCOVERED] assert ("switch", "bla") in hass.data["mqtt"].discovery_already_discovered
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(hass, "test_topic/some/base/topic", "ON") async_fire_mqtt_message(hass, "test_topic/some/base/topic", "ON")
@ -1023,7 +1027,7 @@ async def test_discovery_expansion_2(hass, mqtt_mock_entry_no_yaml_config, caplo
state = hass.states.get("switch.DiscoveryExpansionTest1") state = hass.states.get("switch.DiscoveryExpansionTest1")
assert state is not None assert state is not None
assert state.name == "DiscoveryExpansionTest1" assert state.name == "DiscoveryExpansionTest1"
assert ("switch", "bla") in hass.data[ALREADY_DISCOVERED] assert ("switch", "bla") in hass.data["mqtt"].discovery_already_discovered
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
@ -1102,7 +1106,7 @@ async def test_discovery_expansion_without_encoding_and_value_template_1(
state = hass.states.get("switch.DiscoveryExpansionTest1") state = hass.states.get("switch.DiscoveryExpansionTest1")
assert state is not None assert state is not None
assert state.name == "DiscoveryExpansionTest1" assert state.name == "DiscoveryExpansionTest1"
assert ("switch", "bla") in hass.data[ALREADY_DISCOVERED] assert ("switch", "bla") in hass.data["mqtt"].discovery_already_discovered
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(hass, "some/base/topic/avail_item1", b"\x00") async_fire_mqtt_message(hass, "some/base/topic/avail_item1", b"\x00")
@ -1151,7 +1155,7 @@ async def test_discovery_expansion_without_encoding_and_value_template_2(
state = hass.states.get("switch.DiscoveryExpansionTest1") state = hass.states.get("switch.DiscoveryExpansionTest1")
assert state is not None assert state is not None
assert state.name == "DiscoveryExpansionTest1" assert state.name == "DiscoveryExpansionTest1"
assert ("switch", "bla") in hass.data[ALREADY_DISCOVERED] assert ("switch", "bla") in hass.data["mqtt"].discovery_already_discovered
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
async_fire_mqtt_message(hass, "some/base/topic/avail_item1", b"\x00") async_fire_mqtt_message(hass, "some/base/topic/avail_item1", b"\x00")
@ -1236,7 +1240,7 @@ async def test_no_implicit_state_topic_switch(
state = hass.states.get("switch.Test1") state = hass.states.get("switch.Test1")
assert state is not None assert state is not None
assert state.name == "Test1" assert state.name == "Test1"
assert ("switch", "bla") in hass.data[ALREADY_DISCOVERED] assert ("switch", "bla") in hass.data["mqtt"].discovery_already_discovered
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
assert state.attributes["assumed_state"] is True assert state.attributes["assumed_state"] is True
@ -1280,7 +1284,9 @@ async def test_complex_discovery_topic_prefix(
assert state is not None assert state is not None
assert state.name == "Beer" assert state.name == "Beer"
assert ("binary_sensor", "node1 object1") in hass.data[ALREADY_DISCOVERED] assert ("binary_sensor", "node1 object1") in hass.data[
"mqtt"
].discovery_already_discovered
@patch("homeassistant.components.mqtt.PLATFORMS", []) @patch("homeassistant.components.mqtt.PLATFORMS", [])