De-duplicate MQTT config_flow code (#79369)
* De-duplicate config_flow code * De duplicate code birth and will
This commit is contained in:
parent
9a81b65815
commit
aee82e2b3b
4 changed files with 195 additions and 121 deletions
|
@ -54,6 +54,7 @@ from .const import (
|
||||||
CONF_TLS_INSECURE,
|
CONF_TLS_INSECURE,
|
||||||
CONF_WILL_MESSAGE,
|
CONF_WILL_MESSAGE,
|
||||||
DEFAULT_ENCODING,
|
DEFAULT_ENCODING,
|
||||||
|
DEFAULT_PROTOCOL,
|
||||||
DEFAULT_QOS,
|
DEFAULT_QOS,
|
||||||
MQTT_CONNECTED,
|
MQTT_CONNECTED,
|
||||||
MQTT_DISCONNECTED,
|
MQTT_DISCONNECTED,
|
||||||
|
@ -272,7 +273,7 @@ class MqttClientSetup:
|
||||||
# should be able to optionally rely on MQTT.
|
# should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
if config[CONF_PROTOCOL] == PROTOCOL_31:
|
if config.get(CONF_PROTOCOL, DEFAULT_PROTOCOL) == PROTOCOL_31:
|
||||||
proto = mqtt.MQTTv31
|
proto = mqtt.MQTTv31
|
||||||
else:
|
else:
|
||||||
proto = mqtt.MQTTv311
|
proto = mqtt.MQTTv311
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Callable
|
||||||
import queue
|
import queue
|
||||||
|
from types import MappingProxyType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -15,10 +17,9 @@ from homeassistant.const import (
|
||||||
CONF_PASSWORD,
|
CONF_PASSWORD,
|
||||||
CONF_PAYLOAD,
|
CONF_PAYLOAD,
|
||||||
CONF_PORT,
|
CONF_PORT,
|
||||||
CONF_PROTOCOL,
|
|
||||||
CONF_USERNAME,
|
CONF_USERNAME,
|
||||||
)
|
)
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
|
@ -33,6 +34,7 @@ from .const import (
|
||||||
CONF_WILL_MESSAGE,
|
CONF_WILL_MESSAGE,
|
||||||
DEFAULT_BIRTH,
|
DEFAULT_BIRTH,
|
||||||
DEFAULT_DISCOVERY,
|
DEFAULT_DISCOVERY,
|
||||||
|
DEFAULT_PORT,
|
||||||
DEFAULT_WILL,
|
DEFAULT_WILL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
|
@ -56,9 +58,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
"""Get the options flow for this handler."""
|
"""Get the options flow for this handler."""
|
||||||
return MQTTOptionsFlowHandler(config_entry)
|
return MQTTOptionsFlowHandler(config_entry)
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(self, user_input: ConfigType | None = None) -> FlowResult:
|
||||||
self, user_input: dict[str, Any] | None = None
|
|
||||||
) -> FlowResult:
|
|
||||||
"""Handle a flow initialized by the user."""
|
"""Handle a flow initialized by the user."""
|
||||||
if self._async_current_entries():
|
if self._async_current_entries():
|
||||||
return self.async_abort(reason="single_instance_allowed")
|
return self.async_abort(reason="single_instance_allowed")
|
||||||
|
@ -66,35 +66,38 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
return await self.async_step_broker()
|
return await self.async_step_broker()
|
||||||
|
|
||||||
async def async_step_broker(
|
async def async_step_broker(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: ConfigType | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Confirm the setup."""
|
"""Confirm the setup."""
|
||||||
errors = {}
|
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
|
||||||
|
errors: dict[str, str] = {}
|
||||||
if user_input is not None:
|
fields: OrderedDict[Any, Any] = OrderedDict()
|
||||||
|
validated_user_input: ConfigType = {}
|
||||||
|
if await async_get_broker_settings(
|
||||||
|
self.hass,
|
||||||
|
fields,
|
||||||
|
yaml_config,
|
||||||
|
None,
|
||||||
|
user_input,
|
||||||
|
validated_user_input,
|
||||||
|
errors,
|
||||||
|
):
|
||||||
|
test_config: ConfigType = yaml_config.copy()
|
||||||
|
test_config.update(validated_user_input)
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
get_mqtt_data(self.hass, True).config or {},
|
test_config,
|
||||||
user_input[CONF_BROKER],
|
|
||||||
user_input[CONF_PORT],
|
|
||||||
user_input.get(CONF_USERNAME),
|
|
||||||
user_input.get(CONF_PASSWORD),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_connect:
|
if can_connect:
|
||||||
user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
|
validated_user_input[CONF_DISCOVERY] = DEFAULT_DISCOVERY
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=user_input[CONF_BROKER], data=user_input
|
title=validated_user_input[CONF_BROKER],
|
||||||
|
data=validated_user_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
|
|
||||||
fields = OrderedDict()
|
|
||||||
fields[vol.Required(CONF_BROKER)] = str
|
|
||||||
fields[vol.Required(CONF_PORT, default=1883)] = vol.Coerce(int)
|
|
||||||
fields[vol.Optional(CONF_USERNAME)] = str
|
|
||||||
fields[vol.Optional(CONF_PASSWORD)] = str
|
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="broker", data_schema=vol.Schema(fields), errors=errors
|
step_id="broker", data_schema=vol.Schema(fields), errors=errors
|
||||||
)
|
)
|
||||||
|
@ -111,26 +114,22 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Confirm a Hass.io discovery."""
|
"""Confirm a Hass.io discovery."""
|
||||||
errors = {}
|
errors: dict[str, str] = {}
|
||||||
assert self._hassio_discovery
|
assert self._hassio_discovery
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
data = self._hassio_discovery
|
data: ConfigType = self._hassio_discovery.copy()
|
||||||
|
data[CONF_BROKER] = data.pop(CONF_HOST)
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
get_mqtt_data(self.hass, True).config or {},
|
data,
|
||||||
data[CONF_HOST],
|
|
||||||
data[CONF_PORT],
|
|
||||||
data.get(CONF_USERNAME),
|
|
||||||
data.get(CONF_PASSWORD),
|
|
||||||
data.get(CONF_PROTOCOL),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_connect:
|
if can_connect:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=data["addon"],
|
title=data["addon"],
|
||||||
data={
|
data={
|
||||||
CONF_BROKER: data[CONF_HOST],
|
CONF_BROKER: data[CONF_BROKER],
|
||||||
CONF_PORT: data[CONF_PORT],
|
CONF_PORT: data[CONF_PORT],
|
||||||
CONF_USERNAME: data.get(CONF_USERNAME),
|
CONF_USERNAME: data.get(CONF_USERNAME),
|
||||||
CONF_PASSWORD: data.get(CONF_PASSWORD),
|
CONF_PASSWORD: data.get(CONF_PASSWORD),
|
||||||
|
@ -164,46 +163,32 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Manage the MQTT broker configuration."""
|
"""Manage the MQTT broker configuration."""
|
||||||
mqtt_data = get_mqtt_data(self.hass, True)
|
errors: dict[str, str] = {}
|
||||||
yaml_config = mqtt_data.config or {}
|
yaml_config: ConfigType = get_mqtt_data(self.hass, True).config or {}
|
||||||
errors = {}
|
fields: OrderedDict[Any, Any] = OrderedDict()
|
||||||
current_config = self.config_entry.data
|
validated_user_input: ConfigType = {}
|
||||||
if user_input is not None:
|
if await async_get_broker_settings(
|
||||||
|
self.hass,
|
||||||
|
fields,
|
||||||
|
yaml_config,
|
||||||
|
self.config_entry.data,
|
||||||
|
user_input,
|
||||||
|
validated_user_input,
|
||||||
|
errors,
|
||||||
|
):
|
||||||
|
test_config: ConfigType = yaml_config.copy()
|
||||||
|
test_config.update(validated_user_input)
|
||||||
can_connect = await self.hass.async_add_executor_job(
|
can_connect = await self.hass.async_add_executor_job(
|
||||||
try_connection,
|
try_connection,
|
||||||
yaml_config,
|
test_config,
|
||||||
user_input[CONF_BROKER],
|
|
||||||
user_input[CONF_PORT],
|
|
||||||
user_input.get(CONF_USERNAME),
|
|
||||||
user_input.get(CONF_PASSWORD),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_connect:
|
if can_connect:
|
||||||
self.broker_config.update(user_input)
|
self.broker_config.update(validated_user_input)
|
||||||
return await self.async_step_options()
|
return await self.async_step_options()
|
||||||
|
|
||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
|
|
||||||
fields = OrderedDict()
|
|
||||||
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
|
|
||||||
current_port = current_config.get(CONF_PORT, yaml_config.get(CONF_PORT))
|
|
||||||
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
|
|
||||||
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
|
|
||||||
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
|
|
||||||
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
|
|
||||||
fields[
|
|
||||||
vol.Optional(
|
|
||||||
CONF_USERNAME,
|
|
||||||
description={"suggested_value": current_user},
|
|
||||||
)
|
|
||||||
] = str
|
|
||||||
fields[
|
|
||||||
vol.Optional(
|
|
||||||
CONF_PASSWORD,
|
|
||||||
description={"suggested_value": current_pass},
|
|
||||||
)
|
|
||||||
] = str
|
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="broker",
|
step_id="broker",
|
||||||
data_schema=vol.Schema(fields),
|
data_schema=vol.Schema(fields),
|
||||||
|
@ -212,53 +197,61 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_options(
|
async def async_step_options(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: ConfigType | None = None
|
||||||
) -> FlowResult:
|
) -> FlowResult:
|
||||||
"""Manage the MQTT options."""
|
"""Manage the MQTT options."""
|
||||||
mqtt_data = get_mqtt_data(self.hass, True)
|
|
||||||
errors = {}
|
errors = {}
|
||||||
current_config = self.config_entry.data
|
current_config = self.config_entry.data
|
||||||
yaml_config = mqtt_data.config or {}
|
yaml_config = get_mqtt_data(self.hass, True).config or {}
|
||||||
options_config: dict[str, Any] = {}
|
options_config: ConfigType = {}
|
||||||
if user_input is not None:
|
bad_input: bool = False
|
||||||
bad_birth = False
|
|
||||||
bad_will = False
|
|
||||||
|
|
||||||
if "birth_topic" in user_input:
|
def _birth_will(birt_or_will: str) -> dict:
|
||||||
birth_message = {
|
"""Return the user input for birth or will."""
|
||||||
ATTR_TOPIC: user_input["birth_topic"],
|
assert user_input
|
||||||
ATTR_PAYLOAD: user_input.get("birth_payload", ""),
|
return {
|
||||||
ATTR_QOS: user_input["birth_qos"],
|
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],
|
||||||
ATTR_RETAIN: user_input["birth_retain"],
|
ATTR_PAYLOAD: user_input.get(f"{birt_or_will}_payload", ""),
|
||||||
|
ATTR_QOS: user_input[f"{birt_or_will}_qos"],
|
||||||
|
ATTR_RETAIN: user_input[f"{birt_or_will}_retain"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _validate(
|
||||||
|
field: str, values: ConfigType, error_code: str, schema: Callable
|
||||||
|
):
|
||||||
|
"""Validate the user input."""
|
||||||
|
nonlocal bad_input
|
||||||
try:
|
try:
|
||||||
birth_message = MQTT_WILL_BIRTH_SCHEMA(birth_message)
|
option_values = schema(values)
|
||||||
options_config[CONF_BIRTH_MESSAGE] = birth_message
|
options_config[field] = option_values
|
||||||
except vol.Invalid:
|
except vol.Invalid:
|
||||||
errors["base"] = "bad_birth"
|
errors["base"] = error_code
|
||||||
bad_birth = True
|
bad_input = True
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
# validate input
|
||||||
|
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
||||||
|
if "birth_topic" in user_input:
|
||||||
|
_validate(
|
||||||
|
CONF_BIRTH_MESSAGE,
|
||||||
|
_birth_will("birth"),
|
||||||
|
"bad_birth",
|
||||||
|
MQTT_WILL_BIRTH_SCHEMA,
|
||||||
|
)
|
||||||
if not user_input["birth_enable"]:
|
if not user_input["birth_enable"]:
|
||||||
options_config[CONF_BIRTH_MESSAGE] = {}
|
options_config[CONF_BIRTH_MESSAGE] = {}
|
||||||
|
|
||||||
if "will_topic" in user_input:
|
if "will_topic" in user_input:
|
||||||
will_message = {
|
_validate(
|
||||||
ATTR_TOPIC: user_input["will_topic"],
|
CONF_WILL_MESSAGE,
|
||||||
ATTR_PAYLOAD: user_input.get("will_payload", ""),
|
_birth_will("will"),
|
||||||
ATTR_QOS: user_input["will_qos"],
|
"bad_will",
|
||||||
ATTR_RETAIN: user_input["will_retain"],
|
MQTT_WILL_BIRTH_SCHEMA,
|
||||||
}
|
)
|
||||||
try:
|
|
||||||
will_message = MQTT_WILL_BIRTH_SCHEMA(will_message)
|
|
||||||
options_config[CONF_WILL_MESSAGE] = will_message
|
|
||||||
except vol.Invalid:
|
|
||||||
errors["base"] = "bad_will"
|
|
||||||
bad_will = True
|
|
||||||
if not user_input["will_enable"]:
|
if not user_input["will_enable"]:
|
||||||
options_config[CONF_WILL_MESSAGE] = {}
|
options_config[CONF_WILL_MESSAGE] = {}
|
||||||
|
|
||||||
options_config[CONF_DISCOVERY] = user_input[CONF_DISCOVERY]
|
if not bad_input:
|
||||||
|
|
||||||
if not bad_birth and not bad_will:
|
|
||||||
updated_config = {}
|
updated_config = {}
|
||||||
updated_config.update(self.broker_config)
|
updated_config.update(self.broker_config)
|
||||||
updated_config.update(options_config)
|
updated_config.update(options_config)
|
||||||
|
@ -285,6 +278,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
|
CONF_DISCOVERY, yaml_config.get(CONF_DISCOVERY, DEFAULT_DISCOVERY)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# build form
|
||||||
fields: OrderedDict[vol.Marker, Any] = OrderedDict()
|
fields: OrderedDict[vol.Marker, Any] = OrderedDict()
|
||||||
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool
|
fields[vol.Optional(CONF_DISCOVERY, default=discovery)] = bool
|
||||||
|
|
||||||
|
@ -338,28 +332,66 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def try_connection(
|
async def async_get_broker_settings(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
fields: OrderedDict[Any, Any],
|
||||||
yaml_config: ConfigType,
|
yaml_config: ConfigType,
|
||||||
broker: str,
|
entry_config: MappingProxyType[str, Any] | None,
|
||||||
port: int,
|
user_input: ConfigType | None,
|
||||||
username: str | None,
|
validated_user_input: ConfigType,
|
||||||
password: str | None,
|
errors: dict[str, str],
|
||||||
protocol: str = "3.1",
|
) -> bool:
|
||||||
|
"""Build the config flow schema to collect the broker settings.
|
||||||
|
|
||||||
|
Returns True when settings are collected successfully.
|
||||||
|
"""
|
||||||
|
user_input_basic: ConfigType = ConfigType()
|
||||||
|
current_config = entry_config.copy() if entry_config is not None else ConfigType()
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
validated_user_input.update(user_input)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Update the current settings the the new posted data to fill the defaults
|
||||||
|
current_config.update(user_input_basic)
|
||||||
|
|
||||||
|
# Get default settings (if any)
|
||||||
|
current_broker = current_config.get(CONF_BROKER, yaml_config.get(CONF_BROKER))
|
||||||
|
current_port = current_config.get(
|
||||||
|
CONF_PORT, yaml_config.get(CONF_PORT, DEFAULT_PORT)
|
||||||
|
)
|
||||||
|
current_user = current_config.get(CONF_USERNAME, yaml_config.get(CONF_USERNAME))
|
||||||
|
current_pass = current_config.get(CONF_PASSWORD, yaml_config.get(CONF_PASSWORD))
|
||||||
|
|
||||||
|
# Build form
|
||||||
|
fields[vol.Required(CONF_BROKER, default=current_broker)] = str
|
||||||
|
fields[vol.Required(CONF_PORT, default=current_port)] = vol.Coerce(int)
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_USERNAME,
|
||||||
|
description={"suggested_value": current_user},
|
||||||
|
)
|
||||||
|
] = str
|
||||||
|
fields[
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PASSWORD,
|
||||||
|
description={"suggested_value": current_pass},
|
||||||
|
)
|
||||||
|
] = str
|
||||||
|
|
||||||
|
# Show form
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def try_connection(
|
||||||
|
user_input: ConfigType,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Test if we can connect to an MQTT broker."""
|
"""Test if we can connect to an MQTT broker."""
|
||||||
# We don't import on the top because some integrations
|
# We don't import on the top because some integrations
|
||||||
# should be able to optionally rely on MQTT.
|
# should be able to optionally rely on MQTT.
|
||||||
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
# Get the config from configuration.yaml
|
client = MqttClientSetup(user_input).client
|
||||||
entry_config = {
|
|
||||||
CONF_BROKER: broker,
|
|
||||||
CONF_PORT: port,
|
|
||||||
CONF_USERNAME: username,
|
|
||||||
CONF_PASSWORD: password,
|
|
||||||
CONF_PROTOCOL: protocol,
|
|
||||||
}
|
|
||||||
client = MqttClientSetup({**yaml_config, **entry_config}).client
|
|
||||||
|
|
||||||
result: queue.Queue[bool] = queue.Queue(maxsize=1)
|
result: queue.Queue[bool] = queue.Queue(maxsize=1)
|
||||||
|
|
||||||
|
@ -369,7 +401,7 @@ def try_connection(
|
||||||
|
|
||||||
client.on_connect = on_connect
|
client.on_connect = on_connect
|
||||||
|
|
||||||
client.connect_async(broker, port)
|
client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
|
||||||
client.loop_start()
|
client.loop_start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -40,6 +40,7 @@ DEFAULT_ENCODING = "utf-8"
|
||||||
DEFAULT_QOS = 0
|
DEFAULT_QOS = 0
|
||||||
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
DEFAULT_PAYLOAD_AVAILABLE = "online"
|
||||||
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
DEFAULT_PAYLOAD_NOT_AVAILABLE = "offline"
|
||||||
|
DEFAULT_PORT = 1883
|
||||||
DEFAULT_RETAIN = False
|
DEFAULT_RETAIN = False
|
||||||
|
|
||||||
DEFAULT_BIRTH = {
|
DEFAULT_BIRTH = {
|
||||||
|
@ -67,6 +68,8 @@ PAYLOAD_NONE = "None"
|
||||||
PROTOCOL_31 = "3.1"
|
PROTOCOL_31 = "3.1"
|
||||||
PROTOCOL_311 = "3.1.1"
|
PROTOCOL_311 = "3.1.1"
|
||||||
|
|
||||||
|
DEFAULT_PROTOCOL = PROTOCOL_311
|
||||||
|
|
||||||
PLATFORMS = [
|
PLATFORMS = [
|
||||||
Platform.ALARM_CONTROL_PANEL,
|
Platform.ALARM_CONTROL_PANEL,
|
||||||
Platform.BINARY_SENSOR,
|
Platform.BINARY_SENSOR,
|
||||||
|
|
|
@ -188,15 +188,12 @@ async def test_manual_config_set(
|
||||||
# Check we tried the connection, with precedence for config entry settings
|
# Check we tried the connection, with precedence for config entry settings
|
||||||
mock_try_connection.assert_called_once_with(
|
mock_try_connection.assert_called_once_with(
|
||||||
{
|
{
|
||||||
"broker": "bla",
|
"broker": "127.0.0.1",
|
||||||
|
"protocol": "3.1.1",
|
||||||
"keepalive": 60,
|
"keepalive": 60,
|
||||||
"discovery_prefix": "homeassistant",
|
"discovery_prefix": "homeassistant",
|
||||||
"protocol": "3.1.1",
|
"port": 1883,
|
||||||
},
|
},
|
||||||
"127.0.0.1",
|
|
||||||
1883,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Check config entry got setup
|
# Check config entry got setup
|
||||||
assert len(mock_finish_setup.mock_calls) == 1
|
assert len(mock_finish_setup.mock_calls) == 1
|
||||||
|
@ -291,6 +288,44 @@ async def test_hassio_confirm(hass, mock_try_connection_success, mock_finish_set
|
||||||
assert len(mock_finish_setup.mock_calls) == 1
|
assert len(mock_finish_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hassio_cannot_connect(
|
||||||
|
hass, mock_try_connection_time_out, mock_finish_setup
|
||||||
|
):
|
||||||
|
"""Test a config flow is aborted when a connection was not successful."""
|
||||||
|
mock_try_connection.return_value = True
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
"mqtt",
|
||||||
|
data=HassioServiceInfo(
|
||||||
|
config={
|
||||||
|
"addon": "Mock Addon",
|
||||||
|
"host": "mock-broker",
|
||||||
|
"port": 1883,
|
||||||
|
"username": "mock-user",
|
||||||
|
"password": "mock-pass",
|
||||||
|
"protocol": "3.1.1", # Set by the addon's discovery, ignored by HA
|
||||||
|
"ssl": False, # Set by the addon's discovery, ignored by HA
|
||||||
|
}
|
||||||
|
),
|
||||||
|
context={"source": config_entries.SOURCE_HASSIO},
|
||||||
|
)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["step_id"] == "hassio_confirm"
|
||||||
|
assert result["description_placeholders"] == {"addon": "Mock Addon"}
|
||||||
|
|
||||||
|
mock_try_connection_time_out.reset_mock()
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"discovery": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["errors"]["base"] == "cannot_connect"
|
||||||
|
# Check we tried the connection
|
||||||
|
assert len(mock_try_connection_time_out.mock_calls)
|
||||||
|
# Check config entry got setup
|
||||||
|
assert len(mock_finish_setup.mock_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.config.async_hass_config_yaml",
|
"homeassistant.config.async_hass_config_yaml",
|
||||||
AsyncMock(return_value={}),
|
AsyncMock(return_value={}),
|
||||||
|
@ -299,7 +334,7 @@ async def test_option_flow(
|
||||||
hass,
|
hass,
|
||||||
mqtt_mock_entry_no_yaml_config,
|
mqtt_mock_entry_no_yaml_config,
|
||||||
mock_try_connection,
|
mock_try_connection,
|
||||||
mock_reload_after_entry_update,
|
caplog,
|
||||||
):
|
):
|
||||||
"""Test config flow options."""
|
"""Test config flow options."""
|
||||||
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
mqtt_mock = await mqtt_mock_entry_no_yaml_config()
|
||||||
|
@ -372,7 +407,10 @@ async def test_option_flow(
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert config_entry.title == "another-broker"
|
assert config_entry.title == "another-broker"
|
||||||
# assert that the entry was reloaded with the new config
|
# assert that the entry was reloaded with the new config
|
||||||
assert mock_reload_after_entry_update.call_count == 1
|
assert (
|
||||||
|
"<Event call_service[L]: domain=mqtt, service=reload, service_data=>"
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_disable_birth_will(
|
async def test_disable_birth_will(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue