Rework MQTT config merging and adding defaults (#90529)

* Cleanup config merging and adding defaults

* Optimize and update tests

* Do not mix entry and yaml config

* Make sure hass.data is initilized

* remove check on get_mqtt_data

* Tweaks to MQTT client

* Remove None assigment mqtt client and fix mock
This commit is contained in:
Jan Bouwhuis 2023-04-04 18:12:18 +02:00 committed by GitHub
parent 690a0f34e5
commit 4a0d3e881a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 77 additions and 180 deletions

View file

@ -24,7 +24,7 @@ from homeassistant.const import (
SERVICE_RELOAD,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.exceptions import ConfigEntryError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -45,11 +45,7 @@ from .client import ( # noqa: F401
publish,
subscribe,
)
from .config_integration import (
CONFIG_SCHEMA_ENTRY,
DEFAULT_VALUES,
PLATFORM_CONFIG_SCHEMA_BASE,
)
from .config_integration import CONFIG_SCHEMA_ENTRY, PLATFORM_CONFIG_SCHEMA_BASE
from .const import ( # noqa: F401
ATTR_PAYLOAD,
ATTR_QOS,
@ -83,6 +79,7 @@ from .const import ( # noqa: F401
)
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
@ -102,8 +99,6 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"
MANDATORY_DEFAULT_VALUES = (CONF_PORT, CONF_DISCOVERY_PREFIX)
ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"
@ -193,50 +188,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove unknown keys from config entry data.
Extra keys may have been added when importing MQTT yaml configuration.
"""
filtered_data = {
k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
}
if entry.data.keys() != filtered_data.keys():
_LOGGER.warning(
(
"The following unsupported configuration options were removed from the "
"MQTT config entry: %s"
),
entry.data.keys() - filtered_data.keys(),
)
hass.config_entries.async_update_entry(entry, data=filtered_data)
async def _async_auto_mend_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Mends config fetched from config entry and adds missing values.
This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True
if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)
def _merge_extended_config(entry: ConfigEntry, conf: ConfigType) -> dict[str, Any]:
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}
async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated.
@ -245,45 +196,29 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
await hass.config_entries.async_reload(entry.entry_id)
async def async_fetch_config(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any] | None:
"""Fetch fresh MQTT yaml config from the hass config."""
mqtt_data = get_mqtt_data(hass)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)
# Add missing defaults to migrate older config entries
await _async_auto_mend_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return None
# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
return conf
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data = get_mqtt_data(hass, True)
# validate entry config
try:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
except vol.MultipleInvalid as ex:
raise ConfigEntryError(
f"The MQTT config entry is invalid, please correct it: {ex}"
) from ex
# Fetch configuration and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out
return False
# Fetch configuration and add default values
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)
await async_create_certificate_temp_files(hass, dict(entry.data))
mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
@ -349,7 +284,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
return
assert mqtt_data.client is not None and msg_topic is not None
assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)
hass.services.async_register(
@ -585,7 +520,6 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
return mqtt_data.client.connected
@ -603,7 +537,6 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client
# Unload publish and dump services.

View file

@ -69,6 +69,7 @@ from .const import (
from .models import (
AsyncMessageCallbackType,
MessageCallbackType,
MqttData,
PublishMessage,
PublishPayloadType,
ReceiveMessage,
@ -111,11 +112,11 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
@ -161,11 +162,11 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
# Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0
non_default = 0
@ -377,19 +378,16 @@ class MQTT:
_mqttc: mqtt.Client
_last_subscribe: float
_mqtt_data: MqttData
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
conf: ConfigType,
self, hass: HomeAssistant, config_entry: ConfigEntry, conf: ConfigType
) -> None:
"""Initialize Home Assistant MQTT client."""
self._mqtt_data = get_mqtt_data(hass)
self.hass = hass
self.config_entry = config_entry
self.conf = conf
self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = []
self.connected = False
@ -415,8 +413,6 @@ class MQTT:
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client()
async def async_stop_mqtt(_event: Event) -> None:
"""Stop MQTT component."""
await self.async_disconnect()
@ -425,6 +421,14 @@ class MQTT:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
)
def start(
self,
mqtt_data: MqttData,
) -> None:
"""Start Home Assistant MQTT client."""
self._mqtt_data = mqtt_data
self.init_client()
@property
def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions."""

View file

@ -65,17 +65,6 @@ from .util import valid_birth_will, valid_publish_topic
DEFAULT_TLS_PROTOCOL = "auto"
DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX,
CONF_PORT: DEFAULT_PORT,
CONF_PROTOCOL: DEFAULT_PROTOCOL,
CONF_TRANSPORT: DEFAULT_TRANSPORT,
CONF_WILL_MESSAGE: DEFAULT_WILL,
CONF_KEEPALIVE: DEFAULT_KEEPALIVE,
}
PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{
Platform.ALARM_CONTROL_PANEL.value: vol.All(
@ -169,9 +158,11 @@ CLIENT_KEY_AUTH_MSG = (
CONFIG_SCHEMA_ENTRY = vol.Schema(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Coerce(int), vol.Range(min=15)
),
vol.Required(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): str,
@ -180,13 +171,17 @@ CONFIG_SCHEMA_ENTRY = vol.Schema(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): str,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
cv.string, vol.In(SUPPORTED_PROTOCOLS)
),
vol.Optional(CONF_WILL_MESSAGE, default=DEFAULT_WILL): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE, default=DEFAULT_BIRTH): valid_birth_will,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All(
cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS])
),
@ -195,32 +190,6 @@ CONFIG_SCHEMA_ENTRY = vol.Schema(
}
)
CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
}
)
DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE,
CONF_BROKER,

View file

@ -562,7 +562,6 @@ class MqttAvailability(Entity):
def available(self) -> bool:
"""Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False

View file

@ -288,8 +288,8 @@ class EntityTopicState:
class MqttData:
"""Keep the MQTT entry data."""
client: MQTT | None = None
config: ConfigType | None = None
client: MQTT
config: ConfigType
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict

View file

@ -136,12 +136,9 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData
if ensure_exists:
mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data

View file

@ -183,7 +183,7 @@ async def test_user_connection_works(
assert result["type"] == "form"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1", "advanced_options": False}
result["flow_id"], {"broker": "127.0.0.1"}
)
assert result["type"] == "create_entry"
@ -191,7 +191,6 @@ async def test_user_connection_works(
"broker": "127.0.0.1",
"port": 1883,
"discovery": True,
"discovery_prefix": "homeassistant",
}
# Check we tried the connection
assert len(mock_try_connection.mock_calls) == 1
@ -231,7 +230,6 @@ async def test_user_v5_connection_works(
assert result["result"].data == {
"broker": "another-broker",
"discovery": True,
"discovery_prefix": "homeassistant",
"port": 2345,
"protocol": "5",
}
@ -283,7 +281,7 @@ async def test_manual_config_set(
assert result["type"] == "form"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1"}
result["flow_id"], {"broker": "127.0.0.1", "port": "1883"}
)
assert result["type"] == "create_entry"
@ -291,7 +289,6 @@ async def test_manual_config_set(
"broker": "127.0.0.1",
"port": 1883,
"discovery": True,
"discovery_prefix": "homeassistant",
}
# Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with(
@ -395,7 +392,6 @@ async def test_hassio_confirm(
"username": "mock-user",
"password": "mock-pass",
"discovery": True,
"discovery_prefix": "homeassistant",
}
# Check we tried the connection
assert len(mock_try_connection_success.mock_calls)

View file

@ -31,6 +31,8 @@ default_config = {
"retain": False,
"topic": "homeassistant/status",
},
"ws_headers": {},
"ws_path": "/",
}
@ -265,6 +267,7 @@ async def test_redact_diagnostics(
"name_by_user": None,
}
await get_diagnostics_for_config_entry(hass, hass_client, config_entry)
assert await get_diagnostics_for_config_entry(hass, hass_client, config_entry) == {
"connected": True,
"devices": [expected_device],

View file

@ -952,6 +952,7 @@ async def test_subscribe_with_deprecated_callback_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None:
"""Test the subscription of a topic using deprecated callback signature fails."""
await mqtt_mock_entry_no_yaml_config()
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls."""
@ -969,6 +970,7 @@ async def test_subscribe_deprecated_async_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None:
"""Test the subscription of a topic using deprecated coroutine signature fails."""
await mqtt_mock_entry_no_yaml_config()
@callback
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
@ -2256,34 +2258,29 @@ async def test_mqtt_subscribes_topics_on_connect(
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
async def test_update_incomplete_entry(
async def test_default_entry_setting_are_applied(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test if the MQTT component loads when config entry data is incomplete."""
"""Test if the MQTT component loads when config entry data not has all default settings."""
data = (
'{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }'
)
# Config entry data is incomplete
# Config entry data is incomplete but valid according the schema
entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
entry.data = {"broker": "test-broker", "port": 1234}
await mqtt_mock_entry_no_yaml_config()
await hass.async_block_till_done()
# Config entry data should now be updated
assert dict(entry.data) == {
"broker": "test-broker",
"port": 1234,
"discovery_prefix": "homeassistant",
}
# Discover a device to verify the entry was setup correctly
# The discovery prefix should be the default
# And that the default settings were merged
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done()
@ -2297,12 +2294,15 @@ async def test_fail_no_broker(
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test if the MQTT component loads when broker configuration is missing."""
"""Test the MQTT entry setup when broker configuration is missing."""
# Config entry data is incomplete
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={})
entry.add_to_hass(hass)
assert not await hass.config_entries.async_setup(entry.entry_id)
assert "MQTT broker is not configured, please configure it" in caplog.text
assert (
"The MQTT config entry is invalid, please correct it: required key not provided @ data['broker']"
in caplog.text
)
@pytest.mark.no_fail_on_log_exception
@ -3312,7 +3312,7 @@ async def test_setup_manual_items_with_unique_ids(
assert bool("Platform mqtt does not generate unique IDs." in caplog.text) != unique
async def test_remove_unknown_conf_entry_options(
async def test_fail_with_unknown_conf_entry_options(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture,
@ -3331,14 +3331,9 @@ async def test_remove_unknown_conf_entry_options(
)
entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert await hass.config_entries.async_setup(entry.entry_id) is False
assert mqtt.client.CONF_PROTOCOL not in entry.data
assert (
"The following unsupported configuration options were removed from the "
"MQTT config entry: {'old_option'}"
) in caplog.text
assert ("extra keys not allowed @ data['old_option']") in caplog.text
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT])

View file

@ -967,9 +967,10 @@ async def _mqtt_mock_entry(
nonlocal mock_mqtt_instance
nonlocal real_mqtt_instance
real_mqtt_instance = real_mqtt(*args, **kwargs)
spec = dir(real_mqtt_instance) + ["_mqttc"]
mock_mqtt_instance = MqttMockHAClient(
return_value=real_mqtt_instance,
spec_set=real_mqtt_instance,
spec_set=spec,
wraps=real_mqtt_instance,
)
return mock_mqtt_instance