diff --git a/homeassistant/components/mysensors/config_flow.py b/homeassistant/components/mysensors/config_flow.py index 5409e3c9a85..04e95f1dad3 100644 --- a/homeassistant/components/mysensors/config_flow.py +++ b/homeassistant/components/mysensors/config_flow.py @@ -20,13 +20,13 @@ from homeassistant.components.mqtt import ( from homeassistant.config_entries import ConfigEntry from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers import selector import homeassistant.helpers.config_validation as cv from .const import ( CONF_BAUD_RATE, CONF_DEVICE, CONF_GATEWAY_TYPE, - CONF_GATEWAY_TYPE_ALL, CONF_GATEWAY_TYPE_MQTT, CONF_GATEWAY_TYPE_SERIAL, CONF_GATEWAY_TYPE_TCP, @@ -45,6 +45,15 @@ DEFAULT_BAUD_RATE = 115200 DEFAULT_TCP_PORT = 5003 DEFAULT_VERSION = "1.4" +_PORT_SELECTOR = vol.All( + selector.NumberSelector( + selector.NumberSelectorConfig( + min=1, max=65535, mode=selector.NumberSelectorMode.BOX + ), + ), + vol.Coerce(int), +) + def is_persistence_file(value: str) -> str: """Validate that persistence file path ends in either .pickle or .json.""" @@ -119,51 +128,34 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, str] | None = None ) -> FlowResult: """Create a config entry from frontend user input.""" - schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)} - schema = vol.Schema(schema) - errors = {} - - if user_input is not None: - gw_type = self._gw_type = user_input[CONF_GATEWAY_TYPE] - input_pass = user_input if CONF_DEVICE in user_input else None - if gw_type == CONF_GATEWAY_TYPE_MQTT: - # Naive check that doesn't consider config entry state. - if MQTT_DOMAIN in self.hass.config.components: - return await self.async_step_gw_mqtt(input_pass) - - errors["base"] = "mqtt_required" - if gw_type == CONF_GATEWAY_TYPE_TCP: - return await self.async_step_gw_tcp(input_pass) - if gw_type == CONF_GATEWAY_TYPE_SERIAL: - return await self.async_step_gw_serial(input_pass) - - return self.async_show_form(step_id="user", data_schema=schema, errors=errors) + return self.async_show_menu( + step_id="select_gateway_type", + menu_options=["gw_serial", "gw_tcp", "gw_mqtt"], + ) async def async_step_gw_serial( self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Create config entry for a serial gateway.""" + gw_type = self._gw_type = CONF_GATEWAY_TYPE_SERIAL errors: dict[str, str] = {} + if user_input is not None: - errors.update( - await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input) - ) + errors.update(await self.validate_common(gw_type, errors, user_input)) if not errors: return self._async_create_entry(user_input) user_input = user_input or {} - schema = _get_schema_common(user_input) - schema[ + schema = { + vol.Required( + CONF_DEVICE, default=user_input.get(CONF_DEVICE, "/dev/ttyACM0") + ): str, vol.Required( CONF_BAUD_RATE, default=user_input.get(CONF_BAUD_RATE, DEFAULT_BAUD_RATE), - ) - ] = cv.positive_int - schema[ - vol.Required( - CONF_DEVICE, default=user_input.get(CONF_DEVICE, "/dev/ttyACM0") - ) - ] = str + ): cv.positive_int, + } + schema.update(_get_schema_common(user_input)) schema = vol.Schema(schema) return self.async_show_form( @@ -174,30 +166,24 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Create a config entry for a tcp gateway.""" - errors = {} - if user_input is not None: - if CONF_TCP_PORT in user_input: - port: int = user_input[CONF_TCP_PORT] - if not (0 < port <= 65535): - errors[CONF_TCP_PORT] = "port_out_of_range" + gw_type = self._gw_type = CONF_GATEWAY_TYPE_TCP + errors: dict[str, str] = {} - errors.update( - await self.validate_common(CONF_GATEWAY_TYPE_TCP, errors, user_input) - ) + if user_input is not None: + errors.update(await self.validate_common(gw_type, errors, user_input)) if not errors: return self._async_create_entry(user_input) user_input = user_input or {} - schema = _get_schema_common(user_input) - schema[ - vol.Required(CONF_DEVICE, default=user_input.get(CONF_DEVICE, "127.0.0.1")) - ] = str - # Don't use cv.port as that would show a slider *facepalm* - schema[ + schema = { + vol.Required( + CONF_DEVICE, default=user_input.get(CONF_DEVICE, "127.0.0.1") + ): str, vol.Optional( CONF_TCP_PORT, default=user_input.get(CONF_TCP_PORT, DEFAULT_TCP_PORT) - ) - ] = vol.Coerce(int) + ): _PORT_SELECTOR, + } + schema.update(_get_schema_common(user_input)) schema = vol.Schema(schema) return self.async_show_form(step_id="gw_tcp", data_schema=schema, errors=errors) @@ -214,7 +200,13 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Create a config entry for a mqtt gateway.""" - errors = {} + # Naive check that doesn't consider config entry state. + if MQTT_DOMAIN not in self.hass.config.components: + return self.async_abort(reason="mqtt_required") + + gw_type = self._gw_type = CONF_GATEWAY_TYPE_MQTT + errors: dict[str, str] = {} + if user_input is not None: user_input[CONF_DEVICE] = MQTT_COMPONENT @@ -239,27 +231,21 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): elif self._check_topic_exists(user_input[CONF_TOPIC_OUT_PREFIX]): errors[CONF_TOPIC_OUT_PREFIX] = "duplicate_topic" - errors.update( - await self.validate_common(CONF_GATEWAY_TYPE_MQTT, errors, user_input) - ) + errors.update(await self.validate_common(gw_type, errors, user_input)) if not errors: return self._async_create_entry(user_input) user_input = user_input or {} - schema = _get_schema_common(user_input) - schema[ - vol.Required(CONF_RETAIN, default=user_input.get(CONF_RETAIN, True)) - ] = bool - schema[ + schema = { vol.Required( CONF_TOPIC_IN_PREFIX, default=user_input.get(CONF_TOPIC_IN_PREFIX, "") - ) - ] = str - schema[ + ): str, vol.Required( CONF_TOPIC_OUT_PREFIX, default=user_input.get(CONF_TOPIC_OUT_PREFIX, "") - ) - ] = str + ): str, + vol.Required(CONF_RETAIN, default=user_input.get(CONF_RETAIN, True)): bool, + } + schema.update(_get_schema_common(user_input)) schema = vol.Schema(schema) return self.async_show_form( diff --git a/homeassistant/components/mysensors/const.py b/homeassistant/components/mysensors/const.py index 32e2110dd95..42df81ae526 100644 --- a/homeassistant/components/mysensors/const.py +++ b/homeassistant/components/mysensors/const.py @@ -22,11 +22,6 @@ ConfGatewayType = Literal["Serial", "TCP", "MQTT"] CONF_GATEWAY_TYPE_SERIAL: ConfGatewayType = "Serial" CONF_GATEWAY_TYPE_TCP: ConfGatewayType = "TCP" CONF_GATEWAY_TYPE_MQTT: ConfGatewayType = "MQTT" -CONF_GATEWAY_TYPE_ALL: list[str] = [ - CONF_GATEWAY_TYPE_MQTT, - CONF_GATEWAY_TYPE_SERIAL, - CONF_GATEWAY_TYPE_TCP, -] DOMAIN: Final = "mysensors" MYSENSORS_GATEWAY_START_TASK: str = "mysensors_gateway_start_task_{}" diff --git a/homeassistant/components/mysensors/strings.json b/homeassistant/components/mysensors/strings.json index d7722e565cb..dc5dc76c7ae 100644 --- a/homeassistant/components/mysensors/strings.json +++ b/homeassistant/components/mysensors/strings.json @@ -7,6 +7,14 @@ }, "description": "Choose connection method to the gateway" }, + "select_gateway_type": { + "description": "Select which gateway to configure.", + "menu_options": { + "gw_mqtt": "Configure an MQTT gateway", + "gw_serial": "Configure a serial gateway", + "gw_tcp": "Configure a TCP gateway" + } + }, "gw_tcp": { "description": "Ethernet gateway setup", "data": { @@ -51,7 +59,6 @@ "invalid_serial": "Invalid serial port", "invalid_device": "Invalid device", "invalid_version": "Invalid MySensors version", - "mqtt_required": "The MQTT integration is not set up", "not_a_number": "Please enter a number", "port_out_of_range": "Port number must be at least 1 and at most 65535", "unknown": "[%key:common::config_flow::error::unknown%]" @@ -71,6 +78,7 @@ "invalid_serial": "Invalid serial port", "invalid_device": "Invalid device", "invalid_version": "Invalid MySensors version", + "mqtt_required": "The MQTT integration is not set up", "not_a_number": "Please enter a number", "port_out_of_range": "Port number must be at least 1 and at most 65535", "unknown": "[%key:common::config_flow::error::unknown%]" diff --git a/homeassistant/components/mysensors/translations/en.json b/homeassistant/components/mysensors/translations/en.json index 5ec81c22186..b85a28fb7d3 100644 --- a/homeassistant/components/mysensors/translations/en.json +++ b/homeassistant/components/mysensors/translations/en.json @@ -14,6 +14,7 @@ "invalid_serial": "Invalid serial port", "invalid_subscribe_topic": "Invalid subscribe topic", "invalid_version": "Invalid MySensors version", + "mqtt_required": "The MQTT integration is not set up", "not_a_number": "Please enter a number", "port_out_of_range": "Port number must be at least 1 and at most 65535", "same_topic": "Subscribe and publish topics are the same", @@ -33,7 +34,6 @@ "invalid_serial": "Invalid serial port", "invalid_subscribe_topic": "Invalid subscribe topic", "invalid_version": "Invalid MySensors version", - "mqtt_required": "The MQTT integration is not set up", "not_a_number": "Please enter a number", "port_out_of_range": "Port number must be at least 1 and at most 65535", "same_topic": "Subscribe and publish topics are the same", @@ -68,6 +68,14 @@ }, "description": "Ethernet gateway setup" }, + "select_gateway_type": { + "description": "Select which gateway to configure.", + "menu_options": { + "gw_mqtt": "Configure an MQTT gateway", + "gw_serial": "Configure a serial gateway", + "gw_tcp": "Configure a TCP gateway" + } + }, "user": { "data": { "gateway_type": "Gateway type" diff --git a/tests/components/mysensors/test_config_flow.py b/tests/components/mysensors/test_config_flow.py index e7808162043..e14059c4e4f 100644 --- a/tests/components/mysensors/test_config_flow.py +++ b/tests/components/mysensors/test_config_flow.py @@ -24,28 +24,32 @@ from homeassistant.components.mysensors.const import ( ConfGatewayType, ) from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResult +from homeassistant.data_entry_flow import FlowResult, FlowResultType from tests.common import MockConfigEntry +GATEWAY_TYPE_TO_STEP = { + CONF_GATEWAY_TYPE_TCP: "gw_tcp", + CONF_GATEWAY_TYPE_SERIAL: "gw_serial", + CONF_GATEWAY_TYPE_MQTT: "gw_mqtt", +} + async def get_form( - hass: HomeAssistant, gatway_type: ConfGatewayType, expected_step_id: str + hass: HomeAssistant, gateway_type: ConfGatewayType, expected_step_id: str ) -> FlowResult: """Get a form for the given gateway type.""" - stepuser = await hass.config_entries.flow.async_init( + result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert stepuser["type"] == "form" - assert not stepuser["errors"] + assert result["type"] == FlowResultType.MENU result = await hass.config_entries.flow.async_configure( - stepuser["flow_id"], - {CONF_GATEWAY_TYPE: gatway_type}, + result["flow_id"], {"next_step_id": GATEWAY_TYPE_TO_STEP[gateway_type]} ) await hass.async_block_till_done() - assert result["type"] == "form" + assert result["type"] == FlowResultType.FORM assert result["step_id"] == expected_step_id return result @@ -62,7 +66,7 @@ async def test_config_mqtt(hass: HomeAssistant, mqtt: None) -> None: "homeassistant.components.mysensors.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( flow_id, { CONF_RETAIN: True, @@ -73,11 +77,11 @@ async def test_config_mqtt(hass: HomeAssistant, mqtt: None) -> None: ) await hass.async_block_till_done() - if "errors" in result2: - assert not result2["errors"] - assert result2["type"] == "create_entry" - assert result2["title"] == "mqtt" - assert result2["data"] == { + if "errors" in result: + assert not result["errors"] + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "mqtt" + assert result["data"] == { CONF_DEVICE: "mqtt", CONF_RETAIN: True, CONF_TOPIC_IN_PREFIX: "bla", @@ -91,20 +95,19 @@ async def test_config_mqtt(hass: HomeAssistant, mqtt: None) -> None: async def test_missing_mqtt(hass: HomeAssistant) -> None: """Test configuring a mqtt gateway without mqtt integration setup.""" - result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert result["type"] == "form" - assert not result["errors"] + assert result["type"] == FlowResultType.MENU result = await hass.config_entries.flow.async_configure( result["flow_id"], - {CONF_GATEWAY_TYPE: CONF_GATEWAY_TYPE_MQTT}, + {"next_step_id": GATEWAY_TYPE_TO_STEP[CONF_GATEWAY_TYPE_MQTT]}, ) - assert result["step_id"] == "user" - assert result["type"] == "form" - assert result["errors"] == {"base": "mqtt_required"} + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "mqtt_required" async def test_config_serial(hass: HomeAssistant) -> None: @@ -123,7 +126,7 @@ async def test_config_serial(hass: HomeAssistant) -> None: "homeassistant.components.mysensors.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( flow_id, { CONF_BAUD_RATE: 115200, @@ -133,11 +136,11 @@ async def test_config_serial(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - if "errors" in result2: - assert not result2["errors"] - assert result2["type"] == "create_entry" - assert result2["title"] == "/dev/ttyACM0" - assert result2["data"] == { + if "errors" in result: + assert not result["errors"] + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "/dev/ttyACM0" + assert result["data"] == { CONF_DEVICE: "/dev/ttyACM0", CONF_BAUD_RATE: 115200, CONF_VERSION: "2.4", @@ -160,7 +163,7 @@ async def test_config_tcp(hass: HomeAssistant) -> None: "homeassistant.components.mysensors.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( flow_id, { CONF_TCP_PORT: 5003, @@ -170,11 +173,11 @@ async def test_config_tcp(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - if "errors" in result2: - assert not result2["errors"] - assert result2["type"] == "create_entry" - assert result2["title"] == "127.0.0.1" - assert result2["data"] == { + if "errors" in result: + assert not result["errors"] + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "127.0.0.1" + assert result["data"] == { CONF_DEVICE: "127.0.0.1", CONF_TCP_PORT: 5003, CONF_VERSION: "2.4", @@ -197,7 +200,7 @@ async def test_fail_to_connect(hass: HomeAssistant) -> None: "homeassistant.components.mysensors.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( flow_id, { CONF_TCP_PORT: 5003, @@ -207,9 +210,9 @@ async def test_fail_to_connect(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert result2["type"] == "form" - assert "errors" in result2 - errors = result2["errors"] + assert result["type"] == FlowResultType.FORM + assert "errors" in result + errors = result["errors"] assert errors assert errors.get("base") == "cannot_connect" assert len(mock_setup.mock_calls) == 0 @@ -219,28 +222,6 @@ async def test_fail_to_connect(hass: HomeAssistant) -> None: @pytest.mark.parametrize( "gateway_type, expected_step_id, user_input, err_field, err_string", [ - ( - CONF_GATEWAY_TYPE_TCP, - "gw_tcp", - { - CONF_TCP_PORT: 600_000, - CONF_DEVICE: "127.0.0.1", - CONF_VERSION: "2.4", - }, - CONF_TCP_PORT, - "port_out_of_range", - ), - ( - CONF_GATEWAY_TYPE_TCP, - "gw_tcp", - { - CONF_TCP_PORT: 0, - CONF_DEVICE: "127.0.0.1", - CONF_VERSION: "2.4", - }, - CONF_TCP_PORT, - "port_out_of_range", - ), ( CONF_GATEWAY_TYPE_TCP, "gw_tcp", @@ -382,15 +363,15 @@ async def test_config_invalid( "homeassistant.components.mysensors.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( flow_id, user_input, ) await hass.async_block_till_done() - assert result2["type"] == "form" - assert "errors" in result2 - errors = result2["errors"] + assert result["type"] == FlowResultType.FORM + assert "errors" in result + errors = result["errors"] assert errors assert err_field in errors assert errors[err_field] == err_string @@ -681,10 +662,8 @@ async def test_duplicate( MockConfigEntry(domain=DOMAIN, data=first_input).add_to_hass(hass) second_gateway_type = second_input.pop(CONF_GATEWAY_TYPE) - result = await hass.config_entries.flow.async_init( - DOMAIN, - data={CONF_GATEWAY_TYPE: second_gateway_type}, - context={"source": config_entries.SOURCE_USER}, + result = await get_form( + hass, second_gateway_type, GATEWAY_TYPE_TO_STEP[second_gateway_type] ) result = await hass.config_entries.flow.async_configure( result["flow_id"],