Make hass.data["mqtt"] an instance of a DataClass (#77972)

* Use dataclass to reference hass.data globals

* Add discovery_registry_hooks to dataclass

* Move discovery registry hooks to dataclass

* Add device triggers to dataclass

* Cleanup DEVICE_TRIGGERS const

* Add last_discovery to data_class

* Simplify typing for class `Subscription`

* Follow up on comment

* Redo suggested typing change to sasisfy mypy

* Restore typing

* Add mypy version to CI check logging

* revert changes to ci.yaml

* Add docstr for protocol

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>

* Mypy update after merging #78399

* Remove mypy ignore

* Correct return type

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Jan Bouwhuis 2022-09-17 21:43:42 +02:00 committed by GitHub
parent 391d895426
commit 1f410e884a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 174 additions and 137 deletions

View file

@ -4,9 +4,10 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from functools import partial
import logging
from typing import Any, Protocol, cast, final
from typing import TYPE_CHECKING, Any, Protocol, cast, final
import voluptuous as vol
@ -60,7 +61,7 @@ from homeassistant.helpers.json import json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription
from .client import async_publish
from .client import MQTT, Subscription, async_publish
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD,
@ -70,11 +71,6 @@ from .const import (
CONF_QOS,
CONF_TOPIC,
DATA_MQTT,
DATA_MQTT_CONFIG,
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS,
DATA_MQTT_RELOAD_DISPATCHERS,
DATA_MQTT_RELOAD_ENTRY,
DATA_MQTT_UPDATED_CONFIG,
DEFAULT_ENCODING,
DEFAULT_PAYLOAD_AVAILABLE,
DEFAULT_PAYLOAD_NOT_AVAILABLE,
@ -98,6 +94,9 @@ from .subscription import (
)
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
if TYPE_CHECKING:
from .device_trigger import Trigger
_LOGGER = logging.getLogger(__name__)
AVAILABILITY_ALL = "all"
@ -274,6 +273,24 @@ def warn_for_legacy_schema(domain: str) -> Callable:
return validator
@dataclass
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
device_triggers: dict[str, Trigger] = field(default_factory=dict)
discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field(
default_factory=dict
)
last_discovery: float = 0.0
reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list)
reload_entry: bool = False
reload_needed: bool = False
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
updated_config: ConfigType = field(default_factory=dict)
class SetupEntity(Protocol):
"""Protocol type for async_setup_entities."""
@ -292,11 +309,12 @@ async def async_discover_yaml_entities(
hass: HomeAssistant, platform_domain: str
) -> None:
"""Discover entities for a platform."""
if DATA_MQTT_UPDATED_CONFIG in hass.data:
mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.updated_config:
# The platform has been reloaded
config_yaml = hass.data[DATA_MQTT_UPDATED_CONFIG]
config_yaml = mqtt_data.updated_config
else:
config_yaml = hass.data.get(DATA_MQTT_CONFIG, {})
config_yaml = mqtt_data.config or {}
if not config_yaml:
return
if platform_domain not in config_yaml:
@ -318,8 +336,9 @@ async def async_get_platform_config_from_yaml(
) -> list[ConfigType]:
"""Return a list of validated configurations for the domain."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
if config_yaml is None:
config_yaml = hass.data.get(DATA_MQTT_CONFIG)
config_yaml = mqtt_data.config
if not config_yaml:
return []
if not (platform_configs := config_yaml.get(platform_domain)):
@ -334,6 +353,7 @@ async def async_setup_entry_helper(
schema: vol.Schema,
) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data: MqttData = hass.data[DATA_MQTT]
async def async_discover(discovery_payload):
"""Discover and add an MQTT entity, automation or tag."""
@ -357,7 +377,7 @@ async def async_setup_entry_helper(
)
raise
hass.data.setdefault(DATA_MQTT_RELOAD_DISPATCHERS, []).append(
mqtt_data.reload_dispatchers.append(
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW.format(domain, "mqtt"), async_discover
)
@ -372,7 +392,8 @@ async def async_setup_platform_helper(
async_setup_entities: SetupEntity,
) -> None:
"""Help to set up the platform for manual configured MQTT entities."""
if DATA_MQTT_RELOAD_ENTRY in hass.data:
mqtt_data: MqttData = hass.data[DATA_MQTT]
if mqtt_data.reload_entry:
_LOGGER.debug(
"MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry",
platform_domain,
@ -597,7 +618,10 @@ class MqttAvailability(Entity):
@property
def available(self) -> bool:
"""Return if the device is available."""
if not self.hass.data[DATA_MQTT].connected and not self.hass.is_stopping:
mqtt_data: MqttData = self.hass.data[DATA_MQTT]
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False
if not self._avail_topics:
return True
@ -632,7 +656,7 @@ async def cleanup_device_registry(
)
def get_discovery_hash(discovery_data: dict) -> tuple:
def get_discovery_hash(discovery_data: dict) -> tuple[str, str]:
"""Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH]
@ -817,9 +841,8 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = False
if discovery_data is None:
return
self._registry_hooks: dict[tuple, CALLBACK_TYPE] = hass.data[
DATA_MQTT_DISCOVERY_REGISTRY_HOOKS
]
mqtt_data: MqttData = hass.data[DATA_MQTT]
self._registry_hooks = mqtt_data.discovery_registry_hooks
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if discovery_hash in self._registry_hooks:
self._registry_hooks.pop(discovery_hash)()
@ -897,7 +920,7 @@ class MqttDiscoveryUpdate(Entity):
def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform."""
if self._discovery_data is not None:
discovery_hash: tuple = self._discovery_data[ATTR_DISCOVERY_HASH]
discovery_hash: tuple[str, str] = self._discovery_data[ATTR_DISCOVERY_HASH]
if self.registry_entry is not None:
self._registry_hooks[
discovery_hash