Refactor MQTT to replace get_mqtt_data with HassKey (#117899)

This commit is contained in:
J. Nick Koston 2024-05-21 23:21:51 -10:00 committed by GitHub
parent b4d0562063
commit 4e3c4400a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 74 additions and 74 deletions

View file

@ -65,8 +65,6 @@ from .const import ( # noqa: F401
CONF_WILL_MESSAGE,
CONF_WS_HEADERS,
CONF_WS_PATH,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_DISCOVERY,
DEFAULT_ENCODING,
DEFAULT_PREFIX,
@ -79,6 +77,8 @@ from .const import ( # noqa: F401
TEMPLATE_ERRORS,
)
from .models import ( # noqa: F401
DATA_MQTT,
DATA_MQTT_AVAILABLE,
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
@ -97,7 +97,6 @@ from .util import ( # noqa: F401
async_create_certificate_temp_files,
async_forward_entry_setup_and_setup_discovery,
async_wait_for_mqtt_client,
get_mqtt_data,
mqtt_config_entry_enabled,
platforms_from_config,
valid_publish_topic,
@ -194,7 +193,7 @@ async def async_check_config_schema(
hass: HomeAssistant, config_yaml: ConfigType
) -> None:
"""Validate manually configured MQTT items."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {})
for mqtt_config_item in mqtt_config:
for domain, config_items in mqtt_config_item.items():
@ -233,7 +232,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await async_create_certificate_temp_files(hass, conf)
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
@ -241,7 +240,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
get_mqtt_data.cache_clear()
client.start(mqtt_data)
# Restore saved subscriptions
@ -503,7 +501,7 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data.client.connected
@ -520,7 +518,7 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
mqtt_client = mqtt_data.client
# Unload publish and dump services.

View file

@ -66,6 +66,7 @@ from .const import (
TRANSPORT_WEBSOCKETS,
)
from .models import (
DATA_MQTT,
AsyncMessageCallbackType,
MessageCallbackType,
MqttData,
@ -73,7 +74,7 @@ from .models import (
PublishPayloadType,
ReceiveMessage,
)
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
from .util import get_file_path, mqtt_config_entry_enabled
if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally
@ -132,7 +133,7 @@ async def async_publish(
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
@ -186,7 +187,7 @@ async def async_subscribe(
translation_placeholders={"topic": topic},
)
try:
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc:
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', "

View file

@ -86,9 +86,6 @@ CONF_CONFIGURATION_URL = "configuration_url"
CONF_OBJECT_ID = "object_id"
CONF_SUPPORT_URL = "support_url"
DATA_MQTT = "mqtt"
DATA_MQTT_AVAILABLE = "mqtt_client_available"
DEFAULT_PREFIX = "homeassistant"
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"
DEFAULT_DISCOVERY = True

View file

@ -16,8 +16,7 @@ from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.util import dt as dt_util
from .const import ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC
from .models import MessageCallbackType, PublishPayloadType
from .util import get_mqtt_data
from .models import DATA_MQTT, MessageCallbackType, PublishPayloadType
STORED_MESSAGES = 10
@ -27,7 +26,7 @@ def log_messages(
) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to support message logging."""
debug_info_entities = get_mqtt_data(hass).debug_info_entities
debug_info_entities = hass.data[DATA_MQTT].debug_info_entities
def _log_message(msg: Any) -> None:
"""Log message."""
@ -70,7 +69,7 @@ def log_message(
retain: bool,
) -> None:
"""Log an outgoing MQTT message."""
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
)
if topic not in entity_info["transmitted"]:
@ -90,7 +89,7 @@ def add_subscription(
) -> None:
"""Prepare debug data for subscription."""
if entity_id := getattr(message_callback, "__entity_id", None):
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
)
if subscription not in entity_info["subscriptions"]:
@ -108,7 +107,7 @@ def remove_subscription(
) -> None:
"""Remove debug data for subscription if it exists."""
if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in (
debug_info_entities := get_mqtt_data(hass).debug_info_entities
debug_info_entities := hass.data[DATA_MQTT].debug_info_entities
):
debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1
if not debug_info_entities[entity_id]["subscriptions"][subscription]["count"]:
@ -119,7 +118,7 @@ def add_entity_discovery_data(
hass: HomeAssistant, discovery_data: DiscoveryInfoType, entity_id: str
) -> None:
"""Add discovery data."""
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
)
entity_info["discovery_data"] = discovery_data
@ -129,7 +128,7 @@ def update_entity_discovery_data(
hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str
) -> None:
"""Update discovery data."""
discovery_data = get_mqtt_data(hass).debug_info_entities[entity_id][
discovery_data = hass.data[DATA_MQTT].debug_info_entities[entity_id][
"discovery_data"
]
if TYPE_CHECKING:
@ -139,7 +138,7 @@ def update_entity_discovery_data(
def remove_entity_data(hass: HomeAssistant, entity_id: str) -> None:
"""Remove discovery data."""
if entity_id in (debug_info_entities := get_mqtt_data(hass).debug_info_entities):
if entity_id in (debug_info_entities := hass.data[DATA_MQTT].debug_info_entities):
debug_info_entities.pop(entity_id)
@ -150,7 +149,7 @@ def add_trigger_discovery_data(
device_id: str,
) -> None:
"""Add discovery data."""
get_mqtt_data(hass).debug_info_triggers[discovery_hash] = {
hass.data[DATA_MQTT].debug_info_triggers[discovery_hash] = {
"device_id": device_id,
"discovery_data": discovery_data,
}
@ -162,7 +161,7 @@ def update_trigger_discovery_data(
discovery_payload: DiscoveryInfoType,
) -> None:
"""Update discovery data."""
get_mqtt_data(hass).debug_info_triggers[discovery_hash]["discovery_data"][
hass.data[DATA_MQTT].debug_info_triggers[discovery_hash]["discovery_data"][
ATTR_DISCOVERY_PAYLOAD
] = discovery_payload
@ -171,11 +170,11 @@ def remove_trigger_discovery_data(
hass: HomeAssistant, discovery_hash: tuple[str, str]
) -> None:
"""Remove discovery data."""
get_mqtt_data(hass).debug_info_triggers.pop(discovery_hash)
hass.data[DATA_MQTT].debug_info_triggers.pop(discovery_hash)
def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
entity_info = get_mqtt_data(hass).debug_info_entities[entity_id]
entity_info = hass.data[DATA_MQTT].debug_info_entities[entity_id]
monotonic_time_diff = time.time() - time.monotonic()
subscriptions = [
{
@ -231,7 +230,7 @@ def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
def _info_for_trigger(
hass: HomeAssistant, trigger_key: tuple[str, str]
) -> dict[str, Any]:
trigger = get_mqtt_data(hass).debug_info_triggers[trigger_key]
trigger = hass.data[DATA_MQTT].debug_info_triggers[trigger_key]
discovery_data = None
if trigger["discovery_data"] is not None:
discovery_data = {
@ -244,7 +243,7 @@ def _info_for_trigger(
def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
"""Get debug info for all entities and triggers."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
mqtt_info["entities"].extend(
@ -262,7 +261,7 @@ def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
def info_for_device(hass: HomeAssistant, device_id: str) -> dict[str, list[Any]]:
"""Get debug info for a device."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
entity_registry = er.async_get(hass)

View file

@ -42,7 +42,7 @@ from .mixins import (
send_discovery_done,
update_device,
)
from .util import get_mqtt_data
from .models import DATA_MQTT
_LOGGER = logging.getLogger(__name__)
@ -206,7 +206,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id
self.discovery_data = discovery_data
self.hass = hass
self._mqtt_data = get_mqtt_data(hass)
self._mqtt_data = hass.data[DATA_MQTT]
self.trigger_id = f"{device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
MqttDiscoveryDeviceUpdate.__init__(
@ -259,7 +259,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
new_trigger_id = f"{self.device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
if new_trigger_id != self.trigger_id:
mqtt_data = get_mqtt_data(self.hass)
mqtt_data = self.hass.data[DATA_MQTT]
if new_trigger_id in mqtt_data.device_triggers:
_LOGGER.error(
"Cannot update device trigger %s due to an existing duplicate "
@ -308,7 +308,7 @@ async def async_setup_trigger(
trigger_type = config[CONF_TYPE]
trigger_subtype = config[CONF_SUBTYPE]
trigger_id = f"{device_id}_{trigger_type}_{trigger_subtype}"
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
if (
trigger_id in mqtt_data.device_triggers
and mqtt_data.device_triggers[trigger_id].discovery_data is not None
@ -334,7 +334,7 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
triggers = await async_get_triggers(hass, device_id)
for trig in triggers:
trigger_id = f"{device_id}_{trig[CONF_TYPE]}_{trig[CONF_SUBTYPE]}"
@ -352,7 +352,7 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device triggers for MQTT devices."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
if not mqtt_data.device_triggers:
return []
@ -377,7 +377,7 @@ async def async_attach_trigger(
) -> CALLBACK_TYPE:
"""Attach a trigger."""
trigger_id: str | None = None
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
device_id = config[CONF_DEVICE_ID]
# The use of CONF_DISCOVERY_ID was deprecated in HA Core 2024.2.

View file

@ -18,7 +18,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceEntry
from . import debug_info, is_connected
from .util import get_mqtt_data
from .models import DATA_MQTT
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
@ -45,7 +45,7 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None,
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
mqtt_instance = get_mqtt_data(hass).client
mqtt_instance = hass.data[DATA_MQTT].client
if TYPE_CHECKING:
assert mqtt_instance is not None

View file

@ -40,8 +40,8 @@ from .const import (
CONF_TOPIC,
DOMAIN,
)
from .models import MqttOriginInfo, ReceiveMessage
from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data
from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
from .util import async_forward_entry_setup_and_setup_discovery
_LOGGER = logging.getLogger(__name__)
@ -113,12 +113,12 @@ class MQTTDiscoveryPayload(dict[str, Any]):
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list."""
get_mqtt_data(hass).discovery_already_discovered.remove(discovery_hash)
hass.data[DATA_MQTT].discovery_already_discovered.remove(discovery_hash)
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Add entry to already discovered list."""
get_mqtt_data(hass).discovery_already_discovered.add(discovery_hash)
hass.data[DATA_MQTT].discovery_already_discovered.add(discovery_hash)
@callback
@ -150,7 +150,7 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None:
"""Start MQTT Discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {}
async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None:
@ -426,7 +426,7 @@ async def async_start( # noqa: C901
async def async_stop(hass: HomeAssistant) -> None:
"""Stop MQTT Discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
for unsub in mqtt_data.discovery_unsubscribe:
unsub()
mqtt_data.discovery_unsubscribe = []

View file

@ -38,13 +38,13 @@ from .mixins import (
async_setup_entity_entry_helper,
)
from .models import (
DATA_MQTT,
MqttValueTemplate,
MqttValueTemplateException,
PayloadSentinel,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__)
@ -194,7 +194,7 @@ class MqttEvent(MqttEntity, EventEntity):
payload,
)
return
mqtt_data = get_mqtt_data(self.hass)
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
topics["state_topic"] = {

View file

@ -33,12 +33,13 @@ from .mixins import (
async_setup_entity_entry_helper,
)
from .models import (
DATA_MQTT,
MessageCallbackType,
MqttValueTemplate,
MqttValueTemplateException,
ReceiveMessage,
)
from .util import get_mqtt_data, valid_subscribe_topic
from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -186,7 +187,7 @@ class MqttImage(MqttEntity, ImageEntity):
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@ -208,7 +209,7 @@ class MqttImage(MqttEntity, ImageEntity):
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)

View file

@ -106,6 +106,7 @@ from .discovery import (
set_discovery_hash,
)
from .models import (
DATA_MQTT,
MessageCallbackType,
MqttValueTemplate,
MqttValueTemplateException,
@ -118,7 +119,7 @@ from .subscription import (
async_subscribe_topics,
async_unsubscribe_topics,
)
from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -329,7 +330,7 @@ async def async_setup_non_entity_entry_helper(
discovery_schema: vol.Schema,
) -> None:
"""Set up automation or tag creation dynamically through MQTT discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
async def async_setup_from_discovery(
discovery_payload: MQTTDiscoveryPayload,
@ -360,7 +361,7 @@ async def async_setup_entity_entry_helper(
schema_class_mapping: dict[str, type[MqttEntity]] | None = None,
) -> None:
"""Set up entity creation dynamically through MQTT discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
@callback
def async_setup_from_discovery(
@ -391,7 +392,7 @@ async def async_setup_entity_entry_helper(
def _async_setup_entities() -> None:
"""Set up MQTT items from configuration.yaml."""
nonlocal entity_class
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
if not (config_yaml := mqtt_data.config):
return
yaml_configs: list[ConfigType] = [
@ -496,7 +497,7 @@ def write_state_on_attr_change(
if not _attrs_have_changed(tracked_attrs):
return
mqtt_data = get_mqtt_data(entity.hass)
mqtt_data = entity.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(entity)
return wrapper
@ -695,7 +696,7 @@ class MqttAvailability(Entity):
@property
def available(self) -> bool:
"""Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass)
mqtt_data = self.hass.data[DATA_MQTT]
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False
@ -936,7 +937,7 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False
if discovery_data is None:
return
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks:

View file

@ -20,6 +20,7 @@ from homeassistant.helpers import template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from paho.mqtt.client import MQTTMessage
@ -419,3 +420,7 @@ class MqttData:
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")

View file

@ -28,13 +28,14 @@ from .mixins import (
update_device,
)
from .models import (
DATA_MQTT,
MqttValueTemplate,
MqttValueTemplateException,
ReceiveMessage,
ReceivePayloadType,
)
from .subscription import EntitySubscription
from .util import get_mqtt_data, valid_subscribe_topic
from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -70,7 +71,7 @@ async def _async_setup_tag(
discovery_id = discovery_hash[1]
device_id = update_device(hass, config_entry, config)
if device_id is not None and device_id not in (tags := get_mqtt_data(hass).tags):
if device_id is not None and device_id not in (tags := hass.data[DATA_MQTT].tags):
tags[device_id] = {}
tag_scanner = MQTTTagScanner(
@ -91,7 +92,7 @@ async def _async_setup_tag(
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
"""Device has tag scanners."""
if device_id not in (tags := get_mqtt_data(hass).tags):
if device_id not in (tags := hass.data[DATA_MQTT].tags):
return False
return tags[device_id] != {}
@ -176,4 +177,4 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
self.hass, self._sub_state
)
if self.device_id:
get_mqtt_data(self.hass).tags[self.device_id].pop(discovery_id)
self.hass.data[DATA_MQTT].tags[self.device_id].pop(discovery_id)

View file

@ -26,14 +26,12 @@ from .const import (
CONF_CERTIFICATE,
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_ENCODING,
DEFAULT_QOS,
DEFAULT_RETAIN,
DOMAIN,
)
from .models import MqttData
from .models import DATA_MQTT, DATA_MQTT_AVAILABLE
AVAILABILITY_TIMEOUT = 30.0
@ -51,7 +49,7 @@ async def async_forward_entry_setup_and_setup_discovery(
hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str]
) -> None:
"""Forward the config entry setup to the platforms and set up discovery."""
mqtt_data = get_mqtt_data(hass)
mqtt_data = hass.data[DATA_MQTT]
platforms_loaded = mqtt_data.platforms_loaded
new_platforms: set[Platform | str] = platforms - platforms_loaded
tasks: list[asyncio.Task] = []
@ -85,7 +83,11 @@ async def async_forward_entry_setup_and_setup_discovery(
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
"""Return true when the MQTT config entry is enabled."""
return hass.config_entries.async_has_entries(
# If the mqtt client is connected, skip the expensive config
# entry check as its roughly two orders of magnitude faster.
return (
DATA_MQTT in hass.data and hass.data[DATA_MQTT].client.connected
) or hass.config_entries.async_has_entries(
DOMAIN, include_disabled=False, include_ignore=False
)
@ -229,13 +231,6 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config
@lru_cache(maxsize=1)
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
return mqtt_data
async def async_create_certificate_temp_files(
hass: HomeAssistant, config: ConfigType
) -> None:

View file

@ -43,7 +43,7 @@ async def setup_comp(
async def test_setup_fails_without_mqtt_being_setup(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, caplog: pytest.LogCaptureFixture
) -> None:
"""Ensure mqtt is started when we setup the component."""
# Simulate MQTT is was removed
@ -52,6 +52,8 @@ async def test_setup_fails_without_mqtt_being_setup(
await hass.config_entries.async_set_disabled_by(
mqtt_entry.entry_id, ConfigEntryDisabler.USER
)
# mqtt is mocked so we need to simulate it is not connected
mqtt_mock.connected = False
dev_id = "zanzito"
topic = "location/zanzito"