Make hass.data["mqtt"] an instance of a DataClass (#77972)
* Use dataclass to reference hass.data globals * Add discovery_registry_hooks to dataclass * Move discovery registry hooks to dataclass * Add device triggers to dataclass * Cleanup DEVICE_TRIGGERS const * Add last_discovery to data_class * Simplify typing for class `Subscription` * Follow up on comment * Redo suggested typing change to sasisfy mypy * Restore typing * Add mypy version to CI check logging * revert changes to ci.yaml * Add docstr for protocol Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> * Mypy update after merging #78399 * Remove mypy ignore * Correct return type Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
391d895426
commit
1f410e884a
12 changed files with 174 additions and 137 deletions
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
|||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, Protocol, cast, final
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast, final
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -60,7 +61,7 @@ from homeassistant.helpers.json import json_loads
|
|||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import debug_info, subscription
|
||||
from .client import async_publish
|
||||
from .client import MQTT, Subscription, async_publish
|
||||
from .const import (
|
||||
ATTR_DISCOVERY_HASH,
|
||||
ATTR_DISCOVERY_PAYLOAD,
|
||||
|
@ -70,11 +71,6 @@ from .const import (
|
|||
CONF_QOS,
|
||||
CONF_TOPIC,
|
||||
DATA_MQTT,
|
||||
DATA_MQTT_CONFIG,
|
||||
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
|
||||
DATA_MQTT_RELOAD_DISPATCHERS,
|
||||
DATA_MQTT_RELOAD_ENTRY,
|
||||
DATA_MQTT_UPDATED_CONFIG,
|
||||
DEFAULT_ENCODING,
|
||||
DEFAULT_PAYLOAD_AVAILABLE,
|
||||
DEFAULT_PAYLOAD_NOT_AVAILABLE,
|
||||
|
@ -98,6 +94,9 @@ from .subscription import (
|
|||
)
|
||||
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device_trigger import Trigger
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
AVAILABILITY_ALL = "all"
|
||||
|
@ -274,6 +273,24 @@ def warn_for_legacy_schema(domain: str) -> Callable:
|
|||
return validator
|
||||
|
||||
|
||||
@dataclass
|
||||
class MqttData:
|
||||
"""Keep the MQTT entry data."""
|
||||
|
||||
client: MQTT | None = None
|
||||
config: ConfigType | None = None
|
||||
device_triggers: dict[str, Trigger] = field(default_factory=dict)
|
||||
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
last_discovery: float = 0.0
|
||||
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
|
||||
reload_entry: bool = False
|
||||
reload_needed: bool = False
|
||||
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
|
||||
updated_config: ConfigType = field(default_factory=dict)
|
||||
|
||||
|
||||
class SetupEntity(Protocol):
|
||||
"""Protocol type for async_setup_entities."""
|
||||
|
||||
|
@ -292,11 +309,12 @@ async def async_discover_yaml_entities(
|
|||
hass: HomeAssistant, platform_domain: str
|
||||
) -> None:
|
||||
"""Discover entities for a platform."""
|
||||
if DATA_MQTT_UPDATED_CONFIG in hass.data:
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
if mqtt_data.updated_config:
|
||||
# The platform has been reloaded
|
||||
config_yaml = hass.data[DATA_MQTT_UPDATED_CONFIG]
|
||||
config_yaml = mqtt_data.updated_config
|
||||
else:
|
||||
config_yaml = hass.data.get(DATA_MQTT_CONFIG, {})
|
||||
config_yaml = mqtt_data.config or {}
|
||||
if not config_yaml:
|
||||
return
|
||||
if platform_domain not in config_yaml:
|
||||
|
@ -318,8 +336,9 @@ async def async_get_platform_config_from_yaml(
|
|||
) -> list[ConfigType]:
|
||||
"""Return a list of validated configurations for the domain."""
|
||||
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
if config_yaml is None:
|
||||
config_yaml = hass.data.get(DATA_MQTT_CONFIG)
|
||||
config_yaml = mqtt_data.config
|
||||
if not config_yaml:
|
||||
return []
|
||||
if not (platform_configs := config_yaml.get(platform_domain)):
|
||||
|
@ -334,6 +353,7 @@ async def async_setup_entry_helper(
|
|||
schema: vol.Schema,
|
||||
) -> None:
|
||||
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
|
||||
async def async_discover(discovery_payload):
|
||||
"""Discover and add an MQTT entity, automation or tag."""
|
||||
|
@ -357,7 +377,7 @@ async def async_setup_entry_helper(
|
|||
)
|
||||
raise
|
||||
|
||||
hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []).append(
|
||||
mqtt_data.reload_dispatchers.append(
|
||||
async_dispatcher_connect(
|
||||
hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover
|
||||
)
|
||||
|
@ -372,7 +392,8 @@ async def async_setup_platform_helper(
|
|||
async_setup_entities: SetupEntity,
|
||||
) -> None:
|
||||
"""Help to set up the platform for manual configured MQTT entities."""
|
||||
if DATA_MQTT_RELOAD_ENTRY in hass.data:
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
if mqtt_data.reload_entry:
|
||||
_LOGGER.debug(
|
||||
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
|
||||
platform_domain,
|
||||
|
@ -597,7 +618,10 @@ class MqttAvailability(Entity):
|
|||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if the device is available."""
|
||||
if not self.hass.data[DATA_MQTT].connected and not self.hass.is_stopping:
|
||||
mqtt_data: MqttData = self.hass.data[DATA_MQTT]
|
||||
assert mqtt_data.client is not None
|
||||
client = mqtt_data.client
|
||||
if not client.connected and not self.hass.is_stopping:
|
||||
return False
|
||||
if not self._avail_topics:
|
||||
return True
|
||||
|
@ -632,7 +656,7 @@ async def cleanup_device_registry(
|
|||
)
|
||||
|
||||
|
||||
def get_discovery_hash(discovery_data: dict) -> tuple:
|
||||
def get_discovery_hash(discovery_data: dict) -> tuple[str, str]:
|
||||
"""Get the discovery hash from the discovery data."""
|
||||
return discovery_data[ATTR_DISCOVERY_HASH]
|
||||
|
||||
|
@ -817,9 +841,8 @@ class MqttDiscoveryUpdate(Entity):
|
|||
self._removed_from_hass = False
|
||||
if discovery_data is None:
|
||||
return
|
||||
self._registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[
|
||||
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS
|
||||
]
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
self._registry_hooks = mqtt_data.discovery_registry_hooks
|
||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
if discovery_hash in self._registry_hooks:
|
||||
self._registry_hooks.pop(discovery_hash)()
|
||||
|
@ -897,7 +920,7 @@ class MqttDiscoveryUpdate(Entity):
|
|||
def add_to_platform_abort(self) -> None:
|
||||
"""Abort adding an entity to a platform."""
|
||||
if self._discovery_data is not None:
|
||||
discovery_hash: tuple = self._discovery_data[ATTR_DISCOVERY_HASH]
|
||||
discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
|
||||
if self.registry_entry is not None:
|
||||
self._registry_hooks[
|
||||
discovery_hash
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue