De-duplicate MQTT config_flow code (#79369)
* De-duplicate config_flow code * De duplicate code birth and will
This commit is contained in:
parent
9a81b65815
commit
aee82e2b3b
4 changed files with 195 additions and 121 deletions
|
@ -54,6 +54,7 @@ from .const import (
|
|||
CONF_TLS_INSECURE,
|
||||
CONF_WILL_MESSAGE,
|
||||
DEFAULT_ENCODING,
|
||||
DEFAULT_PROTOCOL,
|
||||
DEFAULT_QOS,
|
||||
MQTT_CONNECTED,
|
||||
MQTT_DISCONNECTED,
|
||||
|
@ -272,7 +273,7 @@ class MqttClientSetup:
|
|||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
if config[CONF_PROTOCOL] == PROTOCOL_31:
|
||||
if config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_31:
|
||||
proto = mqtt.MQTTv31
|
||||
else:
|
||||
proto = mqtt.MQTTv311
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
import queue
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -15,10 +17,9 @@ from homeassistant.const import (
|
|||
CONF_PASSWORD,
|
||||
CONF_PAYLOAD,
|
||||
CONF_PORT,
|
||||
CONF_PROTOCOL,
|
||||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
|
@ -33,6 +34,7 @@ from .const import (
|
|||
CONF_WILL_MESSAGE,
|
||||
DEFAULT_BIRTH,
|
||||
DEFAULT_DISCOVERY,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_WILL,
|
||||
DOMAIN,
|
||||
)
|
||||
|
@ -56,9 +58,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"""Get the options flow for this handler."""
|
||||
return MQTTOptionsFlowHandler(config_entry)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
async def async_step_user(self, user_input: ConfigType | None = None) -> FlowResult:
|
||||
"""Handle a flow initialized by the user."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
@ -66,35 +66,38 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
return await self.async_step_broker()
|
||||
|
||||
async def async_step_broker(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
self, user_input: ConfigType | None = None
|
||||
) -> FlowResult:
|
||||
"""Confirm the setup."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
|
||||
errors: dict[str, str] = {}
|
||||
fields: OrderedDict[Any, Any] = OrderedDict()
|
||||
validated_user_input: ConfigType = {}
|
||||
if await async_get_broker_settings(
|
||||
self.hass,
|
||||
fields,
|
||||
yaml_config,
|
||||
None,
|
||||
user_input,
|
||||
validated_user_input,
|
||||
errors,
|
||||
):
|
||||
test_config: ConfigType = yaml_config.copy()
|
||||
test_config.update(validated_user_input)
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
get_mqtt_data(self.hass, True).config or {},
|
||||
user_input[CONF_BROKER],
|
||||
user_input[CONF_PORT],
|
||||
user_input.get(CONF_USERNAME),
|
||||
user_input.get(CONF_PASSWORD),
|
||||
test_config,
|
||||
)
|
||||
|
||||
if can_connect:
|
||||
user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
|
||||
validated_user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
|
||||
return self.async_create_entry(
|
||||
title=user_input[CONF_BROKER], data=user_input
|
||||
title=validated_user_input[CONF_BROKER],
|
||||
data=validated_user_input,
|
||||
)
|
||||
|
||||
errors["base"] = "cannot_connect"
|
||||
|
||||
fields = OrderedDict()
|
||||
fields[vol.Required(CONF_BROKER)] = str
|
||||
fields[vol.Required(CONF_PORT, default=1883)] = vol.Coerce(int)
|
||||
fields[vol.Optional(CONF_USERNAME)] = str
|
||||
fields[vol.Optional(CONF_PASSWORD)] = str
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="broker", data_schema=vol.Schema(fields), errors=errors
|
||||
)
|
||||
|
@ -111,26 +114,22 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Confirm a Hass.io discovery."""
|
||||
errors = {}
|
||||
errors: dict[str, str] = {}
|
||||
assert self._hassio_discovery
|
||||
|
||||
if user_input is not None:
|
||||
data = self._hassio_discovery
|
||||
data: ConfigType = self._hassio_discovery.copy()
|
||||
data[CONF_BROKER] = data.pop(CONF_HOST)
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
get_mqtt_data(self.hass, True).config or {},
|
||||
data[CONF_HOST],
|
||||
data[CONF_PORT],
|
||||
data.get(CONF_USERNAME),
|
||||
data.get(CONF_PASSWORD),
|
||||
data.get(CONF_PROTOCOL),
|
||||
data,
|
||||
)
|
||||
|
||||
if can_connect:
|
||||
return self.async_create_entry(
|
||||
title=data["addon"],
|
||||
data={
|
||||
CONF_BROKER: data[CONF_HOST],
|
||||
CONF_BROKER: data[CONF_BROKER],
|
||||
CONF_PORT: data[CONF_PORT],
|
||||
CONF_USERNAME: data.get(CONF_USERNAME),
|
||||
CONF_PASSWORD: data.get(CONF_PASSWORD),
|
||||
|
@ -164,46 +163,32 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Manage the MQTT broker configuration."""
|
||||
mqtt_data = get_mqtt_data(self.hass, True)
|
||||
yaml_config = mqtt_data.config or {}
|
||||
errors = {}
|
||||
current_config = self.config_entry.data
|
||||
if user_input is not None:
|
||||
errors: dict[str, str] = {}
|
||||
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
|
||||
fields: OrderedDict[Any, Any] = OrderedDict()
|
||||
validated_user_input: ConfigType = {}
|
||||
if await async_get_broker_settings(
|
||||
self.hass,
|
||||
fields,
|
||||
yaml_config,
|
||||
self.config_entry.data,
|
||||
user_input,
|
||||
validated_user_input,
|
||||
errors,
|
||||
):
|
||||
test_config: ConfigType = yaml_config.copy()
|
||||
test_config.update(validated_user_input)
|
||||
can_connect = await self.hass.async_add_executor_job(
|
||||
try_connection,
|
||||
yaml_config,
|
||||
user_input[CONF_BROKER],
|
||||
user_input[CONF_PORT],
|
||||
user_input.get(CONF_USERNAME),
|
||||
user_input.get(CONF_PASSWORD),
|
||||
test_config,
|
||||
)
|
||||
|
||||
if can_connect:
|
||||
self.broker_config.update(user_input)
|
||||
self.broker_config.update(validated_user_input)
|
||||
return await self.async_step_options()
|
||||
|
||||
errors["base"] = "cannot_connect"
|
||||
|
||||
fields = OrderedDict()
|
||||
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
|
||||
current_port = current_config.get(CONF_PORT, yaml_config.get(CONF_PORT))
|
||||
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
|
||||
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
|
||||
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
|
||||
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_USERNAME,
|
||||
description={"suggested_value": current_user},
|
||||
)
|
||||
] = str
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_PASSWORD,
|
||||
description={"suggested_value": current_pass},
|
||||
)
|
||||
] = str
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="broker",
|
||||
data_schema=vol.Schema(fields),
|
||||
|
@ -212,53 +197,61 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
)
|
||||
|
||||
async def async_step_options(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
self, user_input: ConfigType | None = None
|
||||
) -> FlowResult:
|
||||
"""Manage the MQTT options."""
|
||||
mqtt_data = get_mqtt_data(self.hass, True)
|
||||
errors = {}
|
||||
current_config = self.config_entry.data
|
||||
yaml_config = mqtt_data.config or {}
|
||||
options_config: dict[str, Any] = {}
|
||||
if user_input is not None:
|
||||
bad_birth = False
|
||||
bad_will = False
|
||||
yaml_config = get_mqtt_data(self.hass, True).config or {}
|
||||
options_config: ConfigType = {}
|
||||
bad_input: bool = False
|
||||
|
||||
def _birth_will(birt_or_will: str) -> dict:
|
||||
"""Return the user input for birth or will."""
|
||||
assert user_input
|
||||
return {
|
||||
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],
|
||||
ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""),
|
||||
ATTR_QOS: user_input[f"{birt_or_will}_qos"],
|
||||
ATTR_RETAIN: user_input[f"{birt_or_will}_retain"],
|
||||
}
|
||||
|
||||
def _validate(
|
||||
field: str, values: ConfigType, error_code: str, schema: Callable
|
||||
):
|
||||
"""Validate the user input."""
|
||||
nonlocal bad_input
|
||||
try:
|
||||
option_values = schema(values)
|
||||
options_config[field] = option_values
|
||||
except vol.Invalid:
|
||||
errors["base"] = error_code
|
||||
bad_input = True
|
||||
|
||||
if user_input is not None:
|
||||
# validate input
|
||||
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
||||
if "birth_topic" in user_input:
|
||||
birth_message = {
|
||||
ATTR_TOPIC: user_input["birth_topic"],
|
||||
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
|
||||
ATTR_QOS: user_input["birth_qos"],
|
||||
ATTR_RETAIN: user_input["birth_retain"],
|
||||
}
|
||||
try:
|
||||
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
|
||||
options_config[CONF_BIRTH_MESSAGE] = birth_message
|
||||
except vol.Invalid:
|
||||
errors["base"] = "bad_birth"
|
||||
bad_birth = True
|
||||
_validate(
|
||||
CONF_BIRTH_MESSAGE,
|
||||
_birth_will("birth"),
|
||||
"bad_birth",
|
||||
MQTT_WILL_BIRTH_SCHEMA,
|
||||
)
|
||||
if not user_input["birth_enable"]:
|
||||
options_config[CONF_BIRTH_MESSAGE] = {}
|
||||
|
||||
if "will_topic" in user_input:
|
||||
will_message = {
|
||||
ATTR_TOPIC: user_input["will_topic"],
|
||||
ATTR_PAYLOAD: user_input.get("will_payload", ""),
|
||||
ATTR_QOS: user_input["will_qos"],
|
||||
ATTR_RETAIN: user_input["will_retain"],
|
||||
}
|
||||
try:
|
||||
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
|
||||
options_config[CONF_WILL_MESSAGE] = will_message
|
||||
except vol.Invalid:
|
||||
errors["base"] = "bad_will"
|
||||
bad_will = True
|
||||
_validate(
|
||||
CONF_WILL_MESSAGE,
|
||||
_birth_will("will"),
|
||||
"bad_will",
|
||||
MQTT_WILL_BIRTH_SCHEMA,
|
||||
)
|
||||
if not user_input["will_enable"]:
|
||||
options_config[CONF_WILL_MESSAGE] = {}
|
||||
|
||||
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
||||
|
||||
if not bad_birth and not bad_will:
|
||||
if not bad_input:
|
||||
updated_config = {}
|
||||
updated_config.update(self.broker_config)
|
||||
updated_config.update(options_config)
|
||||
|
@ -285,6 +278,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
|
||||
)
|
||||
|
||||
# build form
|
||||
fields: OrderedDict[vol.Marker, Any] = OrderedDict()
|
||||
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool
|
||||
|
||||
|
@ -338,28 +332,66 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
)
|
||||
|
||||
|
||||
def try_connection(
|
||||
async def async_get_broker_settings(
|
||||
hass: HomeAssistant,
|
||||
fields: OrderedDict[Any, Any],
|
||||
yaml_config: ConfigType,
|
||||
broker: str,
|
||||
port: int,
|
||||
username: str | None,
|
||||
password: str | None,
|
||||
protocol: str = "3.1",
|
||||
entry_config: MappingProxyType[str, Any] | None,
|
||||
user_input: ConfigType | None,
|
||||
validated_user_input: ConfigType,
|
||||
errors: dict[str, str],
|
||||
) -> bool:
|
||||
"""Build the config flow schema to collect the broker settings.
|
||||
|
||||
Returns True when settings are collected successfully.
|
||||
"""
|
||||
user_input_basic: ConfigType = ConfigType()
|
||||
current_config = entry_config.copy() if entry_config is not None else ConfigType()
|
||||
|
||||
if user_input is not None:
|
||||
validated_user_input.update(user_input)
|
||||
return True
|
||||
|
||||
# Update the current settings the the new posted data to fill the defaults
|
||||
current_config.update(user_input_basic)
|
||||
|
||||
# Get default settings (if any)
|
||||
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
|
||||
current_port = current_config.get(
|
||||
CONF_PORT, yaml_config.get(CONF_PORT, DEFAULT_PORT)
|
||||
)
|
||||
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
|
||||
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
|
||||
|
||||
# Build form
|
||||
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
|
||||
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_USERNAME,
|
||||
description={"suggested_value": current_user},
|
||||
)
|
||||
] = str
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_PASSWORD,
|
||||
description={"suggested_value": current_pass},
|
||||
)
|
||||
] = str
|
||||
|
||||
# Show form
|
||||
return False
|
||||
|
||||
|
||||
def try_connection(
|
||||
user_input: ConfigType,
|
||||
) -> bool:
|
||||
"""Test if we can connect to an MQTT broker."""
|
||||
# We don't import on the top because some integrations
|
||||
# should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||
|
||||
# Get the config from configuration.yaml
|
||||
entry_config = {
|
||||
CONF_BROKER: broker,
|
||||
CONF_PORT: port,
|
||||
CONF_USERNAME: username,
|
||||
CONF_PASSWORD: password,
|
||||
CONF_PROTOCOL: protocol,
|
||||
}
|
||||
client = MqttClientSetup({**yaml_config, **entry_config}).client
|
||||
client = MqttClientSetup(user_input).client
|
||||
|
||||
result: queue.Queue[bool] = queue.Queue(maxsize=1)
|
||||
|
||||
|
@ -369,7 +401,7 @@ def try_connection(
|
|||
|
||||
client.on_connect = on_connect
|
||||
|
||||
client.connect_async(broker, port)
|
||||
client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
|
||||
client.loop_start()
|
||||
|
||||
try:
|
||||
|
|
|
@ -40,6 +40,7 @@ DEFAULT_ENCODING = "utf-8"
|
|||
DEFAULT_QOS = 0
|
||||
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
||||
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
||||
DEFAULT_PORT = 1883
|
||||
DEFAULT_RETAIN = False
|
||||
|
||||
DEFAULT_BIRTH = {
|
||||
|
@ -67,6 +68,8 @@ PAYLOAD_NONE = "None"
|
|||
PROTOCOL_31 = "3.1"
|
||||
PROTOCOL_311 = "3.1.1"
|
||||
|
||||
DEFAULT_PROTOCOL = PROTOCOL_311
|
||||
|
||||
PLATFORMS = [
|
||||
Platform.ALARM_CONTROL_PANEL,
|
||||
Platform.BINARY_SENSOR,
|
||||
|
|
|
@ -188,15 +188,12 @@ async def test_manual_config_set(
|
|||
# Check we tried the connection, with precedence for config entry settings
|
||||
mock_try_connection.assert_called_once_with(
|
||||
{
|
||||
"broker": "bla",
|
||||
"broker": "127.0.0.1",
|
||||
"protocol": "3.1.1",
|
||||
"keepalive": 60,
|
||||
"discovery_prefix": "homeassistant",
|
||||
"protocol": "3.1.1",
|
||||
"port": 1883,
|
||||
},
|
||||
"127.0.0.1",
|
||||
1883,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# Check config entry got setup
|
||||
assert len(mock_finish_setup.mock_calls) == 1
|
||||
|
@ -291,6 +288,44 @@ async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_set
|
|||
assert len(mock_finish_setup.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_hassio_cannot_connect(
|
||||
hass, mock_try_connection_time_out, mock_finish_setup
|
||||
):
|
||||
"""Test a config flow is aborted when a connection was not successful."""
|
||||
mock_try_connection.return_value = True
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"mqtt",
|
||||
data=HassioServiceInfo(
|
||||
config={
|
||||
"addon": "Mock Addon",
|
||||
"host": "mock-broker",
|
||||
"port": 1883,
|
||||
"username": "mock-user",
|
||||
"password": "mock-pass",
|
||||
"protocol": "3.1.1", # Set by the addon's discovery, ignored by HA
|
||||
"ssl": False, # Set by the addon's discovery, ignored by HA
|
||||
}
|
||||
),
|
||||
context={"source": config_entries.SOURCE_HASSIO},
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert result["step_id"] == "hassio_confirm"
|
||||
assert result["description_placeholders"] == {"addon": "Mock Addon"}
|
||||
|
||||
mock_try_connection_time_out.reset_mock()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {"discovery": True}
|
||||
)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"]["base"] == "cannot_connect"
|
||||
# Check we tried the connection
|
||||
assert len(mock_try_connection_time_out.mock_calls)
|
||||
# Check config entry got setup
|
||||
assert len(mock_finish_setup.mock_calls) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.config.async_hass_config_yaml",
|
||||
AsyncMock(return_value={}),
|
||||
|
@ -299,7 +334,7 @@ async def test_option_flow(
|
|||
hass,
|
||||
mqtt_mock_entry_no_yaml_config,
|
||||
mock_try_connection,
|
||||
mock_reload_after_entry_update,
|
||||
caplog,
|
||||
):
|
||||
"""Test config flow options."""
|
||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||
|
@ -372,7 +407,10 @@ async def test_option_flow(
|
|||
await hass.async_block_till_done()
|
||||
assert config_entry.title == "another-broker"
|
||||
# assert that the entry was reloaded with the new config
|
||||
assert mock_reload_after_entry_update.call_count == 1
|
||||
assert (
|
||||
"<Event call_service[L]: domain=mqtt, service=reload, service_data=>"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_disable_birth_will(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue