Add Shelly gen2 authentication support (#69753)

This commit is contained in:
Shay Levy 2022-04-14 00:30:03 +03:00 committed by GitHub
parent 7edbe66b26
commit d4d819679c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 13 deletions

View file

@ -143,6 +143,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle the credentials step."""
errors: dict[str, str] = {}
if user_input is not None:
if get_info_gen(self.info) == 2:
user_input[CONF_USERNAME] = "admin"
try:
device_info = await validate_input(
self.hass, self.host, self.info, user_input
@ -152,8 +154,12 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors["base"] = "invalid_auth"
else:
errors["base"] = "cannot_connect"
except aioshelly.exceptions.InvalidAuthError:
errors["base"] = "invalid_auth"
except HTTP_CONNECT_ERRORS:
errors["base"] = "cannot_connect"
except aioshelly.exceptions.JSONRPCError:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
@ -171,15 +177,18 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
else:
user_input = {}
schema = vol.Schema(
{
if get_info_gen(self.info) == 2:
schema = {
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str,
}
else:
schema = {
vol.Required(CONF_USERNAME, default=user_input.get(CONF_USERNAME)): str,
vol.Required(CONF_PASSWORD, default=user_input.get(CONF_PASSWORD)): str,
}
)
return self.async_show_form(
step_id="credentials", data_schema=schema, errors=errors
step_id="credentials", data_schema=vol.Schema(schema), errors=errors
)
async def async_step_zeroconf(

View file

@ -135,8 +135,16 @@ async def test_title_without_name(hass):
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_auth(hass):
@pytest.mark.parametrize(
"test_data",
[
(1, {"username": "test user", "password": "test1 password"}, "test user"),
(2, {"password": "test2 password"}, "admin"),
],
)
async def test_form_auth(hass, test_data):
"""Test manual configuration if auth is required."""
gen, user_input, username = test_data
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
@ -145,7 +153,7 @@ async def test_form_auth(hass):
with patch(
"aioshelly.common.get_info",
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True},
return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True, "gen": gen},
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
@ -163,6 +171,15 @@ async def test_form_auth(hass):
settings=MOCK_SETTINGS,
)
),
), patch(
"aioshelly.rpc_device.RpcDevice.create",
new=AsyncMock(
return_value=Mock(
model="SHSW-1",
config=MOCK_CONFIG,
shutdown=AsyncMock(),
)
),
), patch(
"homeassistant.components.shelly.async_setup", return_value=True
) as mock_setup, patch(
@ -170,8 +187,7 @@ async def test_form_auth(hass):
return_value=True,
) as mock_setup_entry:
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"],
{"username": "test username", "password": "test password"},
result2["flow_id"], user_input
)
await hass.async_block_till_done()
@ -181,9 +197,9 @@ async def test_form_auth(hass):
"host": "1.1.1.1",
"model": "SHSW-1",
"sleep_period": 0,
"gen": 1,
"username": "test username",
"password": "test password",
"gen": gen,
"username": username,
"password": user_input["password"],
}
assert len(mock_setup.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
@ -345,8 +361,8 @@ async def test_form_firmware_unsupported(hass):
(ValueError, "unknown"),
],
)
async def test_form_auth_errors_test_connection(hass, error):
"""Test we handle errors in authenticated devices."""
async def test_form_auth_errors_test_connection_gen1(hass, error):
"""Test we handle errors in Gen1 authenticated devices."""
exc, base_error = error
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -373,6 +389,48 @@ async def test_form_auth_errors_test_connection(hass, error):
assert result3["errors"] == {"base": base_error}
@pytest.mark.parametrize(
"error",
[
(
aioshelly.exceptions.JSONRPCError(code=400),
"cannot_connect",
),
(
aioshelly.exceptions.InvalidAuthError(code=401),
"invalid_auth",
),
(asyncio.TimeoutError, "cannot_connect"),
(ValueError, "unknown"),
],
)
async def test_form_auth_errors_test_connection_gen2(hass, error):
"""Test we handle errors in Gen2 authenticated devices."""
exc, base_error = error
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch(
"aioshelly.common.get_info",
return_value={"mac": "test-mac", "auth": True, "gen": 2},
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"host": "1.1.1.1"},
)
with patch(
"aioshelly.rpc_device.RpcDevice.create",
new=AsyncMock(side_effect=exc),
):
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {"password": "test password"}
)
assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result3["errors"] == {"base": base_error}
async def test_zeroconf(hass):
"""Test we get the form."""