Add Shelly gen2 authentication support (#69753)
This commit is contained in:
parent
7edbe66b26
commit
d4d819679c
2 changed files with 80 additions and 13 deletions
|
@ -143,6 +143,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"""Handle the credentials step."""
|
||||
errors: dict[str, str] = {}
|
||||
if user_input is not None:
|
||||
if get_info_gen(self.info) == 2:
|
||||
user_input[CONF_USERNAME] = "admin"
|
||||
try:
|
||||
device_info = await validate_input(
|
||||
self.hass, self.host, self.info, user_input
|
||||
|
@ -152,8 +154,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
errors["base"] = "invalid_auth"
|
||||
else:
|
||||
errors["base"] = "cannot_connect"
|
||||
except aioshelly.exceptions.InvalidAuthError:
|
||||
errors["base"] = "invalid_auth"
|
||||
except HTTP_CONNECT_ERRORS:
|
||||
errors["base"] = "cannot_connect"
|
||||
except aioshelly.exceptions.JSONRPCError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
@ -171,15 +177,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
else:
|
||||
user_input = {}
|
||||
|
||||
schema = vol.Schema(
|
||||
{
|
||||
if get_info_gen(self.info) == 2:
|
||||
schema = {
|
||||
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str,
|
||||
}
|
||||
else:
|
||||
schema = {
|
||||
vol.Required(CONF_USERNAME, default=user_input.get(CONF_USERNAME)): str,
|
||||
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="credentials", data_schema=schema, errors=errors
|
||||
step_id="credentials", data_schema=vol.Schema(schema), errors=errors
|
||||
)
|
||||
|
||||
async def async_step_zeroconf(
|
||||
|
|
|
@ -135,8 +135,16 @@ async def test_title_without_name(hass):
|
|||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_auth(hass):
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
(1, {"username": "test user", "password": "test1 password"}, "test user"),
|
||||
(2, {"password": "test2 password"}, "admin"),
|
||||
],
|
||||
)
|
||||
async def test_form_auth(hass, test_data):
|
||||
"""Test manual configuration if auth is required."""
|
||||
gen, user_input, username = test_data
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
@ -145,7 +153,7 @@ async def test_form_auth(hass):
|
|||
|
||||
with patch(
|
||||
"aioshelly.common.get_info",
|
||||
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True},
|
||||
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True, "gen": gen},
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
|
@ -163,6 +171,15 @@ async def test_form_auth(hass):
|
|||
settings=MOCK_SETTINGS,
|
||||
)
|
||||
),
|
||||
), patch(
|
||||
"aioshelly.rpc_device.RpcDevice.create",
|
||||
new=AsyncMock(
|
||||
return_value=Mock(
|
||||
model="SHSW-1",
|
||||
config=MOCK_CONFIG,
|
||||
shutdown=AsyncMock(),
|
||||
)
|
||||
),
|
||||
), patch(
|
||||
"homeassistant.components.shelly.async_setup", return_value=True
|
||||
) as mock_setup, patch(
|
||||
|
@ -170,8 +187,7 @@ async def test_form_auth(hass):
|
|||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"],
|
||||
{"username": "test username", "password": "test password"},
|
||||
result2["flow_id"], user_input
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -181,9 +197,9 @@ async def test_form_auth(hass):
|
|||
"host": "1.1.1.1",
|
||||
"model": "SHSW-1",
|
||||
"sleep_period": 0,
|
||||
"gen": 1,
|
||||
"username": "test username",
|
||||
"password": "test password",
|
||||
"gen": gen,
|
||||
"username": username,
|
||||
"password": user_input["password"],
|
||||
}
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
@ -345,8 +361,8 @@ async def test_form_firmware_unsupported(hass):
|
|||
(ValueError, "unknown"),
|
||||
],
|
||||
)
|
||||
async def test_form_auth_errors_test_connection(hass, error):
|
||||
"""Test we handle errors in authenticated devices."""
|
||||
async def test_form_auth_errors_test_connection_gen1(hass, error):
|
||||
"""Test we handle errors in Gen1 authenticated devices."""
|
||||
exc, base_error = error
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
|
@ -373,6 +389,48 @@ async def test_form_auth_errors_test_connection(hass, error):
|
|||
assert result3["errors"] == {"base": base_error}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
(
|
||||
aioshelly.exceptions.JSONRPCError(code=400),
|
||||
"cannot_connect",
|
||||
),
|
||||
(
|
||||
aioshelly.exceptions.InvalidAuthError(code=401),
|
||||
"invalid_auth",
|
||||
),
|
||||
(asyncio.TimeoutError, "cannot_connect"),
|
||||
(ValueError, "unknown"),
|
||||
],
|
||||
)
|
||||
async def test_form_auth_errors_test_connection_gen2(hass, error):
|
||||
"""Test we handle errors in Gen2 authenticated devices."""
|
||||
exc, base_error = error
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"aioshelly.common.get_info",
|
||||
return_value={"mac": "test-mac", "auth": True, "gen": 2},
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"host": "1.1.1.1"},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"aioshelly.rpc_device.RpcDevice.create",
|
||||
new=AsyncMock(side_effect=exc),
|
||||
):
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {"password": "test password"}
|
||||
)
|
||||
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result3["errors"] == {"base": base_error}
|
||||
|
||||
|
||||
async def test_zeroconf(hass):
|
||||
"""Test we get the form."""
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue