De-duplicate MQTT config_flow code (#79369)

* De-duplicate config_flow code

* De duplicate code birth and will
This commit is contained in:
Jan Bouwhuis 2022-10-07 10:12:19 +02:00 committed by GitHub
parent 9a81b65815
commit aee82e2b3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 121 deletions

View file

@ -54,6 +54,7 @@ from .const import (
CONF_TLS_INSECURE,
CONF_WILL_MESSAGE,
DEFAULT_ENCODING,
DEFAULT_PROTOCOL,
DEFAULT_QOS,
MQTT_CONNECTED,
MQTT_DISCONNECTED,
@ -272,7 +273,7 @@ class MqttClientSetup:
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
if config[CONF_PROTOCOL] == PROTOCOL_31:
if config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_31:
proto = mqtt.MQTTv31
else:
proto = mqtt.MQTTv311

View file

@ -2,7 +2,9 @@
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable
import queue
from types import MappingProxyType
from typing import Any
import voluptuous as vol
@ -15,10 +17,9 @@ from homeassistant.const import (
CONF_PASSWORD,
CONF_PAYLOAD,
CONF_PORT,
CONF_PROTOCOL,
CONF_USERNAME,
)
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.typing import ConfigType
@ -33,6 +34,7 @@ from .const import (
CONF_WILL_MESSAGE,
DEFAULT_BIRTH,
DEFAULT_DISCOVERY,
DEFAULT_PORT,
DEFAULT_WILL,
DOMAIN,
)
@ -56,9 +58,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Get the options flow for this handler."""
return MQTTOptionsFlowHandler(config_entry)
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
async def async_step_user(self, user_input: ConfigType | None = None) -> FlowResult:
"""Handle a flow initialized by the user."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -66,35 +66,38 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return await self.async_step_broker()
async def async_step_broker(
self, user_input: dict[str, Any] | None = None
self, user_input: ConfigType | None = None
) -> FlowResult:
"""Confirm the setup."""
errors = {}
if user_input is not None:
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
errors: dict[str, str] = {}
fields: OrderedDict[Any, Any] = OrderedDict()
validated_user_input: ConfigType = {}
if await async_get_broker_settings(
self.hass,
fields,
yaml_config,
None,
user_input,
validated_user_input,
errors,
):
test_config: ConfigType = yaml_config.copy()
test_config.update(validated_user_input)
can_connect = await self.hass.async_add_executor_job(
try_connection,
get_mqtt_data(self.hass, True).config or {},
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
test_config,
)
if can_connect:
user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
validated_user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
return self.async_create_entry(
title=user_input[CONF_BROKER], data=user_input
title=validated_user_input[CONF_BROKER],
data=validated_user_input,
)
errors["base"] = "cannot_connect"
fields = OrderedDict()
fields[vol.Required(CONF_BROKER)] = str
fields[vol.Required(CONF_PORT, default=1883)] = vol.Coerce(int)
fields[vol.Optional(CONF_USERNAME)] = str
fields[vol.Optional(CONF_PASSWORD)] = str
return self.async_show_form(
step_id="broker", data_schema=vol.Schema(fields), errors=errors
)
@ -111,26 +114,22 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Confirm a Hass.io discovery."""
errors = {}
errors: dict[str, str] = {}
assert self._hassio_discovery
if user_input is not None:
data = self._hassio_discovery
data: ConfigType = self._hassio_discovery.copy()
data[CONF_BROKER] = data.pop(CONF_HOST)
can_connect = await self.hass.async_add_executor_job(
try_connection,
get_mqtt_data(self.hass, True).config or {},
data[CONF_HOST],
data[CONF_PORT],
data.get(CONF_USERNAME),
data.get(CONF_PASSWORD),
data.get(CONF_PROTOCOL),
data,
)
if can_connect:
return self.async_create_entry(
title=data["addon"],
data={
CONF_BROKER: data[CONF_HOST],
CONF_BROKER: data[CONF_BROKER],
CONF_PORT: data[CONF_PORT],
CONF_USERNAME: data.get(CONF_USERNAME),
CONF_PASSWORD: data.get(CONF_PASSWORD),
@ -164,46 +163,32 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the MQTT broker configuration."""
mqtt_data = get_mqtt_data(self.hass, True)
yaml_config = mqtt_data.config or {}
errors = {}
current_config = self.config_entry.data
if user_input is not None:
errors: dict[str, str] = {}
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
fields: OrderedDict[Any, Any] = OrderedDict()
validated_user_input: ConfigType = {}
if await async_get_broker_settings(
self.hass,
fields,
yaml_config,
self.config_entry.data,
user_input,
validated_user_input,
errors,
):
test_config: ConfigType = yaml_config.copy()
test_config.update(validated_user_input)
can_connect = await self.hass.async_add_executor_job(
try_connection,
yaml_config,
user_input[CONF_BROKER],
user_input[CONF_PORT],
user_input.get(CONF_USERNAME),
user_input.get(CONF_PASSWORD),
test_config,
)
if can_connect:
self.broker_config.update(user_input)
self.broker_config.update(validated_user_input)
return await self.async_step_options()
errors["base"] = "cannot_connect"
fields = OrderedDict()
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
current_port = current_config.get(CONF_PORT, yaml_config.get(CONF_PORT))
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_user},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_pass},
)
] = str
return self.async_show_form(
step_id="broker",
data_schema=vol.Schema(fields),
@ -212,53 +197,61 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
)
async def async_step_options(
self, user_input: dict[str, Any] | None = None
self, user_input: ConfigType | None = None
) -> FlowResult:
"""Manage the MQTT options."""
mqtt_data = get_mqtt_data(self.hass, True)
errors = {}
current_config = self.config_entry.data
yaml_config = mqtt_data.config or {}
options_config: dict[str, Any] = {}
if user_input is not None:
bad_birth = False
bad_will = False
yaml_config = get_mqtt_data(self.hass, True).config or {}
options_config: ConfigType = {}
bad_input: bool = False
def _birth_will(birt_or_will: str) -> dict:
"""Return the user input for birth or will."""
assert user_input
return {
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],
ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""),
ATTR_QOS: user_input[f"{birt_or_will}_qos"],
ATTR_RETAIN: user_input[f"{birt_or_will}_retain"],
}
def _validate(
field: str, values: ConfigType, error_code: str, schema: Callable
):
"""Validate the user input."""
nonlocal bad_input
try:
option_values = schema(values)
options_config[field] = option_values
except vol.Invalid:
errors["base"] = error_code
bad_input = True
if user_input is not None:
# validate input
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
if "birth_topic" in user_input:
birth_message = {
ATTR_TOPIC: user_input["birth_topic"],
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
ATTR_QOS: user_input["birth_qos"],
ATTR_RETAIN: user_input["birth_retain"],
}
try:
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
options_config[CONF_BIRTH_MESSAGE] = birth_message
except vol.Invalid:
errors["base"] = "bad_birth"
bad_birth = True
_validate(
CONF_BIRTH_MESSAGE,
_birth_will("birth"),
"bad_birth",
MQTT_WILL_BIRTH_SCHEMA,
)
if not user_input["birth_enable"]:
options_config[CONF_BIRTH_MESSAGE] = {}
if "will_topic" in user_input:
will_message = {
ATTR_TOPIC: user_input["will_topic"],
ATTR_PAYLOAD: user_input.get("will_payload", ""),
ATTR_QOS: user_input["will_qos"],
ATTR_RETAIN: user_input["will_retain"],
}
try:
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
options_config[CONF_WILL_MESSAGE] = will_message
except vol.Invalid:
errors["base"] = "bad_will"
bad_will = True
_validate(
CONF_WILL_MESSAGE,
_birth_will("will"),
"bad_will",
MQTT_WILL_BIRTH_SCHEMA,
)
if not user_input["will_enable"]:
options_config[CONF_WILL_MESSAGE] = {}
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
if not bad_birth and not bad_will:
if not bad_input:
updated_config = {}
updated_config.update(self.broker_config)
updated_config.update(options_config)
@ -285,6 +278,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
)
# build form
fields: OrderedDict[vol.Marker, Any] = OrderedDict()
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool
@ -338,28 +332,66 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
)
def try_connection(
async def async_get_broker_settings(
hass: HomeAssistant,
fields: OrderedDict[Any, Any],
yaml_config: ConfigType,
broker: str,
port: int,
username: str | None,
password: str | None,
protocol: str = "3.1",
entry_config: MappingProxyType[str, Any] | None,
user_input: ConfigType | None,
validated_user_input: ConfigType,
errors: dict[str, str],
) -> bool:
"""Build the config flow schema to collect the broker settings.
Returns True when settings are collected successfully.
"""
user_input_basic: ConfigType = ConfigType()
current_config = entry_config.copy() if entry_config is not None else ConfigType()
if user_input is not None:
validated_user_input.update(user_input)
return True
# Update the current settings the the new posted data to fill the defaults
current_config.update(user_input_basic)
# Get default settings (if any)
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
current_port = current_config.get(
CONF_PORT, yaml_config.get(CONF_PORT, DEFAULT_PORT)
)
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
# Build form
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
fields[
vol.Optional(
CONF_USERNAME,
description={"suggested_value": current_user},
)
] = str
fields[
vol.Optional(
CONF_PASSWORD,
description={"suggested_value": current_pass},
)
] = str
# Show form
return False
def try_connection(
user_input: ConfigType,
) -> bool:
"""Test if we can connect to an MQTT broker."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
# Get the config from configuration.yaml
entry_config = {
CONF_BROKER: broker,
CONF_PORT: port,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_PROTOCOL: protocol,
}
client = MqttClientSetup({**yaml_config, **entry_config}).client
client = MqttClientSetup(user_input).client
result: queue.Queue[bool] = queue.Queue(maxsize=1)
@ -369,7 +401,7 @@ def try_connection(
client.on_connect = on_connect
client.connect_async(broker, port)
client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
client.loop_start()
try:

View file

@ -40,6 +40,7 @@ DEFAULT_ENCODING = "utf-8"
DEFAULT_QOS = 0
DEFAULT_PAYLOAD_AVAILABLE = "online"
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
DEFAULT_PORT = 1883
DEFAULT_RETAIN = False
DEFAULT_BIRTH = {
@ -67,6 +68,8 @@ PAYLOAD_NONE = "None"
PROTOCOL_31 = "3.1"
PROTOCOL_311 = "3.1.1"
DEFAULT_PROTOCOL = PROTOCOL_311
PLATFORMS = [
Platform.ALARM_CONTROL_PANEL,
Platform.BINARY_SENSOR,

View file

@ -188,15 +188,12 @@ async def test_manual_config_set(
# Check we tried the connection, with precedence for config entry settings
mock_try_connection.assert_called_once_with(
{
"broker": "bla",
"broker": "127.0.0.1",
"protocol": "3.1.1",
"keepalive": 60,
"discovery_prefix": "homeassistant",
"protocol": "3.1.1",
"port": 1883,
},
"127.0.0.1",
1883,
None,
None,
)
# Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 1
@ -291,6 +288,44 @@ async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_set
assert len(mock_finish_setup.mock_calls) == 1
async def test_hassio_cannot_connect(
hass, mock_try_connection_time_out, mock_finish_setup
):
"""Test a config flow is aborted when a connection was not successful."""
mock_try_connection.return_value = True
result = await hass.config_entries.flow.async_init(
"mqtt",
data=HassioServiceInfo(
config={
"addon": "Mock Addon",
"host": "mock-broker",
"port": 1883,
"username": "mock-user",
"password": "mock-pass",
"protocol": "3.1.1", # Set by the addon's discovery, ignored by HA
"ssl": False, # Set by the addon's discovery, ignored by HA
}
),
context={"source": config_entries.SOURCE_HASSIO},
)
assert result["type"] == "form"
assert result["step_id"] == "hassio_confirm"
assert result["description_placeholders"] == {"addon": "Mock Addon"}
mock_try_connection_time_out.reset_mock()
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"discovery": True}
)
assert result["type"] == "form"
assert result["errors"]["base"] == "cannot_connect"
# Check we tried the connection
assert len(mock_try_connection_time_out.mock_calls)
# Check config entry got setup
assert len(mock_finish_setup.mock_calls) == 0
@patch(
"homeassistant.config.async_hass_config_yaml",
AsyncMock(return_value={}),
@ -299,7 +334,7 @@ async def test_option_flow(
hass,
mqtt_mock_entry_no_yaml_config,
mock_try_connection,
mock_reload_after_entry_update,
caplog,
):
"""Test config flow options."""
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
@ -372,7 +407,10 @@ async def test_option_flow(
await hass.async_block_till_done()
assert config_entry.title == "another-broker"
# assert that the entry was reloaded with the new config
assert mock_reload_after_entry_update.call_count == 1
assert (
"<Event call_service[L]: domain=mqtt, service=reload, service_data=>"
in caplog.text
)
async def test_disable_birth_will(