De-duplicate MQTT config_flow code (#79369)

* De-duplicate config_flow code

* De duplicate code birth and will
This commit is contained in:
Jan Bouwhuis 2022-10-07 10:12:19 +02:00 committed by GitHub
parent 9a81b65815
commit aee82e2b3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 121 deletions

View file

@ -54,6 +54,7 @@ from .const import (
CONF_TLS_INSECURE, CONF_TLS_INSECURE,
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
DEFAULT_ENCODING, DEFAULT_ENCODING,
DEFAULT_PROTOCOL,
DEFAULT_QOS, DEFAULT_QOS,
MQTT_CONNECTED, MQTT_CONNECTED,
MQTT_DISCONNECTED, MQTT_DISCONNECTED,
@ -272,7 +273,7 @@ class MqttClientSetup:
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if config[CONF_PROTOCOL] == PROTOCOL_31: if config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_31:
proto = mqtt.MQTTv31 proto = mqtt.MQTTv31
else: else:
proto = mqtt.MQTTv311 proto = mqtt.MQTTv311

View file

@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable
import queue import queue
from types import MappingProxyType
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@ -15,10 +17,9 @@ from homeassistant.const import (
CONF_PASSWORD, CONF_PASSWORD,
CONF_PAYLOAD, CONF_PAYLOAD,
CONF_PORT, CONF_PORT,
CONF_PROTOCOL,
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -33,6 +34,7 @@ from .const import (
CONF_WILL_MESSAGE, CONF_WILL_MESSAGE,
DEFAULT_BIRTH, DEFAULT_BIRTH,
DEFAULT_DISCOVERY, DEFAULT_DISCOVERY,
DEFAULT_PORT,
DEFAULT_WILL, DEFAULT_WILL,
DOMAIN, DOMAIN,
) )
@ -56,9 +58,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
return MQTTOptionsFlowHandler(config_entry) return MQTTOptionsFlowHandler(config_entry)
async def async_step_user( async def async_step_user(self, user_input: ConfigType | None = None) -> FlowResult:
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -66,35 +66,38 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return await self.async_step_broker() return await self.async_step_broker()
async def async_step_broker( async def async_step_broker(
self, user_input: dict[str, Any] | None = None self, user_input: ConfigType | None = None
) -> FlowResult: ) -> FlowResult:
"""Confirm the setup.""" """Confirm the setup."""
errors = {} yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
errors: dict[str, str] = {}
if user_input is not None: fields: OrderedDict[Any, Any] = OrderedDict()
validated_user_input: ConfigType = {}
if await async_get_broker_settings(
self.hass,
fields,
yaml_config,
None,
user_input,
validated_user_input,
errors,
):
test_config: ConfigType = yaml_config.copy()
test_config.update(validated_user_input)
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
get_mqtt_data(self.hass, True).config or {}, test_config,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
) )
if can_connect: if can_connect:
user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY validated_user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
return self.async_create_entry( return self.async_create_entry(
title=user_input[CONF_BROKER], data=user_input title=validated_user_input[CONF_BROKER],
data=validated_user_input,
) )
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
fields = OrderedDict()
fields[vol.Required(CONF_BROKER)] = str
fields[vol.Required(CONF_PORT, default=1883)] = vol.Coerce(int)
fields[vol.Optional(CONF_USERNAME)] = str
fields[vol.Optional(CONF_PASSWORD)] = str
return self.async_show_form( return self.async_show_form(
step_id="broker", data_schema=vol.Schema(fields), errors=errors step_id="broker", data_schema=vol.Schema(fields), errors=errors
) )
@ -111,26 +114,22 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Confirm a Hass.io discovery.""" """Confirm a Hass.io discovery."""
errors = {} errors: dict[str, str] = {}
assert self._hassio_discovery assert self._hassio_discovery
if user_input is not None: if user_input is not None:
data = self._hassio_discovery data: ConfigType = self._hassio_discovery.copy()
data[CONF_BROKER] = data.pop(CONF_HOST)
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
get_mqtt_data(self.hass, True).config or {}, data,
data[CONF_HOST],
data[CONF_PORT],
data.get(CONF_USERNAME),
data.get(CONF_PASSWORD),
data.get(CONF_PROTOCOL),
) )
if can_connect: if can_connect:
return self.async_create_entry( return self.async_create_entry(
title=data["addon"], title=data["addon"],
data={ data={
CONF_BROKER: data[CONF_HOST], CONF_BROKER: data[CONF_BROKER],
CONF_PORT: data[CONF_PORT], CONF_PORT: data[CONF_PORT],
CONF_USERNAME: data.get(CONF_USERNAME), CONF_USERNAME: data.get(CONF_USERNAME),
CONF_PASSWORD: data.get(CONF_PASSWORD), CONF_PASSWORD: data.get(CONF_PASSWORD),
@ -164,46 +163,32 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT broker configuration.""" """Manage the MQTT broker configuration."""
mqtt_data = get_mqtt_data(self.hass, True) errors: dict[str, str] = {}
yaml_config = mqtt_data.config or {} yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
errors = {} fields: OrderedDict[Any, Any] = OrderedDict()
current_config = self.config_entry.data validated_user_input: ConfigType = {}
if user_input is not None: if await async_get_broker_settings(
self.hass,
fields,
yaml_config,
self.config_entry.data,
user_input,
validated_user_input,
errors,
):
test_config: ConfigType = yaml_config.copy()
test_config.update(validated_user_input)
can_connect = await self.hass.async_add_executor_job( can_connect = await self.hass.async_add_executor_job(
try_connection, try_connection,
yaml_config, test_config,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
) )
if can_connect: if can_connect:
self.broker_config.update(user_input) self.broker_config.update(validated_user_input)
return await self.async_step_options() return await self.async_step_options()
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
fields = OrderedDict()
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
current_port = current_config.get(CONF_PORT, yaml_config.get(CONF_PORT))
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_user},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_pass},
)
] = str
return self.async_show_form( return self.async_show_form(
step_id="broker", step_id="broker",
data_schema=vol.Schema(fields), data_schema=vol.Schema(fields),
@ -212,53 +197,61 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
) )
async def async_step_options( async def async_step_options(
self, user_input: dict[str, Any] | None = None self, user_input: ConfigType | None = None
) -> FlowResult: ) -> FlowResult:
"""Manage the MQTT options.""" """Manage the MQTT options."""
mqtt_data = get_mqtt_data(self.hass, True)
errors = {} errors = {}
current_config = self.config_entry.data current_config = self.config_entry.data
yaml_config = mqtt_data.config or {} yaml_config = get_mqtt_data(self.hass, True).config or {}
options_config: dict[str, Any] = {} options_config: ConfigType = {}
if user_input is not None: bad_input: bool = False
bad_birth = False
bad_will = False
def _birth_will(birt_or_will: str) -> dict:
"""Return the user input for birth or will."""
assert user_input
return {
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],
ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""),
ATTR_QOS: user_input[f"{birt_or_will}_qos"],
ATTR_RETAIN: user_input[f"{birt_or_will}_retain"],
}
def _validate(
field: str, values: ConfigType, error_code: str, schema: Callable
):
"""Validate the user input."""
nonlocal bad_input
try:
option_values = schema(values)
options_config[field] = option_values
except vol.Invalid:
errors["base"] = error_code
bad_input = True
if user_input is not None:
# validate input
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
if "birth_topic" in user_input: if "birth_topic" in user_input:
birth_message = { _validate(
ATTR_TOPIC: user_input["birth_topic"], CONF_BIRTH_MESSAGE,
ATTR_PAYLOAD: user_input.get("birth_payload", ""), _birth_will("birth"),
ATTR_QOS: user_input["birth_qos"], "bad_birth",
ATTR_RETAIN: user_input["birth_retain"], MQTT_WILL_BIRTH_SCHEMA,
} )
try:
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
options_config[CONF_BIRTH_MESSAGE] = birth_message
except vol.Invalid:
errors["base"] = "bad_birth"
bad_birth = True
if not user_input["birth_enable"]: if not user_input["birth_enable"]:
options_config[CONF_BIRTH_MESSAGE] = {} options_config[CONF_BIRTH_MESSAGE] = {}
if "will_topic" in user_input: if "will_topic" in user_input:
will_message = { _validate(
ATTR_TOPIC: user_input["will_topic"], CONF_WILL_MESSAGE,
ATTR_PAYLOAD: user_input.get("will_payload", ""), _birth_will("will"),
ATTR_QOS: user_input["will_qos"], "bad_will",
ATTR_RETAIN: user_input["will_retain"], MQTT_WILL_BIRTH_SCHEMA,
} )
try:
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
options_config[CONF_WILL_MESSAGE] = will_message
except vol.Invalid:
errors["base"] = "bad_will"
bad_will = True
if not user_input["will_enable"]: if not user_input["will_enable"]:
options_config[CONF_WILL_MESSAGE] = {} options_config[CONF_WILL_MESSAGE] = {}
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY] if not bad_input:
if not bad_birth and not bad_will:
updated_config = {} updated_config = {}
updated_config.update(self.broker_config) updated_config.update(self.broker_config)
updated_config.update(options_config) updated_config.update(options_config)
@ -285,6 +278,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY) CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
) )
# build form
fields: OrderedDict[vol.Marker, Any] = OrderedDict() fields: OrderedDict[vol.Marker, Any] = OrderedDict()
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool
@ -338,28 +332,66 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
) )
def try_connection( async def async_get_broker_settings(
hass: HomeAssistant,
fields: OrderedDict[Any, Any],
yaml_config: ConfigType, yaml_config: ConfigType,
broker: str, entry_config: MappingProxyType[str, Any] | None,
port: int, user_input: ConfigType | None,
username: str | None, validated_user_input: ConfigType,
password: str | None, errors: dict[str, str],
protocol: str = "3.1", ) -> bool:
"""Build the config flow schema to collect the broker settings.
Returns True when settings are collected successfully.
"""
user_input_basic: ConfigType = ConfigType()
current_config = entry_config.copy() if entry_config is not None else ConfigType()
if user_input is not None:
validated_user_input.update(user_input)
return True
# Update the current settings the the new posted data to fill the defaults
current_config.update(user_input_basic)
# Get default settings (if any)
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
current_port = current_config.get(
CONF_PORT, yaml_config.get(CONF_PORT, DEFAULT_PORT)
)
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
# Build form
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_user},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_pass},
)
] = str
# Show form
return False
def try_connection(
user_input: ConfigType,
) -> bool: ) -> bool:
"""Test if we can connect to an MQTT broker.""" """Test if we can connect to an MQTT broker."""
# We don't import on the top because some integrations # We don't import on the top because some integrations
# should be able to optionally rely on MQTT. # should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml client = MqttClientSetup(user_input).client
entry_config = {
CONF_BROKER: broker,
CONF_PORT: port,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_PROTOCOL: protocol,
}
client = MqttClientSetup({**yaml_config, **entry_config}).client
result: queue.Queue[bool] = queue.Queue(maxsize=1) result: queue.Queue[bool] = queue.Queue(maxsize=1)
@ -369,7 +401,7 @@ def try_connection(
client.on_connect = on_connect client.on_connect = on_connect
client.connect_async(broker, port) client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
client.loop_start() client.loop_start()
try: try:

View file

@ -40,6 +40,7 @@ DEFAULT_ENCODING = "utf-8"
DEFAULT_QOS = 0 DEFAULT_QOS = 0
DEFAULT_PAYLOAD_AVAILABLE = "online" DEFAULT_PAYLOAD_AVAILABLE = "online"
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline" DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
DEFAULT_PORT = 1883
DEFAULT_RETAIN = False DEFAULT_RETAIN = False
DEFAULT_BIRTH = { DEFAULT_BIRTH = {
@ -67,6 +68,8 @@ PAYLOAD_NONE = "None"
PROTOCOL_31 = "3.1" PROTOCOL_31 = "3.1"
PROTOCOL_311 = "3.1.1" PROTOCOL_311 = "3.1.1"
DEFAULT_PROTOCOL = PROTOCOL_311
PLATFORMS = [ PLATFORMS = [
Platform.ALARM_CONTROL_PANEL, Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,

View file

@ -188,15 +188,12 @@ async def test_manual_config_set(
# Check we tried the connection, with precedence for config entry settings # Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with( mock_try_connection.assert_called_once_with(
{ {
"broker": "bla", "broker": "127.0.0.1",
"protocol": "3.1.1",
"keepalive": 60, "keepalive": 60,
"discovery_prefix": "homeassistant", "discovery_prefix": "homeassistant",
"protocol": "3.1.1", "port": 1883,
}, },
"127.0.0.1",
1883,
None,
None,
) )
# Check config entry got setup # Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1 assert len(mock_finish_setup.mock_calls) == 1
@ -291,6 +288,44 @@ async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_set
assert len(mock_finish_setup.mock_calls) == 1 assert len(mock_finish_setup.mock_calls) == 1
async def test_hassio_cannot_connect(
hass, mock_try_connection_time_out, mock_finish_setup
):
"""Test a config flow is aborted when a connection was not successful."""
mock_try_connection.return_value = True
result = await hass.config_entries.flow.async_init(
"mqtt",
data=HassioServiceInfo(
config={
"addon": "Mock Addon",
"host": "mock-broker",
"port": 1883,
"username": "mock-user",
"password": "mock-pass",
"protocol": "3.1.1", # Set by the addon's discovery, ignored by HA
"ssl": False, # Set by the addon's discovery, ignored by HA
}
),
context={"source": config_entries.SOURCE_HASSIO},
)
assert result["type"] == "form"
assert result["step_id"] == "hassio_confirm"
assert result["description_placeholders"] == {"addon": "Mock Addon"}
mock_try_connection_time_out.reset_mock()
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"discovery": True}
)
assert result["type"] == "form"
assert result["errors"]["base"] == "cannot_connect"
# Check we tried the connection
assert len(mock_try_connection_time_out.mock_calls)
# Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 0
@patch( @patch(
"homeassistant.config.async_hass_config_yaml", "homeassistant.config.async_hass_config_yaml",
AsyncMock(return_value={}), AsyncMock(return_value={}),
@ -299,7 +334,7 @@ async def test_option_flow(
hass, hass,
mqtt_mock_entry_no_yaml_config, mqtt_mock_entry_no_yaml_config,
mock_try_connection, mock_try_connection,
mock_reload_after_entry_update, caplog,
): ):
"""Test config flow options.""" """Test config flow options."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config() mqtt_mock = await mqtt_mock_entry_no_yaml_config()
@ -372,7 +407,10 @@ async def test_option_flow(
await hass.async_block_till_done() await hass.async_block_till_done()
assert config_entry.title == "another-broker" assert config_entry.title == "another-broker"
# assert that the entry was reloaded with the new config # assert that the entry was reloaded with the new config
assert mock_reload_after_entry_update.call_count == 1 assert (
"<Event call_service[L]: domain=mqtt, service=reload, service_data=>"
in caplog.text
)
async def test_disable_birth_will( async def test_disable_birth_will(