Make hass.data["mqtt"] an instance of a DataClass (#77972)

* Use dataclass to reference hass.data globals

* Add discovery_registry_hooks to dataclass

* Move discovery registry hooks to dataclass

* Add device triggers to dataclass

* Cleanup DEVICE_TRIGGERS const

* Add last_discovery to data_class

* Simplify typing for class `Subscription`

* Follow up on comment

* Redo suggested typing change to sasisfy mypy

* Restore typing

* Add mypy version to CI check logging

* revert changes to ci.yaml

* Add docstr for protocol

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Mypy update after merging #78399

* Remove mypy ignore

* Correct return type

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-09-17 21:43:42 +02:00 committed by GitHub
parent 391d895426
commit 1f410e884a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 137 deletions

View file

@ -20,13 +20,7 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
SERVICE_RELOAD, SERVICE_RELOAD,
) )
from homeassistant.core import ( from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
CALLBACK_TYPE,
HassJob,
HomeAssistant,
ServiceCall,
callback,
)
from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
@ -71,15 +65,7 @@ from .const import ( # noqa: F401
CONF_TLS_VERSION, CONF_TLS_VERSION,
CONF_TOPIC, CONF_TOPIC,
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
CONFIG_ENTRY_IS_SETUP,
DATA_MQTT, DATA_MQTT,
DATA_MQTT_CONFIG,
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
DATA_MQTT_RELOAD_DISPATCHERS,
DATA_MQTT_RELOAD_ENTRY,
DATA_MQTT_RELOAD_NEEDED,
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE,
DATA_MQTT_UPDATED_CONFIG,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_QOS, DEFAULT_QOS,
DEFAULT_RETAIN, DEFAULT_RETAIN,
@ -89,7 +75,7 @@ from .const import ( # noqa: F401
PLATFORMS, PLATFORMS,
RELOADABLE_PLATFORMS, RELOADABLE_PLATFORMS,
) )
from .mixins import async_discover_yaml_entities from .mixins import MqttData, async_discover_yaml_entities
from .models import ( # noqa: F401 from .models import ( # noqa: F401
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@ -177,6 +163,8 @@ async def _async_setup_discovery(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Start the MQTT protocol service.""" """Start the MQTT protocol service."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
conf: ConfigType | None = config.get(DOMAIN) conf: ConfigType | None = config.get(DOMAIN)
websocket_api.async_register_command(hass, websocket_subscribe) websocket_api.async_register_command(hass, websocket_subscribe)
@ -185,7 +173,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if conf: if conf:
conf = dict(conf) conf = dict(conf)
hass.data[DATA_MQTT_CONFIG] = conf mqtt_data.config = conf
if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is None: if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is None:
# Create an import flow if the user has yaml configured entities etc. # Create an import flow if the user has yaml configured entities etc.
@ -197,12 +185,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY}, context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
data={}, data={},
) )
hass.data[DATA_MQTT_RELOAD_NEEDED] = True mqtt_data.reload_needed = True
elif mqtt_entry_status is False: elif mqtt_entry_status is False:
_LOGGER.info( _LOGGER.info(
"MQTT will be not available until the config entry is enabled", "MQTT will be not available until the config entry is enabled",
) )
hass.data[DATA_MQTT_RELOAD_NEEDED] = True mqtt_data.reload_needed = True
return True return True
@ -260,33 +248,34 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
Causes for this is config entry options changing. Causes for this is config entry options changing.
""" """
mqtt_client = hass.data[DATA_MQTT] mqtt_data: MqttData = hass.data[DATA_MQTT]
assert (client := mqtt_data.client) is not None
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data)) conf = CONFIG_SCHEMA_BASE(dict(entry.data))
mqtt_client.conf = _merge_extended_config(entry, conf) mqtt_data.config = _merge_extended_config(entry, conf)
await mqtt_client.async_disconnect() await client.async_disconnect()
mqtt_client.init_client() client.init_client()
await mqtt_client.async_connect() await client.async_connect()
await discovery.async_stop(hass) await discovery.async_stop(hass)
if mqtt_client.conf.get(CONF_DISCOVERY): if client.conf.get(CONF_DISCOVERY):
await _async_setup_discovery(hass, mqtt_client.conf, entry) await _async_setup_discovery(hass, cast(ConfigType, mqtt_data.config), entry)
async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None: async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
"""Fetch fresh MQTT yaml config from the hass config when (re)loading the entry.""" """Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
if DATA_MQTT_RELOAD_ENTRY in hass.data: mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.reload_entry:
hass_config = await conf_util.async_hass_config_yaml(hass) hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
hass.data[DATA_MQTT_CONFIG] = mqtt_config
# Remove unknown keys from config entry data # Remove unknown keys from config entry data
_filter_entry_config(hass, entry) _filter_entry_config(hass, entry)
# Merge basic configuration, and add missing defaults for basic options # Merge basic configuration, and add missing defaults for basic options
_merge_basic_config(hass, entry, hass.data.get(DATA_MQTT_CONFIG, {})) _merge_basic_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing # Bail out if broker setting is missing
if CONF_BROKER not in entry.data: if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it") _LOGGER.error("MQTT broker is not configured, please configure it")
@ -294,7 +283,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
# If user doesn't have configuration.yaml config, generate default values # If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data # for options not in config entry data
if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_BASE(dict(entry.data)) conf = CONFIG_SCHEMA_BASE(dict(entry.data))
# User has configuration.yaml config, warn about config entry overrides # User has configuration.yaml config, warn about config entry overrides
@ -317,21 +306,20 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict |
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry.""" """Load a config entry."""
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
# Merge basic configuration, and add missing defaults for basic options # Merge basic configuration, and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None: if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out # Bail out
return False return False
mqtt_data.client = MQTT(hass, entry, conf)
hass.data[DATA_MQTT_DISCOVERY_REGISTRY_HOOKS] = {}
hass.data[DATA_MQTT] = MQTT(hass, entry, conf)
# Restore saved subscriptions # Restore saved subscriptions
if DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE in hass.data: if mqtt_data.subscriptions_to_restore:
hass.data[DATA_MQTT].subscriptions = hass.data.pop( mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE mqtt_data.subscriptions_to_restore = []
)
entry.add_update_listener(_async_config_entry_updated) entry.add_update_listener(_async_config_entry_updated)
await hass.data[DATA_MQTT].async_connect() await mqtt_data.client.async_connect()
async def async_publish_service(call: ServiceCall) -> None: async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls.""" """Handle MQTT publish service calls."""
@ -380,7 +368,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
return return
await hass.data[DATA_MQTT].async_publish(msg_topic, payload, qos, retain) assert mqtt_data.client is not None and msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
@ -421,7 +410,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
# setup platforms and discovery # setup platforms and discovery
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
async def async_setup_reload_service() -> None: async def async_setup_reload_service() -> None:
"""Create the reload service for the MQTT domain.""" """Create the reload service for the MQTT domain."""
@ -435,7 +423,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Reload the modern yaml platforms # Reload the modern yaml platforms
config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {} config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {}
hass.data[DATA_MQTT_UPDATED_CONFIG] = config_yaml.get(DOMAIN, {}) mqtt_data.updated_config = config_yaml.get(DOMAIN, {})
await asyncio.gather( await asyncio.gather(
*( *(
[ [
@ -476,13 +464,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Setup reload service after all platforms have loaded # Setup reload service after all platforms have loaded
await async_setup_reload_service() await async_setup_reload_service()
# When the entry is reloaded, also reload manual set up items to enable MQTT # When the entry is reloaded, also reload manual set up items to enable MQTT
if DATA_MQTT_RELOAD_ENTRY in hass.data: if mqtt_data.reload_entry:
hass.data.pop(DATA_MQTT_RELOAD_ENTRY) mqtt_data.reload_entry = False
reload_manual_setup = True reload_manual_setup = True
# When the entry was disabled before, reload manual set up items to enable MQTT again # When the entry was disabled before, reload manual set up items to enable MQTT again
if DATA_MQTT_RELOAD_NEEDED in hass.data: if mqtt_data.reload_needed:
hass.data.pop(DATA_MQTT_RELOAD_NEEDED) mqtt_data.reload_needed = False
reload_manual_setup = True reload_manual_setup = True
if reload_manual_setup: if reload_manual_setup:
@ -592,7 +580,9 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool: def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected.""" """Return if MQTT client is connected."""
return hass.data[DATA_MQTT].connected mqtt_data: MqttData = hass.data[DATA_MQTT]
assert mqtt_data.client is not None
return mqtt_data.client.connected
async def async_remove_config_entry_device( async def async_remove_config_entry_device(
@ -608,6 +598,10 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded.""" """Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client
# Unload publish and dump services. # Unload publish and dump services.
hass.services.async_remove( hass.services.async_remove(
DOMAIN, DOMAIN,
@ -620,7 +614,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Stop the discovery # Stop the discovery
await discovery.async_stop(hass) await discovery.async_stop(hass)
mqtt_client: MQTT = hass.data[DATA_MQTT]
# Unload the platforms # Unload the platforms
await asyncio.gather( await asyncio.gather(
*( *(
@ -630,26 +623,23 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
# Unsubscribe reload dispatchers # Unsubscribe reload dispatchers
while reload_dispatchers := hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []): while reload_dispatchers := mqtt_data.reload_dispatchers:
reload_dispatchers.pop()() reload_dispatchers.pop()()
hass.data[CONFIG_ENTRY_IS_SETUP] = set()
# Cleanup listeners # Cleanup listeners
mqtt_client.cleanup() mqtt_client.cleanup()
# Trigger reload manual MQTT items at entry setup # Trigger reload manual MQTT items at entry setup
if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is False: if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is False:
# The entry is disabled reload legacy manual items when the entry is enabled again # The entry is disabled reload legacy manual items when the entry is enabled again
hass.data[DATA_MQTT_RELOAD_NEEDED] = True mqtt_data.reload_needed = True
elif mqtt_entry_status is True: elif mqtt_entry_status is True:
# The entry is reloaded: # The entry is reloaded:
# Trigger re-fetching the yaml config at entry setup # Trigger re-fetching the yaml config at entry setup
hass.data[DATA_MQTT_RELOAD_ENTRY] = True mqtt_data.reload_entry = True
# Reload the legacy yaml platform to make entities unavailable # Reload the legacy yaml platform to make entities unavailable
await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS) await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS)
# Cleanup entity registry hooks # Cleanup entity registry hooks
registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[ registry_hooks = mqtt_data.discovery_registry_hooks
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS
]
while registry_hooks: while registry_hooks:
registry_hooks.popitem()[1]() registry_hooks.popitem()[1]()
# Wait for all ACKs and stop the loop # Wait for all ACKs and stop the loop
@ -657,6 +647,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Store remaining subscriptions to be able to restore or reload them # Store remaining subscriptions to be able to restore or reload them
# when the entry is set up again # when the entry is set up again
if mqtt_client.subscriptions: if mqtt_client.subscriptions:
hass.data[DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE] = mqtt_client.subscriptions mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions
return True return True

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
import inspect import inspect
from itertools import groupby from itertools import groupby
@ -17,6 +17,7 @@ import attr
import certifi import certifi
from paho.mqtt.client import MQTTMessage from paho.mqtt.client import MQTTMessage
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_CLIENT_ID, CONF_CLIENT_ID,
CONF_PASSWORD, CONF_PASSWORD,
@ -52,7 +53,6 @@ from .const import (
MQTT_DISCONNECTED, MQTT_DISCONNECTED,
PROTOCOL_31, PROTOCOL_31,
) )
from .discovery import LAST_DISCOVERY
from .models import ( from .models import (
AsyncMessageCallbackType, AsyncMessageCallbackType,
MessageCallbackType, MessageCallbackType,
@ -68,6 +68,9 @@ if TYPE_CHECKING:
# because integrations should be able to optionally rely on MQTT. # because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DISCOVERY_COOLDOWN = 2 DISCOVERY_COOLDOWN = 2
@ -97,8 +100,12 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> None: ) -> None:
"""Publish message to a MQTT topic.""" """Publish message to a MQTT topic."""
# Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass): mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled" f"Cannot publish to topic '{topic}', MQTT is not enabled"
) )
@ -126,11 +133,13 @@ async def async_publish(
) )
return return
await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain) await mqtt_data.client.async_publish(
topic, outgoing_payload, qos or 0, retain or False
)
AsyncDeprecatedMessageCallbackType = Callable[ AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Awaitable[None] [str, ReceivePayloadType, int], Coroutine[Any, Any, None]
] ]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
@ -175,13 +184,18 @@ async def async_subscribe(
| DeprecatedMessageCallbackType | DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType, | AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = "utf-8", encoding: str | None = DEFAULT_ENCODING,
): ):
"""Subscribe to an MQTT topic. """Subscribe to an MQTT topic.
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
if DATA_MQTT not in hass.data or not mqtt_config_entry_enabled(hass): # Local import to avoid circular dependencies
# pylint: disable-next=import-outside-toplevel
from .mixins import MqttData
mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled" f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
) )
@ -206,7 +220,7 @@ async def async_subscribe(
cast(DeprecatedMessageCallbackType, msg_callback) cast(DeprecatedMessageCallbackType, msg_callback)
) )
async_remove = await hass.data[DATA_MQTT].async_subscribe( async_remove = await mqtt_data.client.async_subscribe(
topic, topic,
catch_log_exception( catch_log_exception(
wrapped_msg_callback, wrapped_msg_callback,
@ -309,15 +323,17 @@ class MQTT:
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
config_entry, config_entry: ConfigEntry,
conf, conf: ConfigType,
) -> None: ) -> None:
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
# We don't import on the top because some integrations # We don't import on the top because some integrations
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
@ -635,7 +651,6 @@ class MQTT:
subscription.job, subscription.job,
) )
continue continue
self.hass.async_run_hass_job( self.hass.async_run_hass_job(
subscription.job, subscription.job,
ReceiveMessage( ReceiveMessage(
@ -695,10 +710,10 @@ class MQTT:
async def _discovery_cooldown(self): async def _discovery_cooldown(self):
now = time.time() now = time.time()
# Reset discovery and subscribe cooldowns # Reset discovery and subscribe cooldowns
self.hass.data[LAST_DISCOVERY] = now self._mqtt_data.last_discovery = now
self._last_subscribe = now self._last_subscribe = now
last_discovery = self.hass.data[LAST_DISCOVERY] last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe last_subscribe = self._last_subscribe
wait_until = max( wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
@ -706,7 +721,7 @@ class MQTT:
while now < wait_until: while now < wait_until:
await asyncio.sleep(wait_until - now) await asyncio.sleep(wait_until - now)
now = time.time() now = time.time()
last_discovery = self.hass.data[LAST_DISCOVERY] last_discovery = self._mqtt_data.last_discovery
last_subscribe = self._last_subscribe last_subscribe = self._last_subscribe
wait_until = max( wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN

View file

@ -18,7 +18,7 @@ from homeassistant.const import (
CONF_PROTOCOL, CONF_PROTOCOL,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from .client import MqttClientSetup from .client import MqttClientSetup
@ -30,12 +30,13 @@ from .const import (
CONF_BIRTH_MESSAGE, CONF_BIRTH_MESSAGE,
CONF_BROKER, CONF_BROKER,
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
DATA_MQTT_CONFIG, DATA_MQTT,
DEFAULT_BIRTH, DEFAULT_BIRTH,
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
DEFAULT_WILL, DEFAULT_WILL,
DOMAIN, DOMAIN,
) )
from .mixins import MqttData
from .util import MQTT_WILL_BIRTH_SCHEMA from .util import MQTT_WILL_BIRTH_SCHEMA
MQTT_TIMEOUT = 5 MQTT_TIMEOUT = 5
@ -164,9 +165,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT broker configuration.""" """Manage the MQTT broker configuration."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {}) yaml_config = mqtt_data.config or {}
if user_input is not None: if user_input is not None:
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
@ -214,9 +216,10 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT options.""" """Manage the MQTT options."""
mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData())
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = self.hass.data.get(DATA_MQTT_CONFIG, {}) yaml_config = mqtt_data.config or {}
options_config: dict[str, Any] = {} options_config: dict[str, Any] = {}
if user_input is not None: if user_input is not None:
bad_birth = False bad_birth = False
@ -334,14 +337,22 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
) )
def try_connection(hass, broker, port, username, password, protocol="3.1"): def try_connection(
hass: HomeAssistant,
broker: str,
port: int,
username: str | None,
password: str | None,
protocol: str = "3.1",
) -> bool:
"""Test if we can connect to an MQTT broker.""" """Test if we can connect to an MQTT broker."""
# We don't import on the top because some integrations # We don't import on the top because some integrations
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml # Get the config from configuration.yaml
yaml_config = hass.data.get(DATA_MQTT_CONFIG, {}) mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData())
yaml_config = mqtt_data.config or {}
entry_config = { entry_config = {
CONF_BROKER: broker, CONF_BROKER: broker,
CONF_PORT: port, CONF_PORT: port,
@ -351,7 +362,7 @@ def try_connection(hass, broker, port, username, password, protocol="3.1"):
} }
client = MqttClientSetup({**yaml_config, **entry_config}).client client = MqttClientSetup({**yaml_config, **entry_config}).client
result = queue.Queue(maxsize=1) result: queue.Queue[bool] = queue.Queue(maxsize=1)
def on_connect(client_, userdata, flags, result_code): def on_connect(client_, userdata, flags, result_code):
"""Handle connection result.""" """Handle connection result."""

View file

@ -30,16 +30,8 @@ CONF_CLIENT_CERT = "client_cert"
CONF_TLS_INSECURE = "tls_insecure" CONF_TLS_INSECURE = "tls_insecure"
CONF_TLS_VERSION = "tls_version" CONF_TLS_VERSION = "tls_version"
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
DATA_MQTT = "mqtt" DATA_MQTT = "mqtt"
DATA_MQTT_SUBSCRIPTIONS_TO_RESTORE = "mqtt_client_subscriptions"
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS = "mqtt_discovery_registry_hooks"
DATA_MQTT_CONFIG = "mqtt_config"
MQTT_DATA_DEVICE_TRACKER_LEGACY = "mqtt_device_tracker_legacy" MQTT_DATA_DEVICE_TRACKER_LEGACY = "mqtt_device_tracker_legacy"
DATA_MQTT_RELOAD_DISPATCHERS = "mqtt_reload_dispatchers"
DATA_MQTT_RELOAD_ENTRY = "mqtt_reload_entry"
DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed"
DATA_MQTT_UPDATED_CONFIG = "mqtt_updated_config"
DEFAULT_PREFIX = "homeassistant" DEFAULT_PREFIX = "homeassistant"
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status" DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"

View file

@ -33,11 +33,13 @@ from .const import (
CONF_PAYLOAD, CONF_PAYLOAD,
CONF_QOS, CONF_QOS,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT,
DOMAIN, DOMAIN,
) )
from .discovery import MQTT_DISCOVERY_DONE from .discovery import MQTT_DISCOVERY_DONE
from .mixins import ( from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA, MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttData,
MqttDiscoveryDeviceUpdate, MqttDiscoveryDeviceUpdate,
send_discovery_done, send_discovery_done,
update_device, update_device,
@ -81,8 +83,6 @@ TRIGGER_DISCOVERY_SCHEMA = MQTT_BASE_SCHEMA.extend(
extra=vol.REMOVE_EXTRA, extra=vol.REMOVE_EXTRA,
) )
DEVICE_TRIGGERS = "mqtt_device_triggers"
LOG_NAME = "Device trigger" LOG_NAME = "Device trigger"
@ -203,6 +203,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.device_id = device_id self.device_id = device_id
self.discovery_data = discovery_data self.discovery_data = discovery_data
self.hass = hass self.hass = hass
self._mqtt_data: MqttData = hass.data[DATA_MQTT]
MqttDiscoveryDeviceUpdate.__init__( MqttDiscoveryDeviceUpdate.__init__(
self, self,
@ -217,8 +218,8 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
"""Initialize the device trigger.""" """Initialize the device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}): if discovery_id not in self._mqtt_data.device_triggers:
self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( self._mqtt_data.device_triggers[discovery_id] = Trigger(
hass=self.hass, hass=self.hass,
device_id=self.device_id, device_id=self.device_id,
discovery_data=self.discovery_data, discovery_data=self.discovery_data,
@ -230,7 +231,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
value_template=self._config[CONF_VALUE_TEMPLATE], value_template=self._config[CONF_VALUE_TEMPLATE],
) )
else: else:
await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger( await self._mqtt_data.device_triggers[discovery_id].update_trigger(
self._config self._config
) )
debug_info.add_trigger_discovery_data( debug_info.add_trigger_discovery_data(
@ -246,16 +247,16 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
) )
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data) config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
update_device(self.hass, self._config_entry, config) update_device(self.hass, self._config_entry, config)
device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] device_trigger: Trigger = self._mqtt_data.device_triggers[discovery_id]
await device_trigger.update_trigger(config) await device_trigger.update_trigger(config)
async def async_tear_down(self) -> None: async def async_tear_down(self) -> None:
"""Cleanup device trigger.""" """Cleanup device trigger."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1] discovery_id = discovery_hash[1]
if discovery_id in self.hass.data[DEVICE_TRIGGERS]: if discovery_id in self._mqtt_data.device_triggers:
_LOGGER.info("Removing trigger: %s", discovery_hash) _LOGGER.info("Removing trigger: %s", discovery_hash)
trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id] trigger: Trigger = self._mqtt_data.device_triggers[discovery_id]
trigger.detach_trigger() trigger.detach_trigger()
debug_info.remove_trigger_discovery_data(self.hass, discovery_hash) debug_info.remove_trigger_discovery_data(self.hass, discovery_hash)
@ -280,11 +281,10 @@ async def async_setup_trigger(
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device.""" """Handle Mqtt removed from a device."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
triggers = await async_get_triggers(hass, device_id) triggers = await async_get_triggers(hass, device_id)
for trig in triggers: for trig in triggers:
device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop( device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID])
trig[CONF_DISCOVERY_ID]
)
if device_trigger: if device_trigger:
device_trigger.detach_trigger() device_trigger.detach_trigger()
discovery_data = cast(dict, device_trigger.discovery_data) discovery_data = cast(dict, device_trigger.discovery_data)
@ -296,12 +296,13 @@ async def async_get_triggers(
hass: HomeAssistant, device_id: str hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""List device triggers for MQTT devices.""" """List device triggers for MQTT devices."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
triggers: list[dict[str, str]] = [] triggers: list[dict[str, str]] = []
if DEVICE_TRIGGERS not in hass.data: if not mqtt_data.device_triggers:
return triggers return triggers
for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items(): for discovery_id, trig in mqtt_data.device_triggers.items():
if trig.device_id != device_id or trig.topic is None: if trig.device_id != device_id or trig.topic is None:
continue continue
@ -324,12 +325,12 @@ async def async_attach_trigger(
trigger_info: TriggerInfo, trigger_info: TriggerInfo,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Attach a trigger.""" """Attach a trigger."""
hass.data.setdefault(DEVICE_TRIGGERS, {}) mqtt_data: MqttData = hass.data[DATA_MQTT]
device_id = config[CONF_DEVICE_ID] device_id = config[CONF_DEVICE_ID]
discovery_id = config[CONF_DISCOVERY_ID] discovery_id = config[CONF_DISCOVERY_ID]
if discovery_id not in hass.data[DEVICE_TRIGGERS]: if discovery_id not in mqtt_data.device_triggers:
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger( mqtt_data.device_triggers[discovery_id] = Trigger(
hass=hass, hass=hass,
device_id=device_id, device_id=device_id,
discovery_data=None, discovery_data=None,
@ -340,6 +341,6 @@ async def async_attach_trigger(
qos=None, qos=None,
value_template=None, value_template=None,
) )
return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger( return await mqtt_data.device_triggers[discovery_id].add_trigger(
action, trigger_info action, trigger_info
) )

View file

@ -43,7 +43,7 @@ def _async_get_diagnostics(
device: DeviceEntry | None = None, device: DeviceEntry | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
mqtt_instance: MQTT = hass.data[DATA_MQTT] mqtt_instance: MQTT = hass.data[DATA_MQTT].client
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG) redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)

View file

@ -7,6 +7,7 @@ import functools
import logging import logging
import re import re
import time import time
from typing import TYPE_CHECKING
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -28,9 +29,13 @@ from .const import (
ATTR_DISCOVERY_TOPIC, ATTR_DISCOVERY_TOPIC,
CONF_AVAILABILITY, CONF_AVAILABILITY,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT,
DOMAIN, DOMAIN,
) )
if TYPE_CHECKING:
from .mixins import MqttData
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
TOPIC_MATCHER = re.compile( TOPIC_MATCHER = re.compile(
@ -69,7 +74,6 @@ INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
LAST_DISCOVERY = "mqtt_last_discovery"
TOPIC_BASE = "~" TOPIC_BASE = "~"
@ -80,12 +84,12 @@ class MQTTConfig(dict):
discovery_data: dict discovery_data: dict
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None: def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
"""Clear entry in ALREADY_DISCOVERED list.""" """Clear entry in ALREADY_DISCOVERED list."""
del hass.data[ALREADY_DISCOVERED][discovery_hash] del hass.data[ALREADY_DISCOVERED][discovery_hash]
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple): def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]):
"""Clear entry in ALREADY_DISCOVERED list.""" """Clear entry in ALREADY_DISCOVERED list."""
hass.data[ALREADY_DISCOVERED][discovery_hash] = {} hass.data[ALREADY_DISCOVERED][discovery_hash] = {}
@ -94,11 +98,12 @@ async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None hass: HomeAssistant, discovery_topic, config_entry=None
) -> None: ) -> None:
"""Start MQTT Discovery.""" """Start MQTT Discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
mqtt_integrations = {} mqtt_integrations = {}
async def async_discovery_message_received(msg): async def async_discovery_message_received(msg):
"""Process the received message.""" """Process the received message."""
hass.data[LAST_DISCOVERY] = time.time() mqtt_data.last_discovery = time.time()
payload = msg.payload payload = msg.payload
topic = msg.topic topic = msg.topic
topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1) topic_trimmed = topic.replace(f"{discovery_topic}/", "", 1)
@ -253,7 +258,7 @@ async def async_start( # noqa: C901
) )
) )
hass.data[LAST_DISCOVERY] = time.time() mqtt_data.last_discovery = time.time()
mqtt_integrations = await async_get_mqtt(hass) mqtt_integrations = await async_get_mqtt(hass)
hass.data[INTEGRATION_UNSUBSCRIBE] = {} hass.data[INTEGRATION_UNSUBSCRIBE] = {}

View file

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from functools import partial from functools import partial
import logging import logging
from typing import Any, Protocol, cast, final from typing import TYPE_CHECKING, Any, Protocol, cast, final
import voluptuous as vol import voluptuous as vol
@ -60,7 +61,7 @@ from homeassistant.helpers.json import json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription from . import debug_info, subscription
from .client import async_publish from .client import MQTT, Subscription, async_publish
from .const import ( from .const import (
ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_PAYLOAD,
@ -70,11 +71,6 @@ from .const import (
CONF_QOS, CONF_QOS,
CONF_TOPIC, CONF_TOPIC,
DATA_MQTT, DATA_MQTT,
DATA_MQTT_CONFIG,
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
DATA_MQTT_RELOAD_DISPATCHERS,
DATA_MQTT_RELOAD_ENTRY,
DATA_MQTT_UPDATED_CONFIG,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE,
@ -98,6 +94,9 @@ from .subscription import (
) )
from .util import mqtt_config_entry_enabled, valid_subscribe_topic from .util import mqtt_config_entry_enabled, valid_subscribe_topic
if TYPE_CHECKING:
from .device_trigger import Trigger
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
AVAILABILITY_ALL = "all" AVAILABILITY_ALL = "all"
@ -274,6 +273,24 @@ def warn_for_legacy_schema(domain: str) -> Callable:
return validator return validator
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)
class SetupEntity(Protocol): class SetupEntity(Protocol):
"""Protocol type for async_setup_entities.""" """Protocol type for async_setup_entities."""
@ -292,11 +309,12 @@ async def async_discover_yaml_entities(
hass: HomeAssistant, platform_domain: str hass: HomeAssistant, platform_domain: str
) -> None: ) -> None:
"""Discover entities for a platform.""" """Discover entities for a platform."""
if DATA_MQTT_UPDATED_CONFIG in hass.data: mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.updated_config:
# The platform has been reloaded # The platform has been reloaded
config_yaml = hass.data[DATA_MQTT_UPDATED_CONFIG] config_yaml = mqtt_data.updated_config
else: else:
config_yaml = hass.data.get(DATA_MQTT_CONFIG, {}) config_yaml = mqtt_data.config or {}
if not config_yaml: if not config_yaml:
return return
if platform_domain not in config_yaml: if platform_domain not in config_yaml:
@ -318,8 +336,9 @@ async def async_get_platform_config_from_yaml(
) -> list[ConfigType]: ) -> list[ConfigType]:
"""Return a list of validated configurations for the domain.""" """Return a list of validated configurations for the domain."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
if config_yaml is None: if config_yaml is None:
config_yaml = hass.data.get(DATA_MQTT_CONFIG) config_yaml = mqtt_data.config
if not config_yaml: if not config_yaml:
return [] return []
if not (platform_configs := config_yaml.get(platform_domain)): if not (platform_configs := config_yaml.get(platform_domain)):
@ -334,6 +353,7 @@ async def async_setup_entry_helper(
schema: vol.Schema, schema: vol.Schema,
) -> None: ) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery.""" """Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
async def async_discover(discovery_payload): async def async_discover(discovery_payload):
"""Discover and add an MQTT entity, automation or tag.""" """Discover and add an MQTT entity, automation or tag."""
@ -357,7 +377,7 @@ async def async_setup_entry_helper(
) )
raise raise
hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []).append( mqtt_data.reload_dispatchers.append(
async_dispatcher_connect( async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover
) )
@ -372,7 +392,8 @@ async def async_setup_platform_helper(
async_setup_entities: SetupEntity, async_setup_entities: SetupEntity,
) -> None: ) -> None:
"""Help to set up the platform for manual configured MQTT entities.""" """Help to set up the platform for manual configured MQTT entities."""
if DATA_MQTT_RELOAD_ENTRY in hass.data: mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.reload_entry:
_LOGGER.debug( _LOGGER.debug(
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry", "MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
platform_domain, platform_domain,
@ -597,7 +618,10 @@ class MqttAvailability(Entity):
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return if the device is available.""" """Return if the device is available."""
if not self.hass.data[DATA_MQTT].connected and not self.hass.is_stopping: mqtt_data: MqttData = self.hass.data[DATA_MQTT]
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False return False
if not self._avail_topics: if not self._avail_topics:
return True return True
@ -632,7 +656,7 @@ async def cleanup_device_registry(
) )
def get_discovery_hash(discovery_data: dict) -> tuple: def get_discovery_hash(discovery_data: dict) -> tuple[str, str]:
"""Get the discovery hash from the discovery data.""" """Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH] return discovery_data[ATTR_DISCOVERY_HASH]
@ -817,9 +841,8 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False self._removed_from_hass = False
if discovery_data is None: if discovery_data is None:
return return
self._registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[ mqtt_data: MqttData = hass.data[DATA_MQTT]
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS self._registry_hooks = mqtt_data.discovery_registry_hooks
]
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks: if discovery_hash in self._registry_hooks:
self._registry_hooks.pop(discovery_hash)() self._registry_hooks.pop(discovery_hash)()
@ -897,7 +920,7 @@ class MqttDiscoveryUpdate(Entity):
def add_to_platform_abort(self) -> None: def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform.""" """Abort adding an entity to a platform."""
if self._discovery_data is not None: if self._discovery_data is not None:
discovery_hash: tuple = self._discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
if self.registry_entry is not None: if self.registry_entry is not None:
self._registry_hooks[ self._registry_hooks[
discovery_hash discovery_hash

View file

@ -369,7 +369,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
if isinstance(payload, str): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
msg = ReceiveMessage(topic, payload, qos, retain) msg = ReceiveMessage(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg) hass.data["mqtt"].client._mqtt_handle_message(msg)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)

View file

@ -155,7 +155,7 @@ async def test_manual_config_set(
assert await async_setup_component(hass, "mqtt", {"mqtt": {"broker": "bla"}}) assert await async_setup_component(hass, "mqtt", {"mqtt": {"broker": "bla"}})
await hass.async_block_till_done() await hass.async_block_till_done()
# do not try to reload # do not try to reload
del hass.data["mqtt_reload_needed"] hass.data["mqtt"].reload_needed = False
assert len(mock_finish_setup.mock_calls) == 0 assert len(mock_finish_setup.mock_calls) == 0
mock_try_connection.return_value = True mock_try_connection.return_value = True

View file

@ -1438,7 +1438,7 @@ async def test_clean_up_registry_monitoring(
): ):
"""Test registry monitoring hook is removed after a reload.""" """Test registry monitoring hook is removed after a reload."""
await mqtt_mock_entry_no_yaml_config() await mqtt_mock_entry_no_yaml_config()
hooks: dict = hass.data[mqtt.const.DATA_MQTT_DISCOVERY_REGISTRY_HOOKS] hooks: dict = hass.data["mqtt"].discovery_registry_hooks
# discover an entity that is not enabled by default # discover an entity that is not enabled by default
config1 = { config1 = {
"name": "sbfspot_12345", "name": "sbfspot_12345",

View file

@ -1776,14 +1776,14 @@ async def test_delayed_birth_message(
await hass.async_block_till_done() await hass.async_block_till_done()
mqtt_component_mock = MagicMock( mqtt_component_mock = MagicMock(
return_value=hass.data["mqtt"], return_value=hass.data["mqtt"].client,
spec_set=hass.data["mqtt"], spec_set=hass.data["mqtt"].client,
wraps=hass.data["mqtt"], wraps=hass.data["mqtt"].client,
) )
mqtt_component_mock._mqttc = mqtt_client_mock mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock hass.data["mqtt"].client = mqtt_component_mock
mqtt_mock = hass.data["mqtt"] mqtt_mock = hass.data["mqtt"].client
mqtt_mock.reset_mock() mqtt_mock.reset_mock()
async def wait_birth(topic, payload, qos): async def wait_birth(topic, payload, qos):