diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 842e5b6405f..315f116ed92 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -20,13 +20,7 @@ from homeassistant.const import ( CONF_USERNAME, SERVICE_RELOAD, ) -from homeassistant.core import ( - CALLBACK_TYPE, - HassJob, - HomeAssistant, - ServiceCall, - callback, -) +from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.helpers import ( config_validation as cv, @@ -71,15 +65,7 @@ from .const import ( # noqa: F401 CONF_TLS_VERSION, CONF_TOPIC, CONF_WILL_MESSAGE, - CONFIG_ENTRY_IS_SETUP, DATA_MQTT, - DATA_MQTT_CONFIG, - DATA_MQTT_DISCOVERY_REGISTRY_HOOKS, - DATA_MQTT_RELOAD_DISPATCHERS, - DATA_MQTT_RELOAD_ENTRY, - DATA_MQTT_RELOAD_NEEDED, - DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE, - DATA_MQTT_UPDATED_CONFIG, DEFAULT_ENCODING, DEFAULT_QOS, DEFAULT_RETAIN, @@ -89,7 +75,7 @@ from .const import ( # noqa: F401 PLATFORMS, RELOADABLE_PLATFORMS, ) -from .mixins import async_discover_yaml_entities +from .mixins import MqttData, async_discover_yaml_entities from .models import ( # noqa: F401 MqttCommandTemplate, MqttValueTemplate, @@ -177,6 +163,8 @@ async def _async_setup_discovery( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Start the MQTT protocol service.""" + mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + conf: ConfigType | None = config.get(DOMAIN) websocket_api.async_register_command(hass, websocket_subscribe) @@ -185,7 +173,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if conf: conf = dict(conf) - hass.data[DATA_MQTT_CONFIG] = conf + mqtt_data.config = conf if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is None: # Create an import flow if the user has yaml configured entities etc. @@ -197,12 +185,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, data={}, ) - hass.data[DATA_MQTT_RELOAD_NEEDED] = True + mqtt_data.reload_needed = True elif mqtt_entry_status is False: _LOGGER.info( "MQTT will be not available until the config entry is enabled", ) - hass.data[DATA_MQTT_RELOAD_NEEDED] = True + mqtt_data.reload_needed = True return True @@ -260,33 +248,34 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) - Causes for this is config entry options changing. """ - mqtt_client = hass.data[DATA_MQTT] + mqtt_data: MqttData = hass.data[DATA_MQTT] + assert (client := mqtt_data.client) is not None - if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: + if (conf := mqtt_data.config) is None: conf = CONFIG_SCHEMA_BASE(dict(entry.data)) - mqtt_client.conf = _merge_extended_config(entry, conf) - await mqtt_client.async_disconnect() - mqtt_client.init_client() - await mqtt_client.async_connect() + mqtt_data.config = _merge_extended_config(entry, conf) + await client.async_disconnect() + client.init_client() + await client.async_connect() await discovery.async_stop(hass) - if mqtt_client.conf.get(CONF_DISCOVERY): - await _async_setup_discovery(hass, mqtt_client.conf, entry) + if client.conf.get(CONF_DISCOVERY): + await _async_setup_discovery(hass, cast(ConfigType, mqtt_data.config), entry) async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None: """Fetch fresh MQTT yaml config from the hass config when (re)loading the entry.""" - if DATA_MQTT_RELOAD_ENTRY in hass.data: + mqtt_data: MqttData = hass.data[DATA_MQTT] + if mqtt_data.reload_entry: hass_config = await conf_util.async_hass_config_yaml(hass) - mqtt_config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) - hass.data[DATA_MQTT_CONFIG] = mqtt_config + mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) # Remove unknown keys from config entry data _filter_entry_config(hass, entry) # Merge basic configuration, and add missing defaults for basic options - _merge_basic_config(hass, entry, hass.data.get(DATA_MQTT_CONFIG, {})) + _merge_basic_config(hass, entry, mqtt_data.config or {}) # Bail out if broker setting is missing if CONF_BROKER not in entry.data: _LOGGER.error("MQTT broker is not configured, please configure it") @@ -294,7 +283,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | # If user doesn't have configuration.yaml config, generate default values # for options not in config entry data - if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: + if (conf := mqtt_data.config) is None: conf = CONFIG_SCHEMA_BASE(dict(entry.data)) # User has configuration.yaml config, warn about config entry overrides @@ -317,21 +306,20 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Load a config entry.""" + mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + # Merge basic configuration, and add missing defaults for basic options if (conf := await async_fetch_config(hass, entry)) is None: # Bail out return False - - hass.data[DATA_MQTT_DISCOVERY_REGISTRY_HOOKS] = {} - hass.data[DATA_MQTT] = MQTT(hass, entry, conf) + mqtt_data.client = MQTT(hass, entry, conf) # Restore saved subscriptions - if DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE in hass.data: - hass.data[DATA_MQTT].subscriptions = hass.data.pop( - DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE - ) + if mqtt_data.subscriptions_to_restore: + mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore + mqtt_data.subscriptions_to_restore = [] entry.add_update_listener(_async_config_entry_updated) - await hass.data[DATA_MQTT].async_connect() + await mqtt_data.client.async_connect() async def async_publish_service(call: ServiceCall) -> None: """Handle MQTT publish service calls.""" @@ -380,7 +368,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) return - await hass.data[DATA_MQTT].async_publish(msg_topic, payload, qos, retain) + assert mqtt_data.client is not None and msg_topic is not None + await mqtt_data.client.async_publish(msg_topic, payload, qos, retain) hass.services.async_register( DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA @@ -421,7 +410,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) # setup platforms and discovery - hass.data[CONFIG_ENTRY_IS_SETUP] = set() async def async_setup_reload_service() -> None: """Create the reload service for the MQTT domain.""" @@ -435,7 +423,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Reload the modern yaml platforms config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {} - hass.data[DATA_MQTT_UPDATED_CONFIG] = config_yaml.get(DOMAIN, {}) + mqtt_data.updated_config = config_yaml.get(DOMAIN, {}) await asyncio.gather( *( [ @@ -476,13 +464,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Setup reload service after all platforms have loaded await async_setup_reload_service() # When the entry is reloaded, also reload manual set up items to enable MQTT - if DATA_MQTT_RELOAD_ENTRY in hass.data: - hass.data.pop(DATA_MQTT_RELOAD_ENTRY) + if mqtt_data.reload_entry: + mqtt_data.reload_entry = False reload_manual_setup = True # When the entry was disabled before, reload manual set up items to enable MQTT again - if DATA_MQTT_RELOAD_NEEDED in hass.data: - hass.data.pop(DATA_MQTT_RELOAD_NEEDED) + if mqtt_data.reload_needed: + mqtt_data.reload_needed = False reload_manual_setup = True if reload_manual_setup: @@ -592,7 +580,9 @@ def async_subscribe_connection_status( def is_connected(hass: HomeAssistant) -> bool: """Return if MQTT client is connected.""" - return hass.data[DATA_MQTT].connected + mqtt_data: MqttData = hass.data[DATA_MQTT] + assert mqtt_data.client is not None + return mqtt_data.client.connected async def async_remove_config_entry_device( @@ -608,6 +598,10 @@ async def async_remove_config_entry_device( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload MQTT dump and publish service when the config entry is unloaded.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] + assert mqtt_data.client is not None + mqtt_client = mqtt_data.client + # Unload publish and dump services. hass.services.async_remove( DOMAIN, @@ -620,7 +614,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Stop the discovery await discovery.async_stop(hass) - mqtt_client: MQTT = hass.data[DATA_MQTT] # Unload the platforms await asyncio.gather( *( @@ -630,26 +623,23 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) await hass.async_block_till_done() # Unsubscribe reload dispatchers - while reload_dispatchers := hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []): + while reload_dispatchers := mqtt_data.reload_dispatchers: reload_dispatchers.pop()() - hass.data[CONFIG_ENTRY_IS_SETUP] = set() # Cleanup listeners mqtt_client.cleanup() # Trigger reload manual MQTT items at entry setup if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is False: # The entry is disabled reload legacy manual items when the entry is enabled again - hass.data[DATA_MQTT_RELOAD_NEEDED] = True + mqtt_data.reload_needed = True elif mqtt_entry_status is True: # The entry is reloaded: # Trigger re-fetching the yaml config at entry setup - hass.data[DATA_MQTT_RELOAD_ENTRY] = True + mqtt_data.reload_entry = True # Reload the legacy yaml platform to make entities unavailable await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS) # Cleanup entity registry hooks - registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[ - DATA_MQTT_DISCOVERY_REGISTRY_HOOKS - ] + registry_hooks = mqtt_data.discovery_registry_hooks while registry_hooks: registry_hooks.popitem()[1]() # Wait for all ACKs and stop the loop @@ -657,6 +647,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Store remaining subscriptions to be able to restore or reload them # when the entry is set up again if mqtt_client.subscriptions: - hass.data[DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE] = mqtt_client.subscriptions + mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions return True diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 884e589ba05..28887818133 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable, Coroutine, Iterable +from collections.abc import Callable, Coroutine, Iterable from functools import lru_cache, partial, wraps import inspect from itertools import groupby @@ -17,6 +17,7 @@ import attr import certifi from paho.mqtt.client import MQTTMessage +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_CLIENT_ID, CONF_PASSWORD, @@ -52,7 +53,6 @@ from .const import ( MQTT_DISCONNECTED, PROTOCOL_31, ) -from .discovery import LAST_DISCOVERY from .models import ( AsyncMessageCallbackType, MessageCallbackType, @@ -68,6 +68,9 @@ if TYPE_CHECKING: # because integrations should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt + from .mixins import MqttData + + _LOGGER = logging.getLogger(__name__) DISCOVERY_COOLDOWN = 2 @@ -97,8 +100,12 @@ async def async_publish( encoding: str | None = DEFAULT_ENCODING, ) -> None: """Publish message to a MQTT topic.""" + # Local import to avoid circular dependencies + # pylint: disable-next=import-outside-toplevel + from .mixins import MqttData - if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass): + mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): raise HomeAssistantError( f"Cannot publish to topic '{topic}', MQTT is not enabled" ) @@ -126,11 +133,13 @@ async def async_publish( ) return - await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain) + await mqtt_data.client.async_publish( + topic, outgoing_payload, qos or 0, retain or False + ) AsyncDeprecatedMessageCallbackType = Callable[ - [str, ReceivePayloadType, int], Awaitable[None] + [str, ReceivePayloadType, int], Coroutine[Any, Any, None] ] DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] @@ -175,13 +184,18 @@ async def async_subscribe( | DeprecatedMessageCallbackType | AsyncDeprecatedMessageCallbackType, qos: int = DEFAULT_QOS, - encoding: str | None = "utf-8", + encoding: str | None = DEFAULT_ENCODING, ): """Subscribe to an MQTT topic. Call the return value to unsubscribe. """ - if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass): + # Local import to avoid circular dependencies + # pylint: disable-next=import-outside-toplevel + from .mixins import MqttData + + mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): raise HomeAssistantError( f"Cannot subscribe to topic '{topic}', MQTT is not enabled" ) @@ -206,7 +220,7 @@ async def async_subscribe( cast(DeprecatedMessageCallbackType, msg_callback) ) - async_remove = await hass.data[DATA_MQTT].async_subscribe( + async_remove = await mqtt_data.client.async_subscribe( topic, catch_log_exception( wrapped_msg_callback, @@ -309,15 +323,17 @@ class MQTT: def __init__( self, - hass, - config_entry, - conf, + hass: HomeAssistant, + config_entry: ConfigEntry, + conf: ConfigType, ) -> None: """Initialize Home Assistant MQTT client.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + self._mqtt_data: MqttData = hass.data[DATA_MQTT] + self.hass = hass self.config_entry = config_entry self.conf = conf @@ -635,7 +651,6 @@ class MQTT: subscription.job, ) continue - self.hass.async_run_hass_job( subscription.job, ReceiveMessage( @@ -695,10 +710,10 @@ class MQTT: async def _discovery_cooldown(self): now = time.time() # Reset discovery and subscribe cooldowns - self.hass.data[LAST_DISCOVERY] = now + self._mqtt_data.last_discovery = now self._last_subscribe = now - last_discovery = self.hass.data[LAST_DISCOVERY] + last_discovery = self._mqtt_data.last_discovery last_subscribe = self._last_subscribe wait_until = max( last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN @@ -706,7 +721,7 @@ class MQTT: while now < wait_until: await asyncio.sleep(wait_until - now) now = time.time() - last_discovery = self.hass.data[LAST_DISCOVERY] + last_discovery = self._mqtt_data.last_discovery last_subscribe = self._last_subscribe wait_until = max( last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 538c12d258c..12d97b41a74 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -18,7 +18,7 @@ from homeassistant.const import ( CONF_PROTOCOL, CONF_USERNAME, ) -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResult from .client import MqttClientSetup @@ -30,12 +30,13 @@ from .const import ( CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_WILL_MESSAGE, - DATA_MQTT_CONFIG, + DATA_MQTT, DEFAULT_BIRTH, DEFAULT_DISCOVERY, DEFAULT_WILL, DOMAIN, ) +from .mixins import MqttData from .util import MQTT_WILL_BIRTH_SCHEMA MQTT_TIMEOUT = 5 @@ -164,9 +165,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Manage the MQTT broker configuration.""" + mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) errors = {} current_config = self.config_entry.data - yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {}) + yaml_config = mqtt_data.config or {} if user_input is not None: can_connect = await self.hass.async_add_executor_job( try_connection, @@ -214,9 +216,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Manage the MQTT options.""" + mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) errors = {} current_config = self.config_entry.data - yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {}) + yaml_config = mqtt_data.config or {} options_config: dict[str, Any] = {} if user_input is not None: bad_birth = False @@ -334,14 +337,22 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): ) -def try_connection(hass, broker, port, username, password, protocol="3.1"): +def try_connection( + hass: HomeAssistant, + broker: str, + port: int, + username: str | None, + password: str | None, + protocol: str = "3.1", +) -> bool: """Test if we can connect to an MQTT broker.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel # Get the config from configuration.yaml - yaml_config = hass.data.get(DATA_MQTT_CONFIG, {}) + mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + yaml_config = mqtt_data.config or {} entry_config = { CONF_BROKER: broker, CONF_PORT: port, @@ -351,7 +362,7 @@ def try_connection(hass, broker, port, username, password, protocol="3.1"): } client = MqttClientSetup({**yaml_config, **entry_config}).client - result = queue.Queue(maxsize=1) + result: queue.Queue[bool] = queue.Queue(maxsize=1) def on_connect(client_, userdata, flags, result_code): """Handle connection result.""" diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index c8af58862e0..93410f0c792 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -30,16 +30,8 @@ CONF_CLIENT_CERT = "client_cert" CONF_TLS_INSECURE = "tls_insecure" CONF_TLS_VERSION = "tls_version" -CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup" DATA_MQTT = "mqtt" -DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE = "mqtt_client_subscriptions" -DATA_MQTT_DISCOVERY_REGISTRY_HOOKS = "mqtt_discovery_registry_hooks" -DATA_MQTT_CONFIG = "mqtt_config" MQTT_DATA_DEVICE_TRACKER_LEGACY = "mqtt_device_tracker_legacy" -DATA_MQTT_RELOAD_DISPATCHERS = "mqtt_reload_dispatchers" -DATA_MQTT_RELOAD_ENTRY = "mqtt_reload_entry" -DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed" -DATA_MQTT_UPDATED_CONFIG = "mqtt_updated_config" DEFAULT_PREFIX = "homeassistant" DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status" diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 30d6fdea05f..7e37ed72821 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -33,11 +33,13 @@ from .const import ( CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, + DATA_MQTT, DOMAIN, ) from .discovery import MQTT_DISCOVERY_DONE from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, + MqttData, MqttDiscoveryDeviceUpdate, send_discovery_done, update_device, @@ -81,8 +83,6 @@ TRIGGER_DISCOVERY_SCHEMA = MQTT_BASE_SCHEMA.extend( extra=vol.REMOVE_EXTRA, ) -DEVICE_TRIGGERS = "mqtt_device_triggers" - LOG_NAME = "Device trigger" @@ -203,6 +203,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): self.device_id = device_id self.discovery_data = discovery_data self.hass = hass + self._mqtt_data: MqttData = hass.data[DATA_MQTT] MqttDiscoveryDeviceUpdate.__init__( self, @@ -217,8 +218,8 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): """Initialize the device trigger.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}): - self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( + if discovery_id not in self._mqtt_data.device_triggers: + self._mqtt_data.device_triggers[discovery_id] = Trigger( hass=self.hass, device_id=self.device_id, discovery_data=self.discovery_data, @@ -230,7 +231,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): value_template=self._config[CONF_VALUE_TEMPLATE], ) else: - await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( + await self._mqtt_data.device_triggers[discovery_id].update_trigger( self._config ) debug_info.add_trigger_discovery_data( @@ -246,16 +247,16 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): ) config = TRIGGER_DISCOVERY_SCHEMA(discovery_data) update_device(self.hass, self._config_entry, config) - device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] + device_trigger: Trigger = self._mqtt_data.device_triggers[discovery_id] await device_trigger.update_trigger(config) async def async_tear_down(self) -> None: """Cleanup device trigger.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - if discovery_id in self.hass.data[DEVICE_TRIGGERS]: + if discovery_id in self._mqtt_data.device_triggers: _LOGGER.info("Removing trigger: %s", discovery_hash) - trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] + trigger: Trigger = self._mqtt_data.device_triggers[discovery_id] trigger.detach_trigger() debug_info.remove_trigger_discovery_data(self.hass, discovery_hash) @@ -280,11 +281,10 @@ async def async_setup_trigger( async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: """Handle Mqtt removed from a device.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] triggers = await async_get_triggers(hass, device_id) for trig in triggers: - device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop( - trig[CONF_DISCOVERY_ID] - ) + device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID]) if device_trigger: device_trigger.detach_trigger() discovery_data = cast(dict, device_trigger.discovery_data) @@ -296,12 +296,13 @@ async def async_get_triggers( hass: HomeAssistant, device_id: str ) -> list[dict[str, str]]: """List device triggers for MQTT devices.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] triggers: list[dict[str, str]] = [] - if DEVICE_TRIGGERS not in hass.data: + if not mqtt_data.device_triggers: return triggers - for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items(): + for discovery_id, trig in mqtt_data.device_triggers.items(): if trig.device_id != device_id or trig.topic is None: continue @@ -324,12 +325,12 @@ async def async_attach_trigger( trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" - hass.data.setdefault(DEVICE_TRIGGERS, {}) + mqtt_data: MqttData = hass.data[DATA_MQTT] device_id = config[CONF_DEVICE_ID] discovery_id = config[CONF_DISCOVERY_ID] - if discovery_id not in hass.data[DEVICE_TRIGGERS]: - hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( + if discovery_id not in mqtt_data.device_triggers: + mqtt_data.device_triggers[discovery_id] = Trigger( hass=hass, device_id=device_id, discovery_data=None, @@ -340,6 +341,6 @@ async def async_attach_trigger( qos=None, value_template=None, ) - return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger( + return await mqtt_data.device_triggers[discovery_id].add_trigger( action, trigger_info ) diff --git a/homeassistant/components/mqtt/diagnostics.py b/homeassistant/components/mqtt/diagnostics.py index ea490783fc0..2a6322cac63 100644 --- a/homeassistant/components/mqtt/diagnostics.py +++ b/homeassistant/components/mqtt/diagnostics.py @@ -43,7 +43,7 @@ def _async_get_diagnostics( device: DeviceEntry | None = None, ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - mqtt_instance: MQTT = hass.data[DATA_MQTT] + mqtt_instance: MQTT = hass.data[DATA_MQTT].client redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 8a4c4d0c542..65051ce54fc 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -7,6 +7,7 @@ import functools import logging import re import time +from typing import TYPE_CHECKING from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.core import HomeAssistant @@ -28,9 +29,13 @@ from .const import ( ATTR_DISCOVERY_TOPIC, CONF_AVAILABILITY, CONF_TOPIC, + DATA_MQTT, DOMAIN, ) +if TYPE_CHECKING: + from .mixins import MqttData + _LOGGER = logging.getLogger(__name__) TOPIC_MATCHER = re.compile( @@ -69,7 +74,6 @@ INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" -LAST_DISCOVERY = "mqtt_last_discovery" TOPIC_BASE = "~" @@ -80,12 +84,12 @@ class MQTTConfig(dict): discovery_data: dict -def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None: +def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None: """Clear entry in ALREADY_DISCOVERED list.""" del hass.data[ALREADY_DISCOVERED][discovery_hash] -def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple): +def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]): """Clear entry in ALREADY_DISCOVERED list.""" hass.data[ALREADY_DISCOVERED][discovery_hash] = {} @@ -94,11 +98,12 @@ async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic, config_entry=None ) -> None: """Start MQTT Discovery.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] mqtt_integrations = {} async def async_discovery_message_received(msg): """Process the received message.""" - hass.data[LAST_DISCOVERY] = time.time() + mqtt_data.last_discovery = time.time() payload = msg.payload topic = msg.topic topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1) @@ -253,7 +258,7 @@ async def async_start( # noqa: C901 ) ) - hass.data[LAST_DISCOVERY] = time.time() + mqtt_data.last_discovery = time.time() mqtt_integrations = await async_get_mqtt(hass) hass.data[INTEGRATION_UNSUBSCRIBE] = {} diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index fddbe838303..a16394667d8 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import abstractmethod import asyncio from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field from functools import partial import logging -from typing import Any, Protocol, cast, final +from typing import TYPE_CHECKING, Any, Protocol, cast, final import voluptuous as vol @@ -60,7 +61,7 @@ from homeassistant.helpers.json import json_loads from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import debug_info, subscription -from .client import async_publish +from .client import MQTT, Subscription, async_publish from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, @@ -70,11 +71,6 @@ from .const import ( CONF_QOS, CONF_TOPIC, DATA_MQTT, - DATA_MQTT_CONFIG, - DATA_MQTT_DISCOVERY_REGISTRY_HOOKS, - DATA_MQTT_RELOAD_DISPATCHERS, - DATA_MQTT_RELOAD_ENTRY, - DATA_MQTT_UPDATED_CONFIG, DEFAULT_ENCODING, DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE, @@ -98,6 +94,9 @@ from .subscription import ( ) from .util import mqtt_config_entry_enabled, valid_subscribe_topic +if TYPE_CHECKING: + from .device_trigger import Trigger + _LOGGER = logging.getLogger(__name__) AVAILABILITY_ALL = "all" @@ -274,6 +273,24 @@ def warn_for_legacy_schema(domain: str) -> Callable: return validator +@dataclass +class MqttData: + """Keep the MQTT entry data.""" + + client: MQTT | None = None + config: ConfigType | None = None + device_triggers: dict[str, Trigger] = field(default_factory=dict) + discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( + default_factory=dict + ) + last_discovery: float = 0.0 + reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) + reload_entry: bool = False + reload_needed: bool = False + subscriptions_to_restore: list[Subscription] = field(default_factory=list) + updated_config: ConfigType = field(default_factory=dict) + + class SetupEntity(Protocol): """Protocol type for async_setup_entities.""" @@ -292,11 +309,12 @@ async def async_discover_yaml_entities( hass: HomeAssistant, platform_domain: str ) -> None: """Discover entities for a platform.""" - if DATA_MQTT_UPDATED_CONFIG in hass.data: + mqtt_data: MqttData = hass.data[DATA_MQTT] + if mqtt_data.updated_config: # The platform has been reloaded - config_yaml = hass.data[DATA_MQTT_UPDATED_CONFIG] + config_yaml = mqtt_data.updated_config else: - config_yaml = hass.data.get(DATA_MQTT_CONFIG, {}) + config_yaml = mqtt_data.config or {} if not config_yaml: return if platform_domain not in config_yaml: @@ -318,8 +336,9 @@ async def async_get_platform_config_from_yaml( ) -> list[ConfigType]: """Return a list of validated configurations for the domain.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] if config_yaml is None: - config_yaml = hass.data.get(DATA_MQTT_CONFIG) + config_yaml = mqtt_data.config if not config_yaml: return [] if not (platform_configs := config_yaml.get(platform_domain)): @@ -334,6 +353,7 @@ async def async_setup_entry_helper( schema: vol.Schema, ) -> None: """Set up entity, automation or tag creation dynamically through MQTT discovery.""" + mqtt_data: MqttData = hass.data[DATA_MQTT] async def async_discover(discovery_payload): """Discover and add an MQTT entity, automation or tag.""" @@ -357,7 +377,7 @@ async def async_setup_entry_helper( ) raise - hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []).append( + mqtt_data.reload_dispatchers.append( async_dispatcher_connect( hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover ) @@ -372,7 +392,8 @@ async def async_setup_platform_helper( async_setup_entities: SetupEntity, ) -> None: """Help to set up the platform for manual configured MQTT entities.""" - if DATA_MQTT_RELOAD_ENTRY in hass.data: + mqtt_data: MqttData = hass.data[DATA_MQTT] + if mqtt_data.reload_entry: _LOGGER.debug( "MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry", platform_domain, @@ -597,7 +618,10 @@ class MqttAvailability(Entity): @property def available(self) -> bool: """Return if the device is available.""" - if not self.hass.data[DATA_MQTT].connected and not self.hass.is_stopping: + mqtt_data: MqttData = self.hass.data[DATA_MQTT] + assert mqtt_data.client is not None + client = mqtt_data.client + if not client.connected and not self.hass.is_stopping: return False if not self._avail_topics: return True @@ -632,7 +656,7 @@ async def cleanup_device_registry( ) -def get_discovery_hash(discovery_data: dict) -> tuple: +def get_discovery_hash(discovery_data: dict) -> tuple[str, str]: """Get the discovery hash from the discovery data.""" return discovery_data[ATTR_DISCOVERY_HASH] @@ -817,9 +841,8 @@ class MqttDiscoveryUpdate(Entity): self._removed_from_hass = False if discovery_data is None: return - self._registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[ - DATA_MQTT_DISCOVERY_REGISTRY_HOOKS - ] + mqtt_data: MqttData = hass.data[DATA_MQTT] + self._registry_hooks = mqtt_data.discovery_registry_hooks discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] if discovery_hash in self._registry_hooks: self._registry_hooks.pop(discovery_hash)() @@ -897,7 +920,7 @@ class MqttDiscoveryUpdate(Entity): def add_to_platform_abort(self) -> None: """Abort adding an entity to a platform.""" if self._discovery_data is not None: - discovery_hash: tuple = self._discovery_data[ATTR_DISCOVERY_HASH] + discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH] if self.registry_entry is not None: self._registry_hooks[ discovery_hash diff --git a/tests/common.py b/tests/common.py index 232701bd746..cc2bc454810 100644 --- a/tests/common.py +++ b/tests/common.py @@ -369,7 +369,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False): if isinstance(payload, str): payload = payload.encode("utf-8") msg = ReceiveMessage(topic, payload, qos, retain) - hass.data["mqtt"]._mqtt_handle_message(msg) + hass.data["mqtt"].client._mqtt_handle_message(msg) fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index e40397fd1d4..dba06e5cd5b 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -155,7 +155,7 @@ async def test_manual_config_set( assert await async_setup_component(hass, "mqtt", {"mqtt": {"broker": "bla"}}) await hass.async_block_till_done() # do not try to reload - del hass.data["mqtt_reload_needed"] + hass.data["mqtt"].reload_needed = False assert len(mock_finish_setup.mock_calls) == 0 mock_try_connection.return_value = True diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index c625d0a21f9..a9ac66f8851 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -1438,7 +1438,7 @@ async def test_clean_up_registry_monitoring( ): """Test registry monitoring hook is removed after a reload.""" await mqtt_mock_entry_no_yaml_config() - hooks: dict = hass.data[mqtt.const.DATA_MQTT_DISCOVERY_REGISTRY_HOOKS] + hooks: dict = hass.data["mqtt"].discovery_registry_hooks # discover an entity that is not enabled by default config1 = { "name": "sbfspot_12345", diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index b76979cc990..46649bf703f 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1776,14 +1776,14 @@ async def test_delayed_birth_message( await hass.async_block_till_done() mqtt_component_mock = MagicMock( - return_value=hass.data["mqtt"], - spec_set=hass.data["mqtt"], - wraps=hass.data["mqtt"], + return_value=hass.data["mqtt"].client, + spec_set=hass.data["mqtt"].client, + wraps=hass.data["mqtt"].client, ) mqtt_component_mock._mqttc = mqtt_client_mock - hass.data["mqtt"] = mqtt_component_mock - mqtt_mock = hass.data["mqtt"] + hass.data["mqtt"].client = mqtt_component_mock + mqtt_mock = hass.data["mqtt"].client mqtt_mock.reset_mock() async def wait_birth(topic, payload, qos):