diff --git a/homeassistant/components/shelly/config_flow.py b/homeassistant/components/shelly/config_flow.py index 04213a28ad6..9311be1a49e 100644 --- a/homeassistant/components/shelly/config_flow.py +++ b/homeassistant/components/shelly/config_flow.py @@ -143,6 +143,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle the credentials step.""" errors: dict[str, str] = {} if user_input is not None: + if get_info_gen(self.info) == 2: + user_input[CONF_USERNAME] = "admin" try: device_info = await validate_input( self.hass, self.host, self.info, user_input @@ -152,8 +154,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors["base"] = "invalid_auth" else: errors["base"] = "cannot_connect" + except aioshelly.exceptions.InvalidAuthError: + errors["base"] = "invalid_auth" except HTTP_CONNECT_ERRORS: errors["base"] = "cannot_connect" + except aioshelly.exceptions.JSONRPCError: + errors["base"] = "cannot_connect" except Exception: # pylint: disable=broad-except LOGGER.exception("Unexpected exception") errors["base"] = "unknown" @@ -171,15 +177,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): else: user_input = {} - schema = vol.Schema( - { + if get_info_gen(self.info) == 2: + schema = { + vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str, + } + else: + schema = { vol.Required(CONF_USERNAME, default=user_input.get(CONF_USERNAME)): str, vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str, } - ) return self.async_show_form( - step_id="credentials", data_schema=schema, errors=errors + step_id="credentials", data_schema=vol.Schema(schema), errors=errors ) async def async_step_zeroconf( diff --git a/tests/components/shelly/test_config_flow.py b/tests/components/shelly/test_config_flow.py index 02e86ef03f8..1293fe92760 100644 --- a/tests/components/shelly/test_config_flow.py +++ b/tests/components/shelly/test_config_flow.py @@ -135,8 +135,16 @@ async def test_title_without_name(hass): assert len(mock_setup_entry.mock_calls) == 1 -async def test_form_auth(hass): +@pytest.mark.parametrize( + "test_data", + [ + (1, {"username": "test user", "password": "test1 password"}, "test user"), + (2, {"password": "test2 password"}, "admin"), + ], +) +async def test_form_auth(hass, test_data): """Test manual configuration if auth is required.""" + gen, user_input, username = test_data result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -145,7 +153,7 @@ async def test_form_auth(hass): with patch( "aioshelly.common.get_info", - return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True}, + return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True, "gen": gen}, ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -163,6 +171,15 @@ async def test_form_auth(hass): settings=MOCK_SETTINGS, ) ), + ), patch( + "aioshelly.rpc_device.RpcDevice.create", + new=AsyncMock( + return_value=Mock( + model="SHSW-1", + config=MOCK_CONFIG, + shutdown=AsyncMock(), + ) + ), ), patch( "homeassistant.components.shelly.async_setup", return_value=True ) as mock_setup, patch( @@ -170,8 +187,7 @@ async def test_form_auth(hass): return_value=True, ) as mock_setup_entry: result3 = await hass.config_entries.flow.async_configure( - result2["flow_id"], - {"username": "test username", "password": "test password"}, + result2["flow_id"], user_input ) await hass.async_block_till_done() @@ -181,9 +197,9 @@ async def test_form_auth(hass): "host": "1.1.1.1", "model": "SHSW-1", "sleep_period": 0, - "gen": 1, - "username": "test username", - "password": "test password", + "gen": gen, + "username": username, + "password": user_input["password"], } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 @@ -345,8 +361,8 @@ async def test_form_firmware_unsupported(hass): (ValueError, "unknown"), ], ) -async def test_form_auth_errors_test_connection(hass, error): - """Test we handle errors in authenticated devices.""" +async def test_form_auth_errors_test_connection_gen1(hass, error): + """Test we handle errors in Gen1 authenticated devices.""" exc, base_error = error result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} @@ -373,6 +389,48 @@ async def test_form_auth_errors_test_connection(hass, error): assert result3["errors"] == {"base": base_error} +@pytest.mark.parametrize( + "error", + [ + ( + aioshelly.exceptions.JSONRPCError(code=400), + "cannot_connect", + ), + ( + aioshelly.exceptions.InvalidAuthError(code=401), + "invalid_auth", + ), + (asyncio.TimeoutError, "cannot_connect"), + (ValueError, "unknown"), + ], +) +async def test_form_auth_errors_test_connection_gen2(hass, error): + """Test we handle errors in Gen2 authenticated devices.""" + exc, base_error = error + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "aioshelly.common.get_info", + return_value={"mac": "test-mac", "auth": True, "gen": 2}, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"host": "1.1.1.1"}, + ) + + with patch( + "aioshelly.rpc_device.RpcDevice.create", + new=AsyncMock(side_effect=exc), + ): + result3 = await hass.config_entries.flow.async_configure( + result2["flow_id"], {"password": "test password"} + ) + assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result3["errors"] == {"base": base_error} + + async def test_zeroconf(hass): """Test we get the form."""