hass-core/homeassistant/components/mqtt/__init__.py
J. Nick Koston ed0568c655
Ensure config entries are not unloaded while their platforms are setting up (#118767)
* Report non-awaited/non-locked config entry platform forwards

Its currently possible for config entries to be reloaded while their platforms
are being forwarded if platform forwards are not awaited or done after the
config entry is setup since the lock will not be held in this case.

In https://developers.home-assistant.io/blog/2022/07/08/config_entry_forwards
we advised to await platform forwards to ensure this does not happen, however
for sleeping devices and late discovered devices, platform forwards may happen
later.

If config platform forwards are happening during setup, they should be awaited

If config entry platform forwards are not happening during setup, instead
async_late_forward_entry_setups should be used which will hold the lock to
prevent the config entry from being unloaded while its platforms are being
setup

* Report non-awaited/non-locked config entry platform forwards

Its currently possible for config entries to be reloaded while their platforms
are being forwarded if platform forwards are not awaited or done after the
config entry is setup since the lock will not be held in this case.

In https://developers.home-assistant.io/blog/2022/07/08/config_entry_forwards
we advised to await platform forwards to ensure this does not happen, however
for sleeping devices and late discovered devices, platform forwards may happen
later.

If config platform forwards are happening during setup, they should be awaited

If config entry platform forwards are not happening during setup, instead
async_late_forward_entry_setups should be used which will hold the lock to
prevent the config entry from being unloaded while its platforms are being
setup

* run with error on to find them

* cert_exp, hold lock

* cert_exp, hold lock

* shelly async_late_forward_entry_setups

* compact

* compact

* found another

* patch up mobileapp

* patch up hue tests

* patch up smartthings

* fix mqtt

* fix esphome

* zwave_js

* mqtt

* rework

* fixes

* fix mocking

* fix mocking

* do not call async_forward_entry_setup directly

* docstrings

* docstrings

* docstrings

* add comments

* doc strings

* fixed all in core, turn off strict

* coverage

* coverage

* missing

* coverage
2024-06-04 21:34:39 -04:00

553 lines
18 KiB
Python

"""Support for MQTT message handling."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
from datetime import datetime
import logging
from typing import TYPE_CHECKING, Any, cast
import voluptuous as vol
from homeassistant import config as conf_util
from homeassistant.components import websocket_api
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DISCOVERY, CONF_PAYLOAD, SERVICE_RELOAD
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import (
ConfigValidationError,
ServiceValidationError,
Unauthorized,
)
from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
event as ev,
issue_registry as ir,
template,
)
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration, async_get_loaded_integration
from homeassistant.setup import SetupPhases, async_pause_setup
from homeassistant.util.async_ import create_eager_task
# Loading the config flow file will register the flow
from . import debug_info, discovery
from .client import ( # noqa: F401
MQTT,
async_publish,
async_subscribe,
async_subscribe_internal,
publish,
subscribe,
)
from .config import MQTT_BASE_SCHEMA, MQTT_RO_SCHEMA, MQTT_RW_SCHEMA # noqa: F401
from .config_integration import CONFIG_SCHEMA_BASE
from .const import ( # noqa: F401
ATTR_PAYLOAD,
ATTR_QOS,
ATTR_RETAIN,
ATTR_TOPIC,
CONF_BIRTH_MESSAGE,
CONF_BROKER,
CONF_CERTIFICATE,
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
CONF_COMMAND_TOPIC,
CONF_DISCOVERY_PREFIX,
CONF_KEEPALIVE,
CONF_QOS,
CONF_STATE_TOPIC,
CONF_TLS_INSECURE,
CONF_TOPIC,
CONF_TRANSPORT,
CONF_WILL_MESSAGE,
CONF_WS_HEADERS,
CONF_WS_PATH,
DEFAULT_DISCOVERY,
DEFAULT_ENCODING,
DEFAULT_PREFIX,
DEFAULT_QOS,
DEFAULT_RETAIN,
DOMAIN,
MQTT_CONNECTION_STATE,
RELOADABLE_PLATFORMS,
TEMPLATE_ERRORS,
)
from .models import ( # noqa: F401
DATA_MQTT,
DATA_MQTT_AVAILABLE,
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
PayloadSentinel,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .subscription import ( # noqa: F401
EntitySubscription,
async_prepare_subscribe_topics,
async_subscribe_topics,
async_unsubscribe_topics,
)
from .util import ( # noqa: F401
async_create_certificate_temp_files,
async_forward_entry_setup_and_setup_discovery,
async_wait_for_mqtt_client,
mqtt_config_entry_enabled,
platforms_from_config,
valid_publish_topic,
valid_qos_schema,
valid_subscribe_topic,
)
_LOGGER = logging.getLogger(__name__)
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"
ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"
MAX_RECONNECT_WAIT = 300 # seconds
CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"
# We accept 2 schemes for configuring manual MQTT items
#
# Preferred style:
#
# mqtt:
# - {domain}:
# name: ""
# ...
# - {domain}:
# name: ""
# ...
# ```
#
# Legacy supported style:
#
# mqtt:
# {domain}:
# - name: ""
# ...
# - name: ""
# ...
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.All(
cv.ensure_list,
cv.remove_falsy,
[CONFIG_SCHEMA_BASE],
)
},
extra=vol.ALLOW_EXTRA,
)
# Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.All(
vol.Schema(
{
vol.Exclusive(ATTR_TOPIC, CONF_TOPIC): valid_publish_topic,
vol.Exclusive(ATTR_TOPIC_TEMPLATE, CONF_TOPIC): cv.string,
vol.Exclusive(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, CONF_PAYLOAD): cv.string,
vol.Optional(ATTR_QOS, default=DEFAULT_QOS): valid_qos_schema,
vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
},
required=True,
),
cv.has_at_least_one_key(ATTR_TOPIC, ATTR_TOPIC_TEMPLATE),
)
async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated.
Causes for this is config entry options changing.
"""
await hass.config_entries.async_reload(entry.entry_id)
@callback
def _async_remove_mqtt_issues(hass: HomeAssistant, mqtt_data: MqttData) -> None:
"""Unregister open config issues."""
issue_registry = ir.async_get(hass)
open_issues = [
issue_id
for (domain, issue_id), issue_entry in issue_registry.issues.items()
if domain == DOMAIN and issue_entry.translation_key == "invalid_platform_config"
]
for issue in open_issues:
ir.async_delete_issue(hass, DOMAIN, issue)
async def async_check_config_schema(
hass: HomeAssistant, config_yaml: ConfigType
) -> None:
"""Validate manually configured MQTT items."""
mqtt_data = hass.data[DATA_MQTT]
mqtt_config: list[dict[str, list[ConfigType]]] = config_yaml.get(DOMAIN, {})
for mqtt_config_item in mqtt_config:
for domain, config_items in mqtt_config_item.items():
schema = mqtt_data.reload_schema[domain]
for config in config_items:
try:
schema(config)
except vol.Invalid as exc:
integration = await async_get_integration(hass, DOMAIN)
message = conf_util.format_schema_error(
hass, exc, domain, config, integration.documentation
)
raise ServiceValidationError(
message,
translation_domain=DOMAIN,
translation_key="invalid_platform_config",
translation_placeholders={
"domain": domain,
},
) from exc
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
conf: dict[str, Any]
mqtt_data: MqttData
async def _setup_client(
client_available: asyncio.Future[bool],
) -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client."""
# Fetch configuration
conf = dict(entry.data)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = CONFIG_SCHEMA(hass_config).get(DOMAIN, [])
await async_create_certificate_temp_files(hass, conf)
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = hass.data[DATA_MQTT]
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
# Initial setup
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
await client.async_start(mqtt_data)
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
)
mqtt_data.subscriptions_to_restore = []
mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated)
)
return (mqtt_data, conf)
client_available: asyncio.Future[bool]
if DATA_MQTT_AVAILABLE not in hass.data:
client_available = hass.data[DATA_MQTT_AVAILABLE] = hass.loop.create_future()
else:
client_available = hass.data[DATA_MQTT_AVAILABLE]
mqtt_data, conf = await _setup_client(client_available)
platforms_used = platforms_from_config(mqtt_data.config)
platforms_used.update(
entry.domain
for entry in er.async_entries_for_config_entry(
er.async_get(hass), entry.entry_id
)
)
integration = async_get_loaded_integration(hass, DOMAIN)
# Preload platforms we know we are going to use so
# discovery can setup each platform synchronously
# and avoid creating a flood of tasks at startup
# while waiting for the the imports to complete
if not integration.platforms_are_loaded(platforms_used):
with async_pause_setup(hass, SetupPhases.WAIT_IMPORT_PLATFORMS):
await integration.async_get_platforms(platforms_used)
# Wait to connect until the platforms are loaded so
# we can be sure discovery does not have to wait for
# each platform to load when we get the flood of retained
# messages on connect
await mqtt_data.client.async_connect(client_available)
async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""
msg_topic: str | None = call.data.get(ATTR_TOPIC)
msg_topic_template: str | None = call.data.get(ATTR_TOPIC_TEMPLATE)
payload: PublishPayloadType = call.data.get(ATTR_PAYLOAD)
payload_template: str | None = call.data.get(ATTR_PAYLOAD_TEMPLATE)
qos: int = call.data[ATTR_QOS]
retain: bool = call.data[ATTR_RETAIN]
if msg_topic_template is not None:
rendered_topic: Any = MqttCommandTemplate(
template.Template(msg_topic_template),
hass=hass,
).async_render()
try:
msg_topic = valid_publish_topic(rendered_topic)
except vol.Invalid as err:
err_str = str(err)
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="invalid_publish_topic",
translation_placeholders={
"error": err_str,
"topic": str(rendered_topic),
"topic_template": str(msg_topic_template),
},
) from err
if payload_template is not None:
payload = MqttCommandTemplate(
template.Template(payload_template), hass=hass
).async_render()
if TYPE_CHECKING:
assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)
hass.services.async_register(
DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
)
async def async_dump_service(call: ServiceCall) -> None:
"""Handle MQTT dump service calls."""
messages: list[tuple[str, str]] = []
@callback
def collect_msg(msg: ReceiveMessage) -> None:
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
def write_dump() -> None:
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
for msg in messages:
fp.write(",".join(msg) + "\n")
async def finish_dump(_: datetime) -> None:
"""Write dump to file."""
unsub()
await hass.async_add_executor_job(write_dump)
ev.async_call_later(hass, call.data["duration"], finish_dump)
hass.services.async_register(
DOMAIN,
SERVICE_DUMP,
async_dump_service,
schema=vol.Schema(
{
vol.Required("topic"): valid_subscribe_topic,
vol.Optional("duration", default=5): int,
}
),
)
# setup platforms and discovery
async def _reload_config(call: ServiceCall) -> None:
"""Reload the platforms."""
# Fetch updated manually configured items and validate
try:
config_yaml = await async_integration_yaml_config(
hass, DOMAIN, raise_on_failure=True
)
except ConfigValidationError as ex:
raise ServiceValidationError(
translation_domain=ex.translation_domain,
translation_key=ex.translation_key,
translation_placeholders=ex.translation_placeholders,
) from ex
new_config: list[ConfigType] = config_yaml.get(DOMAIN, [])
platforms_used = platforms_from_config(new_config)
new_platforms = platforms_used - mqtt_data.platforms_loaded
await async_forward_entry_setup_and_setup_discovery(
hass, entry, new_platforms, late=True
)
# Check the schema before continuing reload
await async_check_config_schema(hass, config_yaml)
# Remove repair issues
_async_remove_mqtt_issues(hass, mqtt_data)
mqtt_data.config = new_config
# Reload the modern yaml platforms
mqtt_platforms = async_get_platforms(hass, DOMAIN)
tasks = [
create_eager_task(entity.async_remove())
for mqtt_platform in mqtt_platforms
for entity in list(mqtt_platform.entities.values())
if getattr(entity, "_discovery_data", None) is None
and mqtt_platform.config_entry
and mqtt_platform.domain in RELOADABLE_PLATFORMS
]
await asyncio.gather(*tasks)
for component in mqtt_data.reload_handlers.values():
component()
# Fire event
hass.bus.async_fire(f"event_{DOMAIN}_reloaded", context=call.context)
await async_forward_entry_setup_and_setup_discovery(hass, entry, platforms_used)
# Setup reload service after all platforms have loaded
if not hass.services.has_service(DOMAIN, SERVICE_RELOAD):
async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config)
# Setup discovery
if conf.get(CONF_DISCOVERY, DEFAULT_DISCOVERY):
await discovery.async_start(
hass, conf.get(CONF_DISCOVERY_PREFIX, DEFAULT_PREFIX), entry
)
return True
@websocket_api.websocket_command(
{vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str}
)
@callback
def websocket_mqtt_info(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Get MQTT debug info for device."""
device_id = msg["device_id"]
mqtt_info = debug_info.info_for_device(hass, device_id)
connection.send_result(msg["id"], mqtt_info)
@websocket_api.websocket_command(
{
vol.Required("type"): "mqtt/subscribe",
vol.Required("topic"): valid_subscribe_topic,
vol.Optional("qos"): valid_qos_schema,
}
)
@websocket_api.async_response
async def websocket_subscribe(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Subscribe to a MQTT topic."""
if not connection.user.is_admin:
raise Unauthorized
@callback
def forward_messages(mqttmsg: ReceiveMessage) -> None:
"""Forward events to websocket."""
try:
payload = cast(bytes, mqttmsg.payload).decode(
DEFAULT_ENCODING
) # not str because encoding is set to None
except (AttributeError, UnicodeDecodeError):
# Convert non UTF-8 payload to a string presentation
payload = str(mqttmsg.payload)
connection.send_message(
websocket_api.event_message(
msg["id"],
{
"topic": mqttmsg.topic,
"payload": payload,
"qos": mqttmsg.qos,
"retain": mqttmsg.retain,
},
)
)
# Perform UTF-8 decoding directly in callback routine
qos: int = msg.get("qos", DEFAULT_QOS)
connection.subscriptions[msg["id"]] = async_subscribe_internal(
hass, msg["topic"], forward_messages, encoding=None, qos=qos
)
connection.send_message(websocket_api.result_message(msg["id"]))
type ConnectionStatusCallback = Callable[[bool], None]
@callback
def async_subscribe_connection_status(
hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
"""Subscribe to MQTT connection changes."""
return async_dispatcher_connect(
hass, MQTT_CONNECTION_STATE, connection_status_callback
)
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data.client.connected
async def async_remove_config_entry_device(
hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry
) -> bool:
"""Remove MQTT config entry from a device."""
# pylint: disable-next=import-outside-toplevel
from . import device_automation
await device_automation.async_removed_from_device(hass, device_entry.id)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = hass.data[DATA_MQTT]
mqtt_client = mqtt_data.client
# Unload publish and dump services.
hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
hass.services.async_remove(DOMAIN, SERVICE_DUMP)
# Stop the discovery
await discovery.async_stop(hass)
# Unload the platforms
await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
mqtt_data.platforms_loaded = set()
await asyncio.sleep(0)
# Unsubscribe reload dispatchers
while reload_dispatchers := mqtt_data.reload_dispatchers:
reload_dispatchers.pop()()
# Cleanup listeners
mqtt_client.cleanup()
# Cleanup entity registry hooks
registry_hooks = mqtt_data.discovery_registry_hooks
while registry_hooks:
registry_hooks.popitem()[1]()
# Wait for all ACKs and stop the loop
await mqtt_client.async_disconnect()
# Cleanup MQTT client availability
hass.data.pop(DATA_MQTT_AVAILABLE, None)
# Store remaining subscriptions to be able to restore or reload them
# when the entry is set up again
if subscriptions := mqtt_client.subscriptions:
mqtt_data.subscriptions_to_restore = subscriptions
# Remove repair issues
_async_remove_mqtt_issues(hass, mqtt_data)
return True