Add reauth support to NUT (#114131)

This commit is contained in:
J. Nick Koston 2024-03-25 07:59:46 -10:00 committed by GitHub
parent 135c40cad8
commit c3f4aca4e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 295 additions and 71 deletions

View file

@ -7,7 +7,7 @@ from datetime import timedelta
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from aionut import AIONUTClient, NUTError from aionut import AIONUTClient, NUTError, NUTLoginError
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
@ -21,7 +21,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
@ -70,6 +70,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Fetch data from NUT.""" """Fetch data from NUT."""
try: try:
return await data.async_update() return await data.async_update()
except NUTLoginError as err:
raise ConfigEntryAuthFailed from err
except NUTError as err: except NUTError as err:
raise UpdateFailed(f"Error fetching UPS state: {err}") from err raise UpdateFailed(f"Error fetching UPS state: {err}") from err
@ -249,16 +251,9 @@ class PyNUTData:
async def _async_get_alias(self) -> str | None: async def _async_get_alias(self) -> str | None:
"""Get the ups alias from NUT.""" """Get the ups alias from NUT."""
try: if not (ups_list := await self._client.list_ups()):
ups_list = await self._client.list_ups()
except NUTError as err:
_LOGGER.error("Failure getting NUT ups alias, %s", err)
return None
if not ups_list:
_LOGGER.error("Empty list while getting NUT ups aliases") _LOGGER.error("Empty list while getting NUT ups aliases")
return None return None
self.ups_list = ups_list self.ups_list = ups_list
return list(ups_list)[0] return list(ups_list)[0]

View file

@ -6,7 +6,7 @@ from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
from aionut import NUTError from aionut import NUTError, NUTLoginError
import voluptuous as vol import voluptuous as vol
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
@ -26,28 +26,23 @@ from homeassistant.const import (
CONF_USERNAME, CONF_USERNAME,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.data_entry_flow import AbortFlow
from . import PyNUTData from . import PyNUTData
from .const import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SCAN_INTERVAL, DOMAIN from .const import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SCAN_INTERVAL, DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
AUTH_SCHEMA = {vol.Optional(CONF_USERNAME): str, vol.Optional(CONF_PASSWORD): str}
def _base_schema(discovery_info: zeroconf.ZeroconfServiceInfo | None) -> vol.Schema:
def _base_schema(nut_config: dict[str, Any]) -> vol.Schema:
"""Generate base schema.""" """Generate base schema."""
base_schema = {} base_schema = {
if not discovery_info: vol.Optional(CONF_HOST, default=nut_config.get(CONF_HOST) or DEFAULT_HOST): str,
base_schema.update( vol.Optional(CONF_PORT, default=nut_config.get(CONF_PORT) or DEFAULT_PORT): int,
{ }
vol.Optional(CONF_HOST, default=DEFAULT_HOST): str, base_schema.update(AUTH_SCHEMA)
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int,
}
)
base_schema.update(
{vol.Optional(CONF_USERNAME): str, vol.Optional(CONF_PASSWORD): str}
)
return vol.Schema(base_schema) return vol.Schema(base_schema)
@ -69,13 +64,10 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
password = data.get(CONF_PASSWORD) password = data.get(CONF_PASSWORD)
nut_data = PyNUTData(host, port, alias, username, password, persistent=False) nut_data = PyNUTData(host, port, alias, username, password, persistent=False)
try: status = await nut_data.async_update()
status = await nut_data.async_update()
except NUTError as err:
raise CannotConnect(str(err)) from err
if not alias and not nut_data.ups_list: if not alias and not nut_data.ups_list:
raise CannotConnect("No UPSes found on the NUT server") raise AbortFlow("no_ups_found")
return {"ups_list": nut_data.ups_list, "available_resources": status} return {"ups_list": nut_data.ups_list, "available_resources": status}
@ -98,20 +90,20 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the nut config flow.""" """Initialize the nut config flow."""
self.nut_config: dict[str, Any] = {} self.nut_config: dict[str, Any] = {}
self.discovery_info: zeroconf.ZeroconfServiceInfo | None = None
self.ups_list: dict[str, str] | None = None self.ups_list: dict[str, str] | None = None
self.title: str | None = None self.title: str | None = None
self.reauth_entry: ConfigEntry | None = None
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo self, discovery_info: zeroconf.ZeroconfServiceInfo
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Prepare configuration for a discovered nut device.""" """Prepare configuration for a discovered nut device."""
self.discovery_info = discovery_info
await self._async_handle_discovery_without_unique_id() await self._async_handle_discovery_without_unique_id()
self.context["title_placeholders"] = { self.nut_config = {
CONF_HOST: discovery_info.host or DEFAULT_HOST,
CONF_PORT: discovery_info.port or DEFAULT_PORT, CONF_PORT: discovery_info.port or DEFAULT_PORT,
CONF_HOST: discovery_info.host,
} }
self.context["title_placeholders"] = self.nut_config.copy()
return await self.async_step_user() return await self.async_step_user()
async def async_step_user( async def async_step_user(
@ -119,29 +111,28 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the user input.""" """Handle the user input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
placeholders: dict[str, str] = {}
nut_config = self.nut_config
if user_input is not None: if user_input is not None:
if self.discovery_info: nut_config.update(user_input)
user_input.update(
{ info, errors, placeholders = await self._async_validate_or_error(nut_config)
CONF_HOST: self.discovery_info.host,
CONF_PORT: self.discovery_info.port or DEFAULT_PORT,
}
)
info, errors = await self._async_validate_or_error(user_input)
if not errors: if not errors:
self.nut_config.update(user_input)
if len(info["ups_list"]) > 1: if len(info["ups_list"]) > 1:
self.ups_list = info["ups_list"] self.ups_list = info["ups_list"]
return await self.async_step_ups() return await self.async_step_ups()
if self._host_port_alias_already_configured(self.nut_config): if self._host_port_alias_already_configured(nut_config):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
title = _format_host_port_alias(self.nut_config) title = _format_host_port_alias(nut_config)
return self.async_create_entry(title=title, data=self.nut_config) return self.async_create_entry(title=title, data=nut_config)
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=_base_schema(self.discovery_info), errors=errors step_id="user",
data_schema=_base_schema(nut_config),
errors=errors,
description_placeholders=placeholders,
) )
async def async_step_ups( async def async_step_ups(
@ -149,20 +140,23 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the picking the ups.""" """Handle the picking the ups."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
placeholders: dict[str, str] = {}
nut_config = self.nut_config
if user_input is not None: if user_input is not None:
self.nut_config.update(user_input) self.nut_config.update(user_input)
if self._host_port_alias_already_configured(self.nut_config): if self._host_port_alias_already_configured(nut_config):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
_, errors = await self._async_validate_or_error(self.nut_config) _, errors, placeholders = await self._async_validate_or_error(nut_config)
if not errors: if not errors:
title = _format_host_port_alias(self.nut_config) title = _format_host_port_alias(nut_config)
return self.async_create_entry(title=title, data=self.nut_config) return self.async_create_entry(title=title, data=nut_config)
return self.async_show_form( return self.async_show_form(
step_id="ups", step_id="ups",
data_schema=_ups_schema(self.ups_list or {}), data_schema=_ups_schema(self.ups_list or {}),
errors=errors, errors=errors,
description_placeholders=placeholders,
) )
def _host_port_alias_already_configured(self, user_input: dict[str, Any]) -> bool: def _host_port_alias_already_configured(self, user_input: dict[str, Any]) -> bool:
@ -176,17 +170,66 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN):
async def _async_validate_or_error( async def _async_validate_or_error(
self, config: dict[str, Any] self, config: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, str]]: ) -> tuple[dict[str, Any], dict[str, str], dict[str, str]]:
errors = {} errors: dict[str, str] = {}
info = {} info: dict[str, Any] = {}
description_placeholders: dict[str, str] = {}
try: try:
info = await validate_input(self.hass, config) info = await validate_input(self.hass, config)
except CannotConnect: except NUTLoginError:
errors[CONF_PASSWORD] = "invalid_auth"
except NUTError as ex:
errors[CONF_BASE] = "cannot_connect" errors[CONF_BASE] = "cannot_connect"
description_placeholders["error"] = str(ex)
except AbortFlow:
raise
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")
errors[CONF_BASE] = "unknown" errors[CONF_BASE] = "unknown"
return info, errors return info, errors, description_placeholders
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Handle reauth."""
entry_id = self.context["entry_id"]
self.reauth_entry = self.hass.config_entries.async_get_entry(entry_id)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle reauth input."""
errors: dict[str, str] = {}
existing_entry = self.reauth_entry
assert existing_entry
existing_data = existing_entry.data
description_placeholders: dict[str, str] = {
CONF_HOST: existing_data[CONF_HOST],
CONF_PORT: existing_data[CONF_PORT],
}
if user_input is not None:
new_config = {
**existing_data,
# Username/password are optional and some servers
# use ip based authentication and will fail if
# username/password are provided
CONF_USERNAME: user_input.get(CONF_USERNAME),
CONF_PASSWORD: user_input.get(CONF_PASSWORD),
}
_, errors, placeholders = await self._async_validate_or_error(new_config)
if not errors:
return self.async_update_reload_and_abort(
existing_entry, data=new_config
)
description_placeholders.update(placeholders)
return self.async_show_form(
description_placeholders=description_placeholders,
step_id="reauth_confirm",
data_schema=vol.Schema(AUTH_SCHEMA),
errors=errors,
)
@staticmethod @staticmethod
@callback @callback
@ -220,7 +263,3 @@ class OptionsFlowHandler(OptionsFlow):
} }
return self.async_show_form(step_id="init", data_schema=vol.Schema(base_schema)) return self.async_show_form(step_id="init", data_schema=vol.Schema(base_schema))
class CannotConnect(HomeAssistantError):
"""Error to indicate we cannot connect."""

View file

@ -18,14 +18,24 @@
"data": { "data": {
"alias": "Alias" "alias": "Alias"
} }
},
"reauth_confirm": {
"description": "Re-authenticate NUT server at {host}:{port}",
"data": {
"username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]"
}
} }
}, },
"error": { "error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "Connection error: {error}",
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]" "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"no_ups_found": "There are no UPS devices available on the NUT server.",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"options": { "options": {

View file

@ -3,7 +3,7 @@
from ipaddress import ip_address from ipaddress import ip_address
from unittest.mock import patch from unittest.mock import patch
from aionut import NUTError from aionut import NUTError, NUTLoginError
from homeassistant import config_entries, data_entry_flow, setup from homeassistant import config_entries, data_entry_flow, setup
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
@ -232,8 +232,8 @@ async def test_form_user_one_ups_with_ignored_entry(hass: HomeAssistant) -> None
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_form_cannot_connect(hass: HomeAssistant) -> None: async def test_form_no_upses_found(hass: HomeAssistant) -> None:
"""Test we handle cannot connect error.""" """Test we abort when the NUT server has not UPSes."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
@ -254,15 +254,22 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
}, },
) )
assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["type"] is data_entry_flow.FlowResultType.ABORT
assert result2["errors"] == {"base": "cannot_connect"} assert result2["reason"] == "no_ups_found"
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
"""Test we handle cannot connect error."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch( with patch(
"homeassistant.components.nut.AIONUTClient.list_ups", "homeassistant.components.nut.AIONUTClient.list_ups",
side_effect=NUTError, side_effect=NUTError("no route to host"),
), patch( ), patch(
"homeassistant.components.nut.AIONUTClient.list_vars", "homeassistant.components.nut.AIONUTClient.list_vars",
side_effect=NUTError, side_effect=NUTError("no route to host"),
): ):
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
@ -276,6 +283,7 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["type"] == data_entry_flow.FlowResultType.FORM
assert result2["errors"] == {"base": "cannot_connect"} assert result2["errors"] == {"base": "cannot_connect"}
assert result2["description_placeholders"] == {"error": "no route to host"}
with patch( with patch(
"homeassistant.components.nut.AIONUTClient.list_ups", "homeassistant.components.nut.AIONUTClient.list_ups",
@ -297,6 +305,154 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None:
assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["type"] == data_entry_flow.FlowResultType.FORM
assert result2["errors"] == {"base": "unknown"} assert result2["errors"] == {"base": "unknown"}
mock_pynut = _get_mock_nutclient(
list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"]
)
with patch(
"homeassistant.components.nut.AIONUTClient",
return_value=mock_pynut,
), patch(
"homeassistant.components.nut.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_HOST: "1.1.1.1",
CONF_USERNAME: "test-username",
CONF_PASSWORD: "test-password",
CONF_PORT: 2222,
},
)
await hass.async_block_till_done()
assert result2["type"] is data_entry_flow.FlowResultType.CREATE_ENTRY
assert result2["title"] == "1.1.1.1:2222"
assert result2["data"] == {
CONF_HOST: "1.1.1.1",
CONF_PASSWORD: "test-password",
CONF_PORT: 2222,
CONF_USERNAME: "test-username",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_auth_failures(hass: HomeAssistant) -> None:
"""Test authentication failures."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch(
"homeassistant.components.nut.AIONUTClient.list_ups",
side_effect=NUTLoginError,
), patch(
"homeassistant.components.nut.AIONUTClient.list_vars",
side_effect=NUTLoginError,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_HOST: "1.1.1.1",
CONF_USERNAME: "test-username",
CONF_PASSWORD: "test-password",
CONF_PORT: 2222,
},
)
assert result2["type"] is data_entry_flow.FlowResultType.FORM
assert result2["errors"] == {"password": "invalid_auth"}
mock_pynut = _get_mock_nutclient(
list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"]
)
with patch(
"homeassistant.components.nut.AIONUTClient",
return_value=mock_pynut,
), patch(
"homeassistant.components.nut.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_HOST: "1.1.1.1",
CONF_USERNAME: "test-username",
CONF_PASSWORD: "test-password",
CONF_PORT: 2222,
},
)
await hass.async_block_till_done()
assert result2["type"] is data_entry_flow.FlowResultType.CREATE_ENTRY
assert result2["title"] == "1.1.1.1:2222"
assert result2["data"] == {
CONF_HOST: "1.1.1.1",
CONF_PASSWORD: "test-password",
CONF_PORT: 2222,
CONF_USERNAME: "test-username",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_reauth(hass: HomeAssistant) -> None:
"""Test reauth flow."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_HOST: "1.1.1.1",
CONF_PORT: 123,
CONF_RESOURCES: ["battery.voltage"],
},
)
config_entry.add_to_hass(hass)
config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress_by_handler(DOMAIN)
assert len(flows) == 1
flow = flows[0]
with patch(
"homeassistant.components.nut.AIONUTClient.list_ups",
side_effect=NUTLoginError,
), patch(
"homeassistant.components.nut.AIONUTClient.list_vars",
side_effect=NUTLoginError,
):
result2 = await hass.config_entries.flow.async_configure(
flow["flow_id"],
{
CONF_USERNAME: "test-username",
CONF_PASSWORD: "test-password",
},
)
assert result2["type"] is data_entry_flow.FlowResultType.FORM
assert result2["errors"] == {"password": "invalid_auth"}
mock_pynut = _get_mock_nutclient(
list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"]
)
with patch(
"homeassistant.components.nut.AIONUTClient",
return_value=mock_pynut,
), patch(
"homeassistant.components.nut.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
flow["flow_id"],
{
CONF_USERNAME: "test-username",
CONF_PASSWORD: "test-password",
},
)
await hass.async_block_till_done()
assert result2["type"] is data_entry_flow.FlowResultType.ABORT
assert result2["reason"] == "reauth_successful"
assert len(mock_setup_entry.mock_calls) == 1
async def test_abort_if_already_setup(hass: HomeAssistant) -> None: async def test_abort_if_already_setup(hass: HomeAssistant) -> None:
"""Test we abort if component is already setup.""" """Test we abort if component is already setup."""

View file

@ -2,7 +2,7 @@
from unittest.mock import patch from unittest.mock import patch
from aionut import NUTError from aionut import NUTError, NUTLoginError
from homeassistant.components.nut.const import DOMAIN from homeassistant.components.nut.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
@ -66,3 +66,27 @@ async def test_config_not_ready(hass: HomeAssistant) -> None:
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_RETRY assert entry.state is ConfigEntryState.SETUP_RETRY
async def test_auth_fails(hass: HomeAssistant) -> None:
"""Test for setup failure if auth has changed."""
entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_HOST: "mock", CONF_PORT: "mock"},
)
entry.add_to_hass(hass)
with patch(
"homeassistant.components.nut.AIONUTClient.list_ups",
return_value={"ups1"},
), patch(
"homeassistant.components.nut.AIONUTClient.list_vars",
side_effect=NUTLoginError,
):
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
assert flows[0]["context"]["source"] == "reauth"