Refactor MQTT to replace get_mqtt_data with HassKey (#117899)

This commit is contained in:
J. Nick Koston 2024-05-21 23:21:51 -10:00 committed by GitHub
parent b4d0562063
commit 4e3c4400a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 74 additions and 74 deletions

View file

@ -65,8 +65,6 @@ from .const import ( # noqa: F401
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
CONF_WS_HEADERS, CONF_WS_HEADERS,
CONF_WS_PATH, CONF_WS_PATH,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_PREFIX, DEFAULT_PREFIX,
@ -79,6 +77,8 @@ from .const import ( # noqa: F401
TEMPLATE_ERRORS, TEMPLATE_ERRORS,
) )
from .models import ( # noqa: F401 from .models import ( # noqa: F401
DATA_MQTT,
DATA_MQTT_AVAILABLE,
MqttCommandTemplate, MqttCommandTemplate,
MqttData, MqttData,
MqttValueTemplate, MqttValueTemplate,
@ -97,7 +97,6 @@ from .util import ( # noqa: F401
async_create_certificate_temp_files, async_create_certificate_temp_files,
async_forward_entry_setup_and_setup_discovery, async_forward_entry_setup_and_setup_discovery,
async_wait_for_mqtt_client, async_wait_for_mqtt_client,
get_mqtt_data,
mqtt_config_entry_enabled, mqtt_config_entry_enabled,
platforms_from_config, platforms_from_config,
valid_publish_topic, valid_publish_topic,
@ -194,7 +193,7 @@ async def async_check_config_schema(
hass: HomeAssistant, config_yaml: ConfigType hass: HomeAssistant, config_yaml: ConfigType
) -> None: ) -> None:
"""Validate manually configured MQTT items.""" """Validate manually configured MQTT items."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {}) mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {})
for mqtt_config_item in mqtt_config: for mqtt_config_item in mqtt_config:
for domain, config_items in mqtt_config_item.items(): for domain, config_items in mqtt_config_item.items():
@ -233,7 +232,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await async_create_certificate_temp_files(hass, conf) await async_create_certificate_temp_files(hass, conf)
client = MQTT(hass, entry, conf) client = MQTT(hass, entry, conf)
if DOMAIN in hass.data: if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
mqtt_data.config = mqtt_yaml mqtt_data.config = mqtt_yaml
mqtt_data.client = client mqtt_data.client = client
else: else:
@ -241,7 +240,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
websocket_api.async_register_command(hass, websocket_subscribe) websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info) websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client) hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
get_mqtt_data.cache_clear()
client.start(mqtt_data) client.start(mqtt_data)
# Restore saved subscriptions # Restore saved subscriptions
@ -503,7 +501,7 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool: def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected.""" """Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
return mqtt_data.client.connected return mqtt_data.client.connected
@ -520,7 +518,7 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded.""" """Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
mqtt_client = mqtt_data.client mqtt_client = mqtt_data.client
# Unload publish and dump services. # Unload publish and dump services.

View file

@ -66,6 +66,7 @@ from .const import (
TRANSPORT_WEBSOCKETS, TRANSPORT_WEBSOCKETS,
) )
from .models import ( from .models import (
DATA_MQTT,
AsyncMessageCallbackType, AsyncMessageCallbackType,
MessageCallbackType, MessageCallbackType,
MqttData, MqttData,
@ -73,7 +74,7 @@ from .models import (
PublishPayloadType, PublishPayloadType,
ReceiveMessage, ReceiveMessage,
) )
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled from .util import get_file_path, mqtt_config_entry_enabled
if TYPE_CHECKING: if TYPE_CHECKING:
# Only import for paho-mqtt type checking here, imports are done locally # Only import for paho-mqtt type checking here, imports are done locally
@ -132,7 +133,7 @@ async def async_publish(
translation_domain=DOMAIN, translation_domain=DOMAIN,
translation_placeholders={"topic": topic}, translation_placeholders={"topic": topic},
) )
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
outgoing_payload = payload outgoing_payload = payload
if not isinstance(payload, bytes): if not isinstance(payload, bytes):
if not encoding: if not encoding:
@ -186,7 +187,7 @@ async def async_subscribe(
translation_placeholders={"topic": topic}, translation_placeholders={"topic": topic},
) )
try: try:
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc: except KeyError as exc:
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', " f"Cannot subscribe to topic '{topic}', "

View file

@ -86,9 +86,6 @@ CONF_CONFIGURATION_URL = "configuration_url"
CONF_OBJECT_ID = "object_id" CONF_OBJECT_ID = "object_id"
CONF_SUPPORT_URL = "support_url" CONF_SUPPORT_URL = "support_url"
DATA_MQTT = "mqtt"
DATA_MQTT_AVAILABLE = "mqtt_client_available"
DEFAULT_PREFIX = "homeassistant" DEFAULT_PREFIX = "homeassistant"
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status" DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"
DEFAULT_DISCOVERY = True DEFAULT_DISCOVERY = True

View file

@ -16,8 +16,7 @@ from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC from .const import ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC
from .models import MessageCallbackType, PublishPayloadType from .models import DATA_MQTT, MessageCallbackType, PublishPayloadType
from .util import get_mqtt_data
STORED_MESSAGES = 10 STORED_MESSAGES = 10
@ -27,7 +26,7 @@ def log_messages(
) -> Callable[[MessageCallbackType], MessageCallbackType]: ) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to support message logging.""" """Wrap an MQTT message callback to support message logging."""
debug_info_entities = get_mqtt_data(hass).debug_info_entities debug_info_entities = hass.data[DATA_MQTT].debug_info_entities
def _log_message(msg: Any) -> None: def _log_message(msg: Any) -> None:
"""Log message.""" """Log message."""
@ -70,7 +69,7 @@ def log_message(
retain: bool, retain: bool,
) -> None: ) -> None:
"""Log an outgoing MQTT message.""" """Log an outgoing MQTT message."""
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
) )
if topic not in entity_info["transmitted"]: if topic not in entity_info["transmitted"]:
@ -90,7 +89,7 @@ def add_subscription(
) -> None: ) -> None:
"""Prepare debug data for subscription.""" """Prepare debug data for subscription."""
if entity_id := getattr(message_callback, "__entity_id", None): if entity_id := getattr(message_callback, "__entity_id", None):
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
) )
if subscription not in entity_info["subscriptions"]: if subscription not in entity_info["subscriptions"]:
@ -108,7 +107,7 @@ def remove_subscription(
) -> None: ) -> None:
"""Remove debug data for subscription if it exists.""" """Remove debug data for subscription if it exists."""
if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in ( if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in (
debug_info_entities := get_mqtt_data(hass).debug_info_entities debug_info_entities := hass.data[DATA_MQTT].debug_info_entities
): ):
debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1 debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1
if not debug_info_entities[entity_id]["subscriptions"][subscription]["count"]: if not debug_info_entities[entity_id]["subscriptions"][subscription]["count"]:
@ -119,7 +118,7 @@ def add_entity_discovery_data(
hass: HomeAssistant, discovery_data: DiscoveryInfoType, entity_id: str hass: HomeAssistant, discovery_data: DiscoveryInfoType, entity_id: str
) -> None: ) -> None:
"""Add discovery data.""" """Add discovery data."""
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
) )
entity_info["discovery_data"] = discovery_data entity_info["discovery_data"] = discovery_data
@ -129,7 +128,7 @@ def update_entity_discovery_data(
hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str
) -> None: ) -> None:
"""Update discovery data.""" """Update discovery data."""
discovery_data = get_mqtt_data(hass).debug_info_entities[entity_id][ discovery_data = hass.data[DATA_MQTT].debug_info_entities[entity_id][
"discovery_data" "discovery_data"
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
@ -139,7 +138,7 @@ def update_entity_discovery_data(
def remove_entity_data(hass: HomeAssistant, entity_id: str) -> None: def remove_entity_data(hass: HomeAssistant, entity_id: str) -> None:
"""Remove discovery data.""" """Remove discovery data."""
if entity_id in (debug_info_entities := get_mqtt_data(hass).debug_info_entities): if entity_id in (debug_info_entities := hass.data[DATA_MQTT].debug_info_entities):
debug_info_entities.pop(entity_id) debug_info_entities.pop(entity_id)
@ -150,7 +149,7 @@ def add_trigger_discovery_data(
device_id: str, device_id: str,
) -> None: ) -> None:
"""Add discovery data.""" """Add discovery data."""
get_mqtt_data(hass).debug_info_triggers[discovery_hash] = { hass.data[DATA_MQTT].debug_info_triggers[discovery_hash] = {
"device_id": device_id, "device_id": device_id,
"discovery_data": discovery_data, "discovery_data": discovery_data,
} }
@ -162,7 +161,7 @@ def update_trigger_discovery_data(
discovery_payload: DiscoveryInfoType, discovery_payload: DiscoveryInfoType,
) -> None: ) -> None:
"""Update discovery data.""" """Update discovery data."""
get_mqtt_data(hass).debug_info_triggers[discovery_hash]["discovery_data"][ hass.data[DATA_MQTT].debug_info_triggers[discovery_hash]["discovery_data"][
ATTR_DISCOVERY_PAYLOAD ATTR_DISCOVERY_PAYLOAD
] = discovery_payload ] = discovery_payload
@ -171,11 +170,11 @@ def remove_trigger_discovery_data(
hass: HomeAssistant, discovery_hash: tuple[str, str] hass: HomeAssistant, discovery_hash: tuple[str, str]
) -> None: ) -> None:
"""Remove discovery data.""" """Remove discovery data."""
get_mqtt_data(hass).debug_info_triggers.pop(discovery_hash) hass.data[DATA_MQTT].debug_info_triggers.pop(discovery_hash)
def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]: def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
entity_info = get_mqtt_data(hass).debug_info_entities[entity_id] entity_info = hass.data[DATA_MQTT].debug_info_entities[entity_id]
monotonic_time_diff = time.time() - time.monotonic() monotonic_time_diff = time.time() - time.monotonic()
subscriptions = [ subscriptions = [
{ {
@ -231,7 +230,7 @@ def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
def _info_for_trigger( def _info_for_trigger(
hass: HomeAssistant, trigger_key: tuple[str, str] hass: HomeAssistant, trigger_key: tuple[str, str]
) -> dict[str, Any]: ) -> dict[str, Any]:
trigger = get_mqtt_data(hass).debug_info_triggers[trigger_key] trigger = hass.data[DATA_MQTT].debug_info_triggers[trigger_key]
discovery_data = None discovery_data = None
if trigger["discovery_data"] is not None: if trigger["discovery_data"] is not None:
discovery_data = { discovery_data = {
@ -244,7 +243,7 @@ def _info_for_trigger(
def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]: def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
"""Get debug info for all entities and triggers.""" """Get debug info for all entities and triggers."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []} mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
mqtt_info["entities"].extend( mqtt_info["entities"].extend(
@ -262,7 +261,7 @@ def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
def info_for_device(hass: HomeAssistant, device_id: str) -> dict[str, list[Any]]: def info_for_device(hass: HomeAssistant, device_id: str) -> dict[str, list[Any]]:
"""Get debug info for a device.""" """Get debug info for a device."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []} mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)

View file

@ -42,7 +42,7 @@ from .mixins import (
send_discovery_done, send_discovery_done,
update_device, update_device,
) )
from .util import get_mqtt_data from .models import DATA_MQTT
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -206,7 +206,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id self.device_id = device_id
self.discovery_data = discovery_data self.discovery_data = discovery_data
self.hass = hass self.hass = hass
self._mqtt_data = get_mqtt_data(hass) self._mqtt_data = hass.data[DATA_MQTT]
self.trigger_id = f"{device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}" self.trigger_id = f"{device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
MqttDiscoveryDeviceUpdate.__init__( MqttDiscoveryDeviceUpdate.__init__(
@ -259,7 +259,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data) config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
new_trigger_id = f"{self.device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}" new_trigger_id = f"{self.device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
if new_trigger_id != self.trigger_id: if new_trigger_id != self.trigger_id:
mqtt_data = get_mqtt_data(self.hass) mqtt_data = self.hass.data[DATA_MQTT]
if new_trigger_id in mqtt_data.device_triggers: if new_trigger_id in mqtt_data.device_triggers:
_LOGGER.error( _LOGGER.error(
"Cannot update device trigger %s due to an existing duplicate " "Cannot update device trigger %s due to an existing duplicate "
@ -308,7 +308,7 @@ async def async_setup_trigger(
trigger_type = config[CONF_TYPE] trigger_type = config[CONF_TYPE]
trigger_subtype = config[CONF_SUBTYPE] trigger_subtype = config[CONF_SUBTYPE]
trigger_id = f"{device_id}_{trigger_type}_{trigger_subtype}" trigger_id = f"{device_id}_{trigger_type}_{trigger_subtype}"
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
if ( if (
trigger_id in mqtt_data.device_triggers trigger_id in mqtt_data.device_triggers
and mqtt_data.device_triggers[trigger_id].discovery_data is not None and mqtt_data.device_triggers[trigger_id].discovery_data is not None
@ -334,7 +334,7 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device.""" """Handle Mqtt removed from a device."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
triggers = await async_get_triggers(hass, device_id) triggers = await async_get_triggers(hass, device_id)
for trig in triggers: for trig in triggers:
trigger_id = f"{device_id}_{trig[CONF_TYPE]}_{trig[CONF_SUBTYPE]}" trigger_id = f"{device_id}_{trig[CONF_TYPE]}_{trig[CONF_SUBTYPE]}"
@ -352,7 +352,7 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""List device triggers for MQTT devices.""" """List device triggers for MQTT devices."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
if not mqtt_data.device_triggers: if not mqtt_data.device_triggers:
return [] return []
@ -377,7 +377,7 @@ async def async_attach_trigger(
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
trigger_id: str | None = None trigger_id: str | None = None
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
device_id = config[CONF_DEVICE_ID] device_id = config[CONF_DEVICE_ID]
# The use of CONF_DISCOVERY_ID was deprecated in HA Core 2024.2. # The use of CONF_DISCOVERY_ID was deprecated in HA Core 2024.2.

View file

@ -18,7 +18,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.device_registry import DeviceEntry
from . import debug_info, is_connected from . import debug_info, is_connected
from .util import get_mqtt_data from .models import DATA_MQTT
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME} REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE} REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
@ -45,7 +45,7 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None, device: DeviceEntry | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
mqtt_instance = get_mqtt_data(hass).client mqtt_instance = hass.data[DATA_MQTT].client
if TYPE_CHECKING: if TYPE_CHECKING:
assert mqtt_instance is not None assert mqtt_instance is not None

View file

@ -40,8 +40,8 @@ from .const import (
CONF_TOPIC, CONF_TOPIC,
DOMAIN, DOMAIN,
) )
from .models import MqttOriginInfo, ReceiveMessage from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data from .util import async_forward_entry_setup_and_setup_discovery
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -113,12 +113,12 @@ class MQTTDiscoveryPayload(dict[str, Any]):
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry from already discovered list.""" """Clear entry from already discovered list."""
get_mqtt_data(hass).discovery_already_discovered.remove(discovery_hash) hass.data[DATA_MQTT].discovery_already_discovered.remove(discovery_hash)
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Add entry to already discovered list.""" """Add entry to already discovered list."""
get_mqtt_data(hass).discovery_already_discovered.add(discovery_hash) hass.data[DATA_MQTT].discovery_already_discovered.add(discovery_hash)
@callback @callback
@ -150,7 +150,7 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {} platform_setup_lock: dict[str, asyncio.Lock] = {}
async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None: async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None:
@ -426,7 +426,7 @@ async def async_start( # noqa: C901
async def async_stop(hass: HomeAssistant) -> None: async def async_stop(hass: HomeAssistant) -> None:
"""Stop MQTT Discovery.""" """Stop MQTT Discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
for unsub in mqtt_data.discovery_unsubscribe: for unsub in mqtt_data.discovery_unsubscribe:
unsub() unsub()
mqtt_data.discovery_unsubscribe = [] mqtt_data.discovery_unsubscribe = []

View file

@ -38,13 +38,13 @@ from .mixins import (
async_setup_entity_entry_helper, async_setup_entity_entry_helper,
) )
from .models import ( from .models import (
DATA_MQTT,
MqttValueTemplate, MqttValueTemplate,
MqttValueTemplateException, MqttValueTemplateException,
PayloadSentinel, PayloadSentinel,
ReceiveMessage, ReceiveMessage,
ReceivePayloadType, ReceivePayloadType,
) )
from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -194,7 +194,7 @@ class MqttEvent(MqttEntity, EventEntity):
payload, payload,
) )
return return
mqtt_data = get_mqtt_data(self.hass) mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)
topics["state_topic"] = { topics["state_topic"] = {

View file

@ -33,12 +33,13 @@ from .mixins import (
async_setup_entity_entry_helper, async_setup_entity_entry_helper,
) )
from .models import ( from .models import (
DATA_MQTT,
MessageCallbackType, MessageCallbackType,
MqttValueTemplate, MqttValueTemplate,
MqttValueTemplateException, MqttValueTemplateException,
ReceiveMessage, ReceiveMessage,
) )
from .util import get_mqtt_data, valid_subscribe_topic from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -186,7 +187,7 @@ class MqttImage(MqttEntity, ImageEntity):
) )
self._last_image = None self._last_image = None
self._attr_image_last_updated = dt_util.utcnow() self._attr_image_last_updated = dt_util.utcnow()
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received) add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@ -208,7 +209,7 @@ class MqttImage(MqttEntity, ImageEntity):
) )
self._attr_image_last_updated = dt_util.utcnow() self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None self._cached_image = None
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received) add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)

View file

@ -106,6 +106,7 @@ from .discovery import (
set_discovery_hash, set_discovery_hash,
) )
from .models import ( from .models import (
DATA_MQTT,
MessageCallbackType, MessageCallbackType,
MqttValueTemplate, MqttValueTemplate,
MqttValueTemplateException, MqttValueTemplateException,
@ -118,7 +119,7 @@ from .subscription import (
async_subscribe_topics, async_subscribe_topics,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic from .util import mqtt_config_entry_enabled, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -329,7 +330,7 @@ async def async_setup_non_entity_entry_helper(
discovery_schema: vol.Schema, discovery_schema: vol.Schema,
) -> None: ) -> None:
"""Set up automation or tag creation dynamically through MQTT discovery.""" """Set up automation or tag creation dynamically through MQTT discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
async def async_setup_from_discovery( async def async_setup_from_discovery(
discovery_payload: MQTTDiscoveryPayload, discovery_payload: MQTTDiscoveryPayload,
@ -360,7 +361,7 @@ async def async_setup_entity_entry_helper(
schema_class_mapping: dict[str, type[MqttEntity]] | None = None, schema_class_mapping: dict[str, type[MqttEntity]] | None = None,
) -> None: ) -> None:
"""Set up entity creation dynamically through MQTT discovery.""" """Set up entity creation dynamically through MQTT discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
@callback @callback
def async_setup_from_discovery( def async_setup_from_discovery(
@ -391,7 +392,7 @@ async def async_setup_entity_entry_helper(
def _async_setup_entities() -> None: def _async_setup_entities() -> None:
"""Set up MQTT items from configuration.yaml.""" """Set up MQTT items from configuration.yaml."""
nonlocal entity_class nonlocal entity_class
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
if not (config_yaml := mqtt_data.config): if not (config_yaml := mqtt_data.config):
return return
yaml_configs: list[ConfigType] = [ yaml_configs: list[ConfigType] = [
@ -496,7 +497,7 @@ def write_state_on_attr_change(
if not _attrs_have_changed(tracked_attrs): if not _attrs_have_changed(tracked_attrs):
return return
mqtt_data = get_mqtt_data(entity.hass) mqtt_data = entity.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(entity) mqtt_data.state_write_requests.write_state_request(entity)
return wrapper return wrapper
@ -695,7 +696,7 @@ class MqttAvailability(Entity):
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return if the device is available.""" """Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass) mqtt_data = self.hass.data[DATA_MQTT]
client = mqtt_data.client client = mqtt_data.client
if not client.connected and not self.hass.is_stopping: if not client.connected and not self.hass.is_stopping:
return False return False
@ -936,7 +937,7 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False self._removed_from_hass = False
if discovery_data is None: if discovery_data is None:
return return
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks: if discovery_hash in self._registry_hooks:

View file

@ -20,6 +20,7 @@ from homeassistant.helpers import template
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING: if TYPE_CHECKING:
from paho.mqtt.client import MQTTMessage from paho.mqtt.client import MQTTMessage
@ -419,3 +420,7 @@ class MqttData:
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState) state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
subscriptions_to_restore: list[Subscription] = field(default_factory=list) subscriptions_to_restore: list[Subscription] = field(default_factory=list)
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict) tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")

View file

@ -28,13 +28,14 @@ from .mixins import (
update_device, update_device,
) )
from .models import ( from .models import (
DATA_MQTT,
MqttValueTemplate, MqttValueTemplate,
MqttValueTemplateException, MqttValueTemplateException,
ReceiveMessage, ReceiveMessage,
ReceivePayloadType, ReceivePayloadType,
) )
from .subscription import EntitySubscription from .subscription import EntitySubscription
from .util import get_mqtt_data, valid_subscribe_topic from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -70,7 +71,7 @@ async def _async_setup_tag(
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
device_id = update_device(hass, config_entry, config) device_id = update_device(hass, config_entry, config)
if device_id is not None and device_id not in (tags := get_mqtt_data(hass).tags): if device_id is not None and device_id not in (tags := hass.data[DATA_MQTT].tags):
tags[device_id] = {} tags[device_id] = {}
tag_scanner = MQTTTagScanner( tag_scanner = MQTTTagScanner(
@ -91,7 +92,7 @@ async def _async_setup_tag(
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool: def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
"""Device has tag scanners.""" """Device has tag scanners."""
if device_id not in (tags := get_mqtt_data(hass).tags): if device_id not in (tags := hass.data[DATA_MQTT].tags):
return False return False
return tags[device_id] != {} return tags[device_id] != {}
@ -176,4 +177,4 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
self.hass, self._sub_state self.hass, self._sub_state
) )
if self.device_id: if self.device_id:
get_mqtt_data(self.hass).tags[self.device_id].pop(discovery_id) self.hass.data[DATA_MQTT].tags[self.device_id].pop(discovery_id)

View file

@ -26,14 +26,12 @@ from .const import (
CONF_CERTIFICATE, CONF_CERTIFICATE,
CONF_CLIENT_CERT, CONF_CLIENT_CERT,
CONF_CLIENT_KEY, CONF_CLIENT_KEY,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_QOS, DEFAULT_QOS,
DEFAULT_RETAIN, DEFAULT_RETAIN,
DOMAIN, DOMAIN,
) )
from .models import MqttData from .models import DATA_MQTT, DATA_MQTT_AVAILABLE
AVAILABILITY_TIMEOUT = 30.0 AVAILABILITY_TIMEOUT = 30.0
@ -51,7 +49,7 @@ async def async_forward_entry_setup_and_setup_discovery(
hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str] hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str]
) -> None: ) -> None:
"""Forward the config entry setup to the platforms and set up discovery.""" """Forward the config entry setup to the platforms and set up discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = hass.data[DATA_MQTT]
platforms_loaded = mqtt_data.platforms_loaded platforms_loaded = mqtt_data.platforms_loaded
new_platforms: set[Platform | str] = platforms - platforms_loaded new_platforms: set[Platform | str] = platforms - platforms_loaded
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
@ -85,7 +83,11 @@ async def async_forward_entry_setup_and_setup_discovery(
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
"""Return true when the MQTT config entry is enabled.""" """Return true when the MQTT config entry is enabled."""
return hass.config_entries.async_has_entries( # If the mqtt client is connected, skip the expensive config
# entry check as its roughly two orders of magnitude faster.
return (
DATA_MQTT in hass.data and hass.data[DATA_MQTT].client.connected
) or hass.config_entries.async_has_entries(
DOMAIN, include_disabled=False, include_ignore=False DOMAIN, include_disabled=False, include_ignore=False
) )
@ -229,13 +231,6 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config return config
@lru_cache(maxsize=1)
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
return mqtt_data
async def async_create_certificate_temp_files( async def async_create_certificate_temp_files(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> None: ) -> None:

View file

@ -43,7 +43,7 @@ async def setup_comp(
async def test_setup_fails_without_mqtt_being_setup( async def test_setup_fails_without_mqtt_being_setup(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, mqtt_mock: MqttMockHAClient, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Ensure mqtt is started when we setup the component.""" """Ensure mqtt is started when we setup the component."""
# Simulate MQTT is was removed # Simulate MQTT is was removed
@ -52,6 +52,8 @@ async def test_setup_fails_without_mqtt_being_setup(
await hass.config_entries.async_set_disabled_by( await hass.config_entries.async_set_disabled_by(
mqtt_entry.entry_id, ConfigEntryDisabler.USER mqtt_entry.entry_id, ConfigEntryDisabler.USER
) )
# mqtt is mocked so we need to simulate it is not connected
mqtt_mock.connected = False
dev_id = "zanzito" dev_id = "zanzito"
topic = "location/zanzito" topic = "location/zanzito"