Refactor MQTT to replace get_mqtt_data with HassKey (#117899)
This commit is contained in:
parent
b4d0562063
commit
4e3c4400a7
14 changed files with 74 additions and 74 deletions
|
@ -65,8 +65,6 @@ from .const import ( # noqa: F401
|
|||
CONF_WILL_MESSAGE,
|
||||
CONF_WS_HEADERS,
|
||||
CONF_WS_PATH,
|
||||
DATA_MQTT,
|
||||
DATA_MQTT_AVAILABLE,
|
||||
DEFAULT_DISCOVERY,
|
||||
DEFAULT_ENCODING,
|
||||
DEFAULT_PREFIX,
|
||||
|
@ -79,6 +77,8 @@ from .const import ( # noqa: F401
|
|||
TEMPLATE_ERRORS,
|
||||
)
|
||||
from .models import ( # noqa: F401
|
||||
DATA_MQTT,
|
||||
DATA_MQTT_AVAILABLE,
|
||||
MqttCommandTemplate,
|
||||
MqttData,
|
||||
MqttValueTemplate,
|
||||
|
@ -97,7 +97,6 @@ from .util import ( # noqa: F401
|
|||
async_create_certificate_temp_files,
|
||||
async_forward_entry_setup_and_setup_discovery,
|
||||
async_wait_for_mqtt_client,
|
||||
get_mqtt_data,
|
||||
mqtt_config_entry_enabled,
|
||||
platforms_from_config,
|
||||
valid_publish_topic,
|
||||
|
@ -194,7 +193,7 @@ async def async_check_config_schema(
|
|||
hass: HomeAssistant, config_yaml: ConfigType
|
||||
) -> None:
|
||||
"""Validate manually configured MQTT items."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {})
|
||||
for mqtt_config_item in mqtt_config:
|
||||
for domain, config_items in mqtt_config_item.items():
|
||||
|
@ -233,7 +232,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
await async_create_certificate_temp_files(hass, conf)
|
||||
client = MQTT(hass, entry, conf)
|
||||
if DOMAIN in hass.data:
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
mqtt_data.config = mqtt_yaml
|
||||
mqtt_data.client = client
|
||||
else:
|
||||
|
@ -241,7 +240,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
websocket_api.async_register_command(hass, websocket_subscribe)
|
||||
websocket_api.async_register_command(hass, websocket_mqtt_info)
|
||||
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
|
||||
get_mqtt_data.cache_clear()
|
||||
client.start(mqtt_data)
|
||||
|
||||
# Restore saved subscriptions
|
||||
|
@ -503,7 +501,7 @@ def async_subscribe_connection_status(
|
|||
|
||||
def is_connected(hass: HomeAssistant) -> bool:
|
||||
"""Return if MQTT client is connected."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
return mqtt_data.client.connected
|
||||
|
||||
|
||||
|
@ -520,7 +518,7 @@ async def async_remove_config_entry_device(
|
|||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload MQTT dump and publish service when the config entry is unloaded."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
mqtt_client = mqtt_data.client
|
||||
|
||||
# Unload publish and dump services.
|
||||
|
|
|
@ -66,6 +66,7 @@ from .const import (
|
|||
TRANSPORT_WEBSOCKETS,
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
AsyncMessageCallbackType,
|
||||
MessageCallbackType,
|
||||
MqttData,
|
||||
|
@ -73,7 +74,7 @@ from .models import (
|
|||
PublishPayloadType,
|
||||
ReceiveMessage,
|
||||
)
|
||||
from .util import get_file_path, get_mqtt_data, mqtt_config_entry_enabled
|
||||
from .util import get_file_path, mqtt_config_entry_enabled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only import for paho-mqtt type checking here, imports are done locally
|
||||
|
@ -132,7 +133,7 @@ async def async_publish(
|
|||
translation_domain=DOMAIN,
|
||||
translation_placeholders={"topic": topic},
|
||||
)
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
outgoing_payload = payload
|
||||
if not isinstance(payload, bytes):
|
||||
if not encoding:
|
||||
|
@ -186,7 +187,7 @@ async def async_subscribe(
|
|||
translation_placeholders={"topic": topic},
|
||||
)
|
||||
try:
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
except KeyError as exc:
|
||||
raise HomeAssistantError(
|
||||
f"Cannot subscribe to topic '{topic}', "
|
||||
|
|
|
@ -86,9 +86,6 @@ CONF_CONFIGURATION_URL = "configuration_url"
|
|||
CONF_OBJECT_ID = "object_id"
|
||||
CONF_SUPPORT_URL = "support_url"
|
||||
|
||||
DATA_MQTT = "mqtt"
|
||||
DATA_MQTT_AVAILABLE = "mqtt_client_available"
|
||||
|
||||
DEFAULT_PREFIX = "homeassistant"
|
||||
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"
|
||||
DEFAULT_DISCOVERY = True
|
||||
|
|
|
@ -16,8 +16,7 @@ from homeassistant.helpers.typing import DiscoveryInfoType
|
|||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC
|
||||
from .models import MessageCallbackType, PublishPayloadType
|
||||
from .util import get_mqtt_data
|
||||
from .models import DATA_MQTT, MessageCallbackType, PublishPayloadType
|
||||
|
||||
STORED_MESSAGES = 10
|
||||
|
||||
|
@ -27,7 +26,7 @@ def log_messages(
|
|||
) -> Callable[[MessageCallbackType], MessageCallbackType]:
|
||||
"""Wrap an MQTT message callback to support message logging."""
|
||||
|
||||
debug_info_entities = get_mqtt_data(hass).debug_info_entities
|
||||
debug_info_entities = hass.data[DATA_MQTT].debug_info_entities
|
||||
|
||||
def _log_message(msg: Any) -> None:
|
||||
"""Log message."""
|
||||
|
@ -70,7 +69,7 @@ def log_message(
|
|||
retain: bool,
|
||||
) -> None:
|
||||
"""Log an outgoing MQTT message."""
|
||||
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
|
||||
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
|
||||
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
|
||||
)
|
||||
if topic not in entity_info["transmitted"]:
|
||||
|
@ -90,7 +89,7 @@ def add_subscription(
|
|||
) -> None:
|
||||
"""Prepare debug data for subscription."""
|
||||
if entity_id := getattr(message_callback, "__entity_id", None):
|
||||
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
|
||||
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
|
||||
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
|
||||
)
|
||||
if subscription not in entity_info["subscriptions"]:
|
||||
|
@ -108,7 +107,7 @@ def remove_subscription(
|
|||
) -> None:
|
||||
"""Remove debug data for subscription if it exists."""
|
||||
if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in (
|
||||
debug_info_entities := get_mqtt_data(hass).debug_info_entities
|
||||
debug_info_entities := hass.data[DATA_MQTT].debug_info_entities
|
||||
):
|
||||
debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1
|
||||
if not debug_info_entities[entity_id]["subscriptions"][subscription]["count"]:
|
||||
|
@ -119,7 +118,7 @@ def add_entity_discovery_data(
|
|||
hass: HomeAssistant, discovery_data: DiscoveryInfoType, entity_id: str
|
||||
) -> None:
|
||||
"""Add discovery data."""
|
||||
entity_info = get_mqtt_data(hass).debug_info_entities.setdefault(
|
||||
entity_info = hass.data[DATA_MQTT].debug_info_entities.setdefault(
|
||||
entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}}
|
||||
)
|
||||
entity_info["discovery_data"] = discovery_data
|
||||
|
@ -129,7 +128,7 @@ def update_entity_discovery_data(
|
|||
hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str
|
||||
) -> None:
|
||||
"""Update discovery data."""
|
||||
discovery_data = get_mqtt_data(hass).debug_info_entities[entity_id][
|
||||
discovery_data = hass.data[DATA_MQTT].debug_info_entities[entity_id][
|
||||
"discovery_data"
|
||||
]
|
||||
if TYPE_CHECKING:
|
||||
|
@ -139,7 +138,7 @@ def update_entity_discovery_data(
|
|||
|
||||
def remove_entity_data(hass: HomeAssistant, entity_id: str) -> None:
|
||||
"""Remove discovery data."""
|
||||
if entity_id in (debug_info_entities := get_mqtt_data(hass).debug_info_entities):
|
||||
if entity_id in (debug_info_entities := hass.data[DATA_MQTT].debug_info_entities):
|
||||
debug_info_entities.pop(entity_id)
|
||||
|
||||
|
||||
|
@ -150,7 +149,7 @@ def add_trigger_discovery_data(
|
|||
device_id: str,
|
||||
) -> None:
|
||||
"""Add discovery data."""
|
||||
get_mqtt_data(hass).debug_info_triggers[discovery_hash] = {
|
||||
hass.data[DATA_MQTT].debug_info_triggers[discovery_hash] = {
|
||||
"device_id": device_id,
|
||||
"discovery_data": discovery_data,
|
||||
}
|
||||
|
@ -162,7 +161,7 @@ def update_trigger_discovery_data(
|
|||
discovery_payload: DiscoveryInfoType,
|
||||
) -> None:
|
||||
"""Update discovery data."""
|
||||
get_mqtt_data(hass).debug_info_triggers[discovery_hash]["discovery_data"][
|
||||
hass.data[DATA_MQTT].debug_info_triggers[discovery_hash]["discovery_data"][
|
||||
ATTR_DISCOVERY_PAYLOAD
|
||||
] = discovery_payload
|
||||
|
||||
|
@ -171,11 +170,11 @@ def remove_trigger_discovery_data(
|
|||
hass: HomeAssistant, discovery_hash: tuple[str, str]
|
||||
) -> None:
|
||||
"""Remove discovery data."""
|
||||
get_mqtt_data(hass).debug_info_triggers.pop(discovery_hash)
|
||||
hass.data[DATA_MQTT].debug_info_triggers.pop(discovery_hash)
|
||||
|
||||
|
||||
def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
|
||||
entity_info = get_mqtt_data(hass).debug_info_entities[entity_id]
|
||||
entity_info = hass.data[DATA_MQTT].debug_info_entities[entity_id]
|
||||
monotonic_time_diff = time.time() - time.monotonic()
|
||||
subscriptions = [
|
||||
{
|
||||
|
@ -231,7 +230,7 @@ def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]:
|
|||
def _info_for_trigger(
|
||||
hass: HomeAssistant, trigger_key: tuple[str, str]
|
||||
) -> dict[str, Any]:
|
||||
trigger = get_mqtt_data(hass).debug_info_triggers[trigger_key]
|
||||
trigger = hass.data[DATA_MQTT].debug_info_triggers[trigger_key]
|
||||
discovery_data = None
|
||||
if trigger["discovery_data"] is not None:
|
||||
discovery_data = {
|
||||
|
@ -244,7 +243,7 @@ def _info_for_trigger(
|
|||
def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
|
||||
"""Get debug info for all entities and triggers."""
|
||||
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
|
||||
|
||||
mqtt_info["entities"].extend(
|
||||
|
@ -262,7 +261,7 @@ def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]:
|
|||
def info_for_device(hass: HomeAssistant, device_id: str) -> dict[str, list[Any]]:
|
||||
"""Get debug info for a device."""
|
||||
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
|
||||
mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []}
|
||||
entity_registry = er.async_get(hass)
|
||||
|
|
|
@ -42,7 +42,7 @@ from .mixins import (
|
|||
send_discovery_done,
|
||||
update_device,
|
||||
)
|
||||
from .util import get_mqtt_data
|
||||
from .models import DATA_MQTT
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -206,7 +206,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
|||
self.device_id = device_id
|
||||
self.discovery_data = discovery_data
|
||||
self.hass = hass
|
||||
self._mqtt_data = get_mqtt_data(hass)
|
||||
self._mqtt_data = hass.data[DATA_MQTT]
|
||||
self.trigger_id = f"{device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
|
||||
|
||||
MqttDiscoveryDeviceUpdate.__init__(
|
||||
|
@ -259,7 +259,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
|||
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
|
||||
new_trigger_id = f"{self.device_id}_{config[CONF_TYPE]}_{config[CONF_SUBTYPE]}"
|
||||
if new_trigger_id != self.trigger_id:
|
||||
mqtt_data = get_mqtt_data(self.hass)
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
if new_trigger_id in mqtt_data.device_triggers:
|
||||
_LOGGER.error(
|
||||
"Cannot update device trigger %s due to an existing duplicate "
|
||||
|
@ -308,7 +308,7 @@ async def async_setup_trigger(
|
|||
trigger_type = config[CONF_TYPE]
|
||||
trigger_subtype = config[CONF_SUBTYPE]
|
||||
trigger_id = f"{device_id}_{trigger_type}_{trigger_subtype}"
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
if (
|
||||
trigger_id in mqtt_data.device_triggers
|
||||
and mqtt_data.device_triggers[trigger_id].discovery_data is not None
|
||||
|
@ -334,7 +334,7 @@ async def async_setup_trigger(
|
|||
|
||||
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
|
||||
"""Handle Mqtt removed from a device."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
triggers = await async_get_triggers(hass, device_id)
|
||||
for trig in triggers:
|
||||
trigger_id = f"{device_id}_{trig[CONF_TYPE]}_{trig[CONF_SUBTYPE]}"
|
||||
|
@ -352,7 +352,7 @@ async def async_get_triggers(
|
|||
hass: HomeAssistant, device_id: str
|
||||
) -> list[dict[str, str]]:
|
||||
"""List device triggers for MQTT devices."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
|
||||
if not mqtt_data.device_triggers:
|
||||
return []
|
||||
|
@ -377,7 +377,7 @@ async def async_attach_trigger(
|
|||
) -> CALLBACK_TYPE:
|
||||
"""Attach a trigger."""
|
||||
trigger_id: str | None = None
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
device_id = config[CONF_DEVICE_ID]
|
||||
|
||||
# The use of CONF_DISCOVERY_ID was deprecated in HA Core 2024.2.
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.helpers import device_registry as dr, entity_registry as er
|
|||
from homeassistant.helpers.device_registry import DeviceEntry
|
||||
|
||||
from . import debug_info, is_connected
|
||||
from .util import get_mqtt_data
|
||||
from .models import DATA_MQTT
|
||||
|
||||
REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME}
|
||||
REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE}
|
||||
|
@ -45,7 +45,7 @@ def _async_get_diagnostics(
|
|||
device: DeviceEntry | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return diagnostics for a config entry."""
|
||||
mqtt_instance = get_mqtt_data(hass).client
|
||||
mqtt_instance = hass.data[DATA_MQTT].client
|
||||
if TYPE_CHECKING:
|
||||
assert mqtt_instance is not None
|
||||
|
||||
|
|
|
@ -40,8 +40,8 @@ from .const import (
|
|||
CONF_TOPIC,
|
||||
DOMAIN,
|
||||
)
|
||||
from .models import MqttOriginInfo, ReceiveMessage
|
||||
from .util import async_forward_entry_setup_and_setup_discovery, get_mqtt_data
|
||||
from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
|
||||
from .util import async_forward_entry_setup_and_setup_discovery
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -113,12 +113,12 @@ class MQTTDiscoveryPayload(dict[str, Any]):
|
|||
|
||||
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
|
||||
"""Clear entry from already discovered list."""
|
||||
get_mqtt_data(hass).discovery_already_discovered.remove(discovery_hash)
|
||||
hass.data[DATA_MQTT].discovery_already_discovered.remove(discovery_hash)
|
||||
|
||||
|
||||
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple[str, str]) -> None:
|
||||
"""Add entry to already discovered list."""
|
||||
get_mqtt_data(hass).discovery_already_discovered.add(discovery_hash)
|
||||
hass.data[DATA_MQTT].discovery_already_discovered.add(discovery_hash)
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -150,7 +150,7 @@ async def async_start( # noqa: C901
|
|||
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Start MQTT Discovery."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
platform_setup_lock: dict[str, asyncio.Lock] = {}
|
||||
|
||||
async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None:
|
||||
|
@ -426,7 +426,7 @@ async def async_start( # noqa: C901
|
|||
|
||||
async def async_stop(hass: HomeAssistant) -> None:
|
||||
"""Stop MQTT Discovery."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
for unsub in mqtt_data.discovery_unsubscribe:
|
||||
unsub()
|
||||
mqtt_data.discovery_unsubscribe = []
|
||||
|
|
|
@ -38,13 +38,13 @@ from .mixins import (
|
|||
async_setup_entity_entry_helper,
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
MqttValueTemplate,
|
||||
MqttValueTemplateException,
|
||||
PayloadSentinel,
|
||||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .util import get_mqtt_data
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -194,7 +194,7 @@ class MqttEvent(MqttEntity, EventEntity):
|
|||
payload,
|
||||
)
|
||||
return
|
||||
mqtt_data = get_mqtt_data(self.hass)
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
mqtt_data.state_write_requests.write_state_request(self)
|
||||
|
||||
topics["state_topic"] = {
|
||||
|
|
|
@ -33,12 +33,13 @@ from .mixins import (
|
|||
async_setup_entity_entry_helper,
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
MessageCallbackType,
|
||||
MqttValueTemplate,
|
||||
MqttValueTemplateException,
|
||||
ReceiveMessage,
|
||||
)
|
||||
from .util import get_mqtt_data, valid_subscribe_topic
|
||||
from .util import valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -186,7 +187,7 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||
)
|
||||
self._last_image = None
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
|
||||
|
||||
|
@ -208,7 +209,7 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||
)
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
self._cached_image = None
|
||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
|
||||
|
||||
|
|
|
@ -106,6 +106,7 @@ from .discovery import (
|
|||
set_discovery_hash,
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
MessageCallbackType,
|
||||
MqttValueTemplate,
|
||||
MqttValueTemplateException,
|
||||
|
@ -118,7 +119,7 @@ from .subscription import (
|
|||
async_subscribe_topics,
|
||||
async_unsubscribe_topics,
|
||||
)
|
||||
from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic
|
||||
from .util import mqtt_config_entry_enabled, valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -329,7 +330,7 @@ async def async_setup_non_entity_entry_helper(
|
|||
discovery_schema: vol.Schema,
|
||||
) -> None:
|
||||
"""Set up automation or tag creation dynamically through MQTT discovery."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
|
||||
async def async_setup_from_discovery(
|
||||
discovery_payload: MQTTDiscoveryPayload,
|
||||
|
@ -360,7 +361,7 @@ async def async_setup_entity_entry_helper(
|
|||
schema_class_mapping: dict[str, type[MqttEntity]] | None = None,
|
||||
) -> None:
|
||||
"""Set up entity creation dynamically through MQTT discovery."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
|
||||
@callback
|
||||
def async_setup_from_discovery(
|
||||
|
@ -391,7 +392,7 @@ async def async_setup_entity_entry_helper(
|
|||
def _async_setup_entities() -> None:
|
||||
"""Set up MQTT items from configuration.yaml."""
|
||||
nonlocal entity_class
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
if not (config_yaml := mqtt_data.config):
|
||||
return
|
||||
yaml_configs: list[ConfigType] = [
|
||||
|
@ -496,7 +497,7 @@ def write_state_on_attr_change(
|
|||
if not _attrs_have_changed(tracked_attrs):
|
||||
return
|
||||
|
||||
mqtt_data = get_mqtt_data(entity.hass)
|
||||
mqtt_data = entity.hass.data[DATA_MQTT]
|
||||
mqtt_data.state_write_requests.write_state_request(entity)
|
||||
|
||||
return wrapper
|
||||
|
@ -695,7 +696,7 @@ class MqttAvailability(Entity):
|
|||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if the device is available."""
|
||||
mqtt_data = get_mqtt_data(self.hass)
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
client = mqtt_data.client
|
||||
if not client.connected and not self.hass.is_stopping:
|
||||
return False
|
||||
|
@ -936,7 +937,7 @@ class MqttDiscoveryUpdate(Entity):
|
|||
self._removed_from_hass = False
|
||||
if discovery_data is None:
|
||||
return
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
self._registry_hooks = mqtt_data.discovery_registry_hooks
|
||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
if discovery_hash in self._registry_hooks:
|
||||
|
|
|
@ -20,6 +20,7 @@ from homeassistant.helpers import template
|
|||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paho.mqtt.client import MQTTMessage
|
||||
|
@ -419,3 +420,7 @@ class MqttData:
|
|||
state_write_requests: EntityTopicState = field(default_factory=EntityTopicState)
|
||||
subscriptions_to_restore: list[Subscription] = field(default_factory=list)
|
||||
tags: dict[str, dict[str, MQTTTagScanner]] = field(default_factory=dict)
|
||||
|
||||
|
||||
DATA_MQTT: HassKey[MqttData] = HassKey("mqtt")
|
||||
DATA_MQTT_AVAILABLE: HassKey[asyncio.Future[bool]] = HassKey("mqtt_client_available")
|
||||
|
|
|
@ -28,13 +28,14 @@ from .mixins import (
|
|||
update_device,
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
MqttValueTemplate,
|
||||
MqttValueTemplateException,
|
||||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .subscription import EntitySubscription
|
||||
from .util import get_mqtt_data, valid_subscribe_topic
|
||||
from .util import valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,7 +71,7 @@ async def _async_setup_tag(
|
|||
discovery_id = discovery_hash[1]
|
||||
|
||||
device_id = update_device(hass, config_entry, config)
|
||||
if device_id is not None and device_id not in (tags := get_mqtt_data(hass).tags):
|
||||
if device_id is not None and device_id not in (tags := hass.data[DATA_MQTT].tags):
|
||||
tags[device_id] = {}
|
||||
|
||||
tag_scanner = MQTTTagScanner(
|
||||
|
@ -91,7 +92,7 @@ async def _async_setup_tag(
|
|||
|
||||
def async_has_tags(hass: HomeAssistant, device_id: str) -> bool:
|
||||
"""Device has tag scanners."""
|
||||
if device_id not in (tags := get_mqtt_data(hass).tags):
|
||||
if device_id not in (tags := hass.data[DATA_MQTT].tags):
|
||||
return False
|
||||
return tags[device_id] != {}
|
||||
|
||||
|
@ -176,4 +177,4 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdate):
|
|||
self.hass, self._sub_state
|
||||
)
|
||||
if self.device_id:
|
||||
get_mqtt_data(self.hass).tags[self.device_id].pop(discovery_id)
|
||||
self.hass.data[DATA_MQTT].tags[self.device_id].pop(discovery_id)
|
||||
|
|
|
@ -26,14 +26,12 @@ from .const import (
|
|||
CONF_CERTIFICATE,
|
||||
CONF_CLIENT_CERT,
|
||||
CONF_CLIENT_KEY,
|
||||
DATA_MQTT,
|
||||
DATA_MQTT_AVAILABLE,
|
||||
DEFAULT_ENCODING,
|
||||
DEFAULT_QOS,
|
||||
DEFAULT_RETAIN,
|
||||
DOMAIN,
|
||||
)
|
||||
from .models import MqttData
|
||||
from .models import DATA_MQTT, DATA_MQTT_AVAILABLE
|
||||
|
||||
AVAILABILITY_TIMEOUT = 30.0
|
||||
|
||||
|
@ -51,7 +49,7 @@ async def async_forward_entry_setup_and_setup_discovery(
|
|||
hass: HomeAssistant, config_entry: ConfigEntry, platforms: set[Platform | str]
|
||||
) -> None:
|
||||
"""Forward the config entry setup to the platforms and set up discovery."""
|
||||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
platforms_loaded = mqtt_data.platforms_loaded
|
||||
new_platforms: set[Platform | str] = platforms - platforms_loaded
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
@ -85,7 +83,11 @@ async def async_forward_entry_setup_and_setup_discovery(
|
|||
|
||||
def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
|
||||
"""Return true when the MQTT config entry is enabled."""
|
||||
return hass.config_entries.async_has_entries(
|
||||
# If the mqtt client is connected, skip the expensive config
|
||||
# entry check as its roughly two orders of magnitude faster.
|
||||
return (
|
||||
DATA_MQTT in hass.data and hass.data[DATA_MQTT].client.connected
|
||||
) or hass.config_entries.async_has_entries(
|
||||
DOMAIN, include_disabled=False, include_ignore=False
|
||||
)
|
||||
|
||||
|
@ -229,13 +231,6 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
|
|||
return config
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
|
||||
"""Return typed MqttData from hass.data[DATA_MQTT]."""
|
||||
mqtt_data: MqttData = hass.data[DATA_MQTT]
|
||||
return mqtt_data
|
||||
|
||||
|
||||
async def async_create_certificate_temp_files(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> None:
|
||||
|
|
|
@ -43,7 +43,7 @@ async def setup_comp(
|
|||
|
||||
|
||||
async def test_setup_fails_without_mqtt_being_setup(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Ensure mqtt is started when we setup the component."""
|
||||
# Simulate MQTT is was removed
|
||||
|
@ -52,6 +52,8 @@ async def test_setup_fails_without_mqtt_being_setup(
|
|||
await hass.config_entries.async_set_disabled_by(
|
||||
mqtt_entry.entry_id, ConfigEntryDisabler.USER
|
||||
)
|
||||
# mqtt is mocked so we need to simulate it is not connected
|
||||
mqtt_mock.connected = False
|
||||
|
||||
dev_id = "zanzito"
|
||||
topic = "location/zanzito"
|
||||
|
|
Loading…
Add table
Reference in a new issue