Rework MQTT config merging and adding defaults (#90529)

* Cleanup config merging and adding defaults

* Optimize and update tests

* Do not mix entry and yaml config

* Make sure hass.data is initilized

* remove check on get_mqtt_data

* Tweaks to MQTT client

* Remove None assigment mqtt client and fix mock
This commit is contained in:
Jan Bouwhuis 2023-04-04 18:12:18 +02:00 committed by GitHub
parent 690a0f34e5
commit 4a0d3e881a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 77 additions and 180 deletions

View file

@ -24,7 +24,7 @@ from homeassistant.const import (
SERVICE_RELOAD, SERVICE_RELOAD,
) )
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.exceptions import ConfigEntryError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -45,11 +45,7 @@ from .client import ( # noqa: F401
publish, publish,
subscribe, subscribe,
) )
from .config_integration import ( from .config_integration import CONFIG_SCHEMA_ENTRY, PLATFORM_CONFIG_SCHEMA_BASE
CONFIG_SCHEMA_ENTRY,
DEFAULT_VALUES,
PLATFORM_CONFIG_SCHEMA_BASE,
)
from .const import ( # noqa: F401 from .const import ( # noqa: F401
ATTR_PAYLOAD, ATTR_PAYLOAD,
ATTR_QOS, ATTR_QOS,
@ -83,6 +79,7 @@ from .const import ( # noqa: F401
) )
from .models import ( # noqa: F401 from .models import ( # noqa: F401
MqttCommandTemplate, MqttCommandTemplate,
MqttData,
MqttValueTemplate, MqttValueTemplate,
PublishPayloadType, PublishPayloadType,
ReceiveMessage, ReceiveMessage,
@ -102,8 +99,6 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_PUBLISH = "publish" SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump" SERVICE_DUMP = "dump"
MANDATORY_DEFAULT_VALUES = (CONF_PORT, CONF_DISCOVERY_PREFIX)
ATTR_TOPIC_TEMPLATE = "topic_template" ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template" ATTR_PAYLOAD_TEMPLATE = "payload_template"
@ -193,50 +188,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove unknown keys from config entry data.
Extra keys may have been added when importing MQTT yaml configuration.
"""
filtered_data = {
k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
}
if entry.data.keys() != filtered_data.keys():
_LOGGER.warning(
(
"The following unsupported configuration options were removed from the "
"MQTT config entry: %s"
),
entry.data.keys() - filtered_data.keys(),
)
hass.config_entries.async_update_entry(entry, data=filtered_data)
async def _async_auto_mend_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Mends config fetched from config entry and adds missing values.
This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True
if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)
def _merge_extended_config(entry: ConfigEntry, conf: ConfigType) -> dict[str, Any]:
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}
async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None: async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated. """Handle signals of config entry being updated.
@ -245,45 +196,29 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
await hass.config_entries.async_reload(entry.entry_id) await hass.config_entries.async_reload(entry.entry_id)
async def async_fetch_config(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any] | None:
"""Fetch fresh MQTT yaml config from the hass config."""
mqtt_data = get_mqtt_data(hass)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)
# Add missing defaults to migrate older config entries
await _async_auto_mend_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return None
# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
return conf
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry.""" """Load a config entry."""
mqtt_data = get_mqtt_data(hass, True) # validate entry config
try:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
except vol.MultipleInvalid as ex:
raise ConfigEntryError(
f"The MQTT config entry is invalid, please correct it: {ex}"
) from ex
# Fetch configuration and add missing defaults for basic options # Fetch configuration and add default values
if (conf := await async_fetch_config(hass, entry)) is None: hass_config = await conf_util.async_hass_config_yaml(hass)
# Bail out mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
return False client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)
await async_create_certificate_temp_files(hass, dict(entry.data)) await async_create_certificate_temp_files(hass, dict(entry.data))
mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions # Restore saved subscriptions
if mqtt_data.subscriptions_to_restore: if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions( mqtt_data.client.async_restore_tracked_subscriptions(
@ -349,7 +284,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
return return
assert mqtt_data.client is not None and msg_topic is not None assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain) await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)
hass.services.async_register( hass.services.async_register(
@ -585,7 +520,6 @@ def async_subscribe_connection_status(
def is_connected(hass: HomeAssistant) -> bool: def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected.""" """Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
return mqtt_data.client.connected return mqtt_data.client.connected
@ -603,7 +537,6 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded.""" """Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client mqtt_client = mqtt_data.client
# Unload publish and dump services. # Unload publish and dump services.

View file

@ -69,6 +69,7 @@ from .const import (
from .models import ( from .models import (
AsyncMessageCallbackType, AsyncMessageCallbackType,
MessageCallbackType, MessageCallbackType,
MqttData,
PublishMessage, PublishMessage,
PublishPayloadType, PublishPayloadType,
ReceiveMessage, ReceiveMessage,
@ -111,11 +112,11 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> None: ) -> None:
"""Publish message to a MQTT topic.""" """Publish message to a MQTT topic."""
mqtt_data = get_mqtt_data(hass, True) if not mqtt_config_entry_enabled(hass):
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled" f"Cannot publish to topic '{topic}', MQTT is not enabled"
) )
mqtt_data = get_mqtt_data(hass)
outgoing_payload = payload outgoing_payload = payload
if not isinstance(payload, bytes): if not isinstance(payload, bytes):
if not encoding: if not encoding:
@ -161,11 +162,11 @@ async def async_subscribe(
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
mqtt_data = get_mqtt_data(hass, True) if not mqtt_config_entry_enabled(hass):
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
raise HomeAssistantError( raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled" f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
) )
mqtt_data = get_mqtt_data(hass)
# Support for a deprecated callback type was removed with HA core 2023.3.0 # Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0 # The signature validation code can be removed from HA core 2023.5.0
non_default = 0 non_default = 0
@ -377,19 +378,16 @@ class MQTT:
_mqttc: mqtt.Client _mqttc: mqtt.Client
_last_subscribe: float _last_subscribe: float
_mqtt_data: MqttData
def __init__( def __init__(
self, self, hass: HomeAssistant, config_entry: ConfigEntry, conf: ConfigType
hass: HomeAssistant,
config_entry: ConfigEntry,
conf: ConfigType,
) -> None: ) -> None:
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
self._mqtt_data = get_mqtt_data(hass)
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self._simple_subscriptions: dict[str, list[Subscription]] = {} self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: list[Subscription] = []
self.connected = False self.connected = False
@ -415,8 +413,6 @@ class MQTT:
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)
self.init_client()
async def async_stop_mqtt(_event: Event) -> None: async def async_stop_mqtt(_event: Event) -> None:
"""Stop MQTT component.""" """Stop MQTT component."""
await self.async_disconnect() await self.async_disconnect()
@ -425,6 +421,14 @@ class MQTT:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
) )
def start(
self,
mqtt_data: MqttData,
) -> None:
"""Start Home Assistant MQTT client."""
self._mqtt_data = mqtt_data
self.init_client()
@property @property
def subscriptions(self) -> list[Subscription]: def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions.""" """Return the tracked subscriptions."""

View file

@ -65,17 +65,6 @@ from .util import valid_birth_will, valid_publish_topic
DEFAULT_TLS_PROTOCOL = "auto" DEFAULT_TLS_PROTOCOL = "auto"
DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX,
CONF_PORT: DEFAULT_PORT,
CONF_PROTOCOL: DEFAULT_PROTOCOL,
CONF_TRANSPORT: DEFAULT_TRANSPORT,
CONF_WILL_MESSAGE: DEFAULT_WILL,
CONF_KEEPALIVE: DEFAULT_KEEPALIVE,
}
PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema( PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{ {
Platform.ALARM_CONTROL_PANEL.value: vol.All( Platform.ALARM_CONTROL_PANEL.value: vol.All(
@ -169,9 +158,11 @@ CLIENT_KEY_AUTH_MSG = (
CONFIG_SCHEMA_ENTRY = vol.Schema( CONFIG_SCHEMA_ENTRY = vol.Schema(
{ {
vol.Optional(CONF_CLIENT_ID): cv.string, vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)), vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Optional(CONF_BROKER): cv.string, vol.Coerce(int), vol.Range(min=15)
vol.Optional(CONF_PORT): cv.port, ),
vol.Required(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string, vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string, vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): str, vol.Optional(CONF_CERTIFICATE): str,
@ -180,13 +171,17 @@ CONFIG_SCHEMA_ENTRY = vol.Schema(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): str, ): str,
vol.Optional(CONF_TLS_INSECURE): cv.boolean, vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)), vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will, cv.string, vol.In(SUPPORTED_PROTOCOLS)
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will, ),
vol.Optional(CONF_DISCOVERY): cv.boolean, vol.Optional(CONF_WILL_MESSAGE, default=DEFAULT_WILL): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE, default=DEFAULT_BIRTH): valid_birth_will,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no # discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix. # state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic, vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All( vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All(
cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS]) cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS])
), ),
@ -195,32 +190,6 @@ CONFIG_SCHEMA_ENTRY = vol.Schema(
} }
) )
CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
}
)
DEPRECATED_CONFIG_KEYS = [ DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE, CONF_BIRTH_MESSAGE,
CONF_BROKER, CONF_BROKER,

View file

@ -562,7 +562,6 @@ class MqttAvailability(Entity):
def available(self) -> bool: def available(self) -> bool:
"""Return if the device is available.""" """Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass) mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None
client = mqtt_data.client client = mqtt_data.client
if not client.connected and not self.hass.is_stopping: if not client.connected and not self.hass.is_stopping:
return False return False

View file

@ -288,8 +288,8 @@ class EntityTopicState:
class MqttData: class MqttData:
"""Keep the MQTT entry data.""" """Keep the MQTT entry data."""
client: MQTT | None = None client: MQTT
config: ConfigType | None = None config: ConfigType
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict) debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field( debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict default_factory=dict

View file

@ -136,12 +136,9 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config return config
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData: def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT].""" """Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData mqtt_data: MqttData
if ensure_exists:
mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
return mqtt_data return mqtt_data

View file

@ -183,7 +183,7 @@ async def test_user_connection_works(
assert result["type"] == "form" assert result["type"] == "form"
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1", "advanced_options": False} result["flow_id"], {"broker": "127.0.0.1"}
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
@ -191,7 +191,6 @@ async def test_user_connection_works(
"broker": "127.0.0.1", "broker": "127.0.0.1",
"port": 1883, "port": 1883,
"discovery": True, "discovery": True,
"discovery_prefix": "homeassistant",
} }
# Check we tried the connection # Check we tried the connection
assert len(mock_try_connection.mock_calls) == 1 assert len(mock_try_connection.mock_calls) == 1
@ -231,7 +230,6 @@ async def test_user_v5_connection_works(
assert result["result"].data == { assert result["result"].data == {
"broker": "another-broker", "broker": "another-broker",
"discovery": True, "discovery": True,
"discovery_prefix": "homeassistant",
"port": 2345, "port": 2345,
"protocol": "5", "protocol": "5",
} }
@ -283,7 +281,7 @@ async def test_manual_config_set(
assert result["type"] == "form" assert result["type"] == "form"
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"broker": "127.0.0.1"} result["flow_id"], {"broker": "127.0.0.1", "port": "1883"}
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
@ -291,7 +289,6 @@ async def test_manual_config_set(
"broker": "127.0.0.1", "broker": "127.0.0.1",
"port": 1883, "port": 1883,
"discovery": True, "discovery": True,
"discovery_prefix": "homeassistant",
} }
# Check we tried the connection, with precedence for config entry settings # Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with( mock_try_connection.assert_called_once_with(
@ -395,7 +392,6 @@ async def test_hassio_confirm(
"username": "mock-user", "username": "mock-user",
"password": "mock-pass", "password": "mock-pass",
"discovery": True, "discovery": True,
"discovery_prefix": "homeassistant",
} }
# Check we tried the connection # Check we tried the connection
assert len(mock_try_connection_success.mock_calls) assert len(mock_try_connection_success.mock_calls)

View file

@ -31,6 +31,8 @@ default_config = {
"retain": False, "retain": False,
"topic": "homeassistant/status", "topic": "homeassistant/status",
}, },
"ws_headers": {},
"ws_path": "/",
} }
@ -265,6 +267,7 @@ async def test_redact_diagnostics(
"name_by_user": None, "name_by_user": None,
} }
await get_diagnostics_for_config_entry(hass, hass_client, config_entry)
assert await get_diagnostics_for_config_entry(hass, hass_client, config_entry) == { assert await get_diagnostics_for_config_entry(hass, hass_client, config_entry) == {
"connected": True, "connected": True,
"devices": [expected_device], "devices": [expected_device],

View file

@ -952,6 +952,7 @@ async def test_subscribe_with_deprecated_callback_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None: ) -> None:
"""Test the subscription of a topic using deprecated callback signature fails.""" """Test the subscription of a topic using deprecated callback signature fails."""
await mqtt_mock_entry_no_yaml_config()
async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: async def record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
"""Record calls.""" """Record calls."""
@ -969,6 +970,7 @@ async def test_subscribe_deprecated_async_fails(
hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator hass: HomeAssistant, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator
) -> None: ) -> None:
"""Test the subscription of a topic using deprecated coroutine signature fails.""" """Test the subscription of a topic using deprecated coroutine signature fails."""
await mqtt_mock_entry_no_yaml_config()
@callback @callback
def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None: def async_record_calls(topic: str, payload: ReceivePayloadType, qos: int) -> None:
@ -2256,34 +2258,29 @@ async def test_mqtt_subscribes_topics_on_connect(
mqtt_client_mock.subscribe.assert_any_call("still/pending", 1) mqtt_client_mock.subscribe.assert_any_call("still/pending", 1)
async def test_update_incomplete_entry( async def test_default_entry_setting_are_applied(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator, mqtt_mock_entry_no_yaml_config: MqttMockHAClientGenerator,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test if the MQTT component loads when config entry data is incomplete.""" """Test if the MQTT component loads when config entry data not has all default settings."""
data = ( data = (
'{ "device":{"identifiers":["0AFFD2"]},' '{ "device":{"identifiers":["0AFFD2"]},'
' "state_topic": "foobar/sensor",' ' "state_topic": "foobar/sensor",'
' "unique_id": "unique" }' ' "unique_id": "unique" }'
) )
# Config entry data is incomplete # Config entry data is incomplete but valid according the schema
entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
entry.data = {"broker": "test-broker", "port": 1234} entry.data = {"broker": "test-broker", "port": 1234}
await mqtt_mock_entry_no_yaml_config() await mqtt_mock_entry_no_yaml_config()
await hass.async_block_till_done() await hass.async_block_till_done()
# Config entry data should now be updated
assert dict(entry.data) == {
"broker": "test-broker",
"port": 1234,
"discovery_prefix": "homeassistant",
}
# Discover a device to verify the entry was setup correctly # Discover a device to verify the entry was setup correctly
# The discovery prefix should be the default
# And that the default settings were merged
async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data) async_fire_mqtt_message(hass, "homeassistant/sensor/bla/config", data)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -2297,12 +2294,15 @@ async def test_fail_no_broker(
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test if the MQTT component loads when broker configuration is missing.""" """Test the MQTT entry setup when broker configuration is missing."""
# Config entry data is incomplete # Config entry data is incomplete
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={}) entry = MockConfigEntry(domain=mqtt.DOMAIN, data={})
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert not await hass.config_entries.async_setup(entry.entry_id) assert not await hass.config_entries.async_setup(entry.entry_id)
assert "MQTT broker is not configured, please configure it" in caplog.text assert (
"The MQTT config entry is invalid, please correct it: required key not provided @ data['broker']"
in caplog.text
)
@pytest.mark.no_fail_on_log_exception @pytest.mark.no_fail_on_log_exception
@ -3312,7 +3312,7 @@ async def test_setup_manual_items_with_unique_ids(
assert bool("Platform mqtt does not generate unique IDs." in caplog.text) != unique assert bool("Platform mqtt does not generate unique IDs." in caplog.text) != unique
async def test_remove_unknown_conf_entry_options( async def test_fail_with_unknown_conf_entry_options(
hass: HomeAssistant, hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient, mqtt_client_mock: MqttMockPahoClient,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
@ -3331,14 +3331,9 @@ async def test_remove_unknown_conf_entry_options(
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(entry.entry_id) assert await hass.config_entries.async_setup(entry.entry_id) is False
await hass.async_block_till_done()
assert mqtt.client.CONF_PROTOCOL not in entry.data assert ("extra keys not allowed @ data['old_option']") in caplog.text
assert (
"The following unsupported configuration options were removed from the "
"MQTT config entry: {'old_option'}"
) in caplog.text
@patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT]) @patch("homeassistant.components.mqtt.PLATFORMS", [Platform.LIGHT])

View file

@ -967,9 +967,10 @@ async def _mqtt_mock_entry(
nonlocal mock_mqtt_instance nonlocal mock_mqtt_instance
nonlocal real_mqtt_instance nonlocal real_mqtt_instance
real_mqtt_instance = real_mqtt(*args, **kwargs) real_mqtt_instance = real_mqtt(*args, **kwargs)
spec = dir(real_mqtt_instance) + ["_mqttc"]
mock_mqtt_instance = MqttMockHAClient( mock_mqtt_instance = MqttMockHAClient(
return_value=real_mqtt_instance, return_value=real_mqtt_instance,
spec_set=real_mqtt_instance, spec_set=spec,
wraps=real_mqtt_instance, wraps=real_mqtt_instance,
) )
return mock_mqtt_instance return mock_mqtt_instance