diff --git a/homeassistant/components/nut/__init__.py b/homeassistant/components/nut/__init__.py index c9067bbb254..575def8bf0f 100644 --- a/homeassistant/components/nut/__init__.py +++ b/homeassistant/components/nut/__init__.py @@ -7,7 +7,7 @@ from datetime import timedelta import logging from typing import TYPE_CHECKING -from aionut import AIONUTClient, NUTError +from aionut import AIONUTClient, NUTError, NUTLoginError from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( @@ -21,7 +21,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.exceptions import HomeAssistantError +from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError from homeassistant.helpers import device_registry as dr from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed @@ -70,6 +70,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Fetch data from NUT.""" try: return await data.async_update() + except NUTLoginError as err: + raise ConfigEntryAuthFailed from err except NUTError as err: raise UpdateFailed(f"Error fetching UPS state: {err}") from err @@ -249,16 +251,9 @@ class PyNUTData: async def _async_get_alias(self) -> str | None: """Get the ups alias from NUT.""" - try: - ups_list = await self._client.list_ups() - except NUTError as err: - _LOGGER.error("Failure getting NUT ups alias, %s", err) - return None - - if not ups_list: + if not (ups_list := await self._client.list_ups()): _LOGGER.error("Empty list while getting NUT ups aliases") return None - self.ups_list = ups_list return list(ups_list)[0] diff --git a/homeassistant/components/nut/config_flow.py b/homeassistant/components/nut/config_flow.py index 3f3de8a126c..f0126ba4894 100644 --- a/homeassistant/components/nut/config_flow.py +++ b/homeassistant/components/nut/config_flow.py @@ -6,7 +6,7 @@ from collections.abc import Mapping import logging from typing import Any -from aionut import NUTError +from aionut import NUTError, NUTLoginError import voluptuous as vol from homeassistant.components import zeroconf @@ -26,28 +26,23 @@ from homeassistant.const import ( CONF_USERNAME, ) from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import HomeAssistantError +from homeassistant.data_entry_flow import AbortFlow from . import PyNUTData from .const import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SCAN_INTERVAL, DOMAIN _LOGGER = logging.getLogger(__name__) +AUTH_SCHEMA = {vol.Optional(CONF_USERNAME): str, vol.Optional(CONF_PASSWORD): str} -def _base_schema(discovery_info: zeroconf.ZeroconfServiceInfo | None) -> vol.Schema: + +def _base_schema(nut_config: dict[str, Any]) -> vol.Schema: """Generate base schema.""" - base_schema = {} - if not discovery_info: - base_schema.update( - { - vol.Optional(CONF_HOST, default=DEFAULT_HOST): str, - vol.Optional(CONF_PORT, default=DEFAULT_PORT): int, - } - ) - base_schema.update( - {vol.Optional(CONF_USERNAME): str, vol.Optional(CONF_PASSWORD): str} - ) - + base_schema = { + vol.Optional(CONF_HOST, default=nut_config.get(CONF_HOST) or DEFAULT_HOST): str, + vol.Optional(CONF_PORT, default=nut_config.get(CONF_PORT) or DEFAULT_PORT): int, + } + base_schema.update(AUTH_SCHEMA) return vol.Schema(base_schema) @@ -69,13 +64,10 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, password = data.get(CONF_PASSWORD) nut_data = PyNUTData(host, port, alias, username, password, persistent=False) - try: - status = await nut_data.async_update() - except NUTError as err: - raise CannotConnect(str(err)) from err + status = await nut_data.async_update() if not alias and not nut_data.ups_list: - raise CannotConnect("No UPSes found on the NUT server") + raise AbortFlow("no_ups_found") return {"ups_list": nut_data.ups_list, "available_resources": status} @@ -98,20 +90,20 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN): def __init__(self) -> None: """Initialize the nut config flow.""" self.nut_config: dict[str, Any] = {} - self.discovery_info: zeroconf.ZeroconfServiceInfo | None = None self.ups_list: dict[str, str] | None = None self.title: str | None = None + self.reauth_entry: ConfigEntry | None = None async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> ConfigFlowResult: """Prepare configuration for a discovered nut device.""" - self.discovery_info = discovery_info await self._async_handle_discovery_without_unique_id() - self.context["title_placeholders"] = { + self.nut_config = { + CONF_HOST: discovery_info.host or DEFAULT_HOST, CONF_PORT: discovery_info.port or DEFAULT_PORT, - CONF_HOST: discovery_info.host, } + self.context["title_placeholders"] = self.nut_config.copy() return await self.async_step_user() async def async_step_user( @@ -119,29 +111,28 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Handle the user input.""" errors: dict[str, str] = {} + placeholders: dict[str, str] = {} + nut_config = self.nut_config if user_input is not None: - if self.discovery_info: - user_input.update( - { - CONF_HOST: self.discovery_info.host, - CONF_PORT: self.discovery_info.port or DEFAULT_PORT, - } - ) - info, errors = await self._async_validate_or_error(user_input) + nut_config.update(user_input) + + info, errors, placeholders = await self._async_validate_or_error(nut_config) if not errors: - self.nut_config.update(user_input) if len(info["ups_list"]) > 1: self.ups_list = info["ups_list"] return await self.async_step_ups() - if self._host_port_alias_already_configured(self.nut_config): + if self._host_port_alias_already_configured(nut_config): return self.async_abort(reason="already_configured") - title = _format_host_port_alias(self.nut_config) - return self.async_create_entry(title=title, data=self.nut_config) + title = _format_host_port_alias(nut_config) + return self.async_create_entry(title=title, data=nut_config) return self.async_show_form( - step_id="user", data_schema=_base_schema(self.discovery_info), errors=errors + step_id="user", + data_schema=_base_schema(nut_config), + errors=errors, + description_placeholders=placeholders, ) async def async_step_ups( @@ -149,20 +140,23 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Handle the picking the ups.""" errors: dict[str, str] = {} + placeholders: dict[str, str] = {} + nut_config = self.nut_config if user_input is not None: self.nut_config.update(user_input) - if self._host_port_alias_already_configured(self.nut_config): + if self._host_port_alias_already_configured(nut_config): return self.async_abort(reason="already_configured") - _, errors = await self._async_validate_or_error(self.nut_config) + _, errors, placeholders = await self._async_validate_or_error(nut_config) if not errors: - title = _format_host_port_alias(self.nut_config) - return self.async_create_entry(title=title, data=self.nut_config) + title = _format_host_port_alias(nut_config) + return self.async_create_entry(title=title, data=nut_config) return self.async_show_form( step_id="ups", data_schema=_ups_schema(self.ups_list or {}), errors=errors, + description_placeholders=placeholders, ) def _host_port_alias_already_configured(self, user_input: dict[str, Any]) -> bool: @@ -176,17 +170,66 @@ class NutConfigFlow(ConfigFlow, domain=DOMAIN): async def _async_validate_or_error( self, config: dict[str, Any] - ) -> tuple[dict[str, Any], dict[str, str]]: - errors = {} - info = {} + ) -> tuple[dict[str, Any], dict[str, str], dict[str, str]]: + errors: dict[str, str] = {} + info: dict[str, Any] = {} + description_placeholders: dict[str, str] = {} try: info = await validate_input(self.hass, config) - except CannotConnect: + except NUTLoginError: + errors[CONF_PASSWORD] = "invalid_auth" + except NUTError as ex: errors[CONF_BASE] = "cannot_connect" + description_placeholders["error"] = str(ex) + except AbortFlow: + raise except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") errors[CONF_BASE] = "unknown" - return info, errors + return info, errors, description_placeholders + + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Handle reauth.""" + entry_id = self.context["entry_id"] + self.reauth_entry = self.hass.config_entries.async_get_entry(entry_id) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle reauth input.""" + errors: dict[str, str] = {} + existing_entry = self.reauth_entry + assert existing_entry + existing_data = existing_entry.data + description_placeholders: dict[str, str] = { + CONF_HOST: existing_data[CONF_HOST], + CONF_PORT: existing_data[CONF_PORT], + } + if user_input is not None: + new_config = { + **existing_data, + # Username/password are optional and some servers + # use ip based authentication and will fail if + # username/password are provided + CONF_USERNAME: user_input.get(CONF_USERNAME), + CONF_PASSWORD: user_input.get(CONF_PASSWORD), + } + _, errors, placeholders = await self._async_validate_or_error(new_config) + if not errors: + return self.async_update_reload_and_abort( + existing_entry, data=new_config + ) + description_placeholders.update(placeholders) + + return self.async_show_form( + description_placeholders=description_placeholders, + step_id="reauth_confirm", + data_schema=vol.Schema(AUTH_SCHEMA), + errors=errors, + ) @staticmethod @callback @@ -220,7 +263,3 @@ class OptionsFlowHandler(OptionsFlow): } return self.async_show_form(step_id="init", data_schema=vol.Schema(base_schema)) - - -class CannotConnect(HomeAssistantError): - """Error to indicate we cannot connect.""" diff --git a/homeassistant/components/nut/strings.json b/homeassistant/components/nut/strings.json index 3c446926fe0..d5b9acbdaad 100644 --- a/homeassistant/components/nut/strings.json +++ b/homeassistant/components/nut/strings.json @@ -18,14 +18,24 @@ "data": { "alias": "Alias" } + }, + "reauth_confirm": { + "description": "Re-authenticate NUT server at {host}:{port}", + "data": { + "username": "[%key:common::config_flow::data::username%]", + "password": "[%key:common::config_flow::data::password%]" + } } }, "error": { - "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", - "unknown": "[%key:common::config_flow::error::unknown%]" + "cannot_connect": "Connection error: {error}", + "unknown": "[%key:common::config_flow::error::unknown%]", + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", + "no_ups_found": "There are no UPS devices available on the NUT server.", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } }, "options": { diff --git a/tests/components/nut/test_config_flow.py b/tests/components/nut/test_config_flow.py index b6a9590f457..0fd9949ff37 100644 --- a/tests/components/nut/test_config_flow.py +++ b/tests/components/nut/test_config_flow.py @@ -3,7 +3,7 @@ from ipaddress import ip_address from unittest.mock import patch -from aionut import NUTError +from aionut import NUTError, NUTLoginError from homeassistant import config_entries, data_entry_flow, setup from homeassistant.components import zeroconf @@ -232,8 +232,8 @@ async def test_form_user_one_ups_with_ignored_entry(hass: HomeAssistant) -> None assert len(mock_setup_entry.mock_calls) == 1 -async def test_form_cannot_connect(hass: HomeAssistant) -> None: - """Test we handle cannot connect error.""" +async def test_form_no_upses_found(hass: HomeAssistant) -> None: + """Test we abort when the NUT server has not UPSes.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -254,15 +254,22 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None: }, ) - assert result2["type"] == data_entry_flow.FlowResultType.FORM - assert result2["errors"] == {"base": "cannot_connect"} + assert result2["type"] is data_entry_flow.FlowResultType.ABORT + assert result2["reason"] == "no_ups_found" + + +async def test_form_cannot_connect(hass: HomeAssistant) -> None: + """Test we handle cannot connect error.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) with patch( "homeassistant.components.nut.AIONUTClient.list_ups", - side_effect=NUTError, + side_effect=NUTError("no route to host"), ), patch( "homeassistant.components.nut.AIONUTClient.list_vars", - side_effect=NUTError, + side_effect=NUTError("no route to host"), ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -276,6 +283,7 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None: assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["errors"] == {"base": "cannot_connect"} + assert result2["description_placeholders"] == {"error": "no route to host"} with patch( "homeassistant.components.nut.AIONUTClient.list_ups", @@ -297,6 +305,154 @@ async def test_form_cannot_connect(hass: HomeAssistant) -> None: assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["errors"] == {"base": "unknown"} + mock_pynut = _get_mock_nutclient( + list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"] + ) + with patch( + "homeassistant.components.nut.AIONUTClient", + return_value=mock_pynut, + ), patch( + "homeassistant.components.nut.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_HOST: "1.1.1.1", + CONF_USERNAME: "test-username", + CONF_PASSWORD: "test-password", + CONF_PORT: 2222, + }, + ) + await hass.async_block_till_done() + + assert result2["type"] is data_entry_flow.FlowResultType.CREATE_ENTRY + assert result2["title"] == "1.1.1.1:2222" + assert result2["data"] == { + CONF_HOST: "1.1.1.1", + CONF_PASSWORD: "test-password", + CONF_PORT: 2222, + CONF_USERNAME: "test-username", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_auth_failures(hass: HomeAssistant) -> None: + """Test authentication failures.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + with patch( + "homeassistant.components.nut.AIONUTClient.list_ups", + side_effect=NUTLoginError, + ), patch( + "homeassistant.components.nut.AIONUTClient.list_vars", + side_effect=NUTLoginError, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_HOST: "1.1.1.1", + CONF_USERNAME: "test-username", + CONF_PASSWORD: "test-password", + CONF_PORT: 2222, + }, + ) + + assert result2["type"] is data_entry_flow.FlowResultType.FORM + assert result2["errors"] == {"password": "invalid_auth"} + + mock_pynut = _get_mock_nutclient( + list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"] + ) + with patch( + "homeassistant.components.nut.AIONUTClient", + return_value=mock_pynut, + ), patch( + "homeassistant.components.nut.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_HOST: "1.1.1.1", + CONF_USERNAME: "test-username", + CONF_PASSWORD: "test-password", + CONF_PORT: 2222, + }, + ) + await hass.async_block_till_done() + + assert result2["type"] is data_entry_flow.FlowResultType.CREATE_ENTRY + assert result2["title"] == "1.1.1.1:2222" + assert result2["data"] == { + CONF_HOST: "1.1.1.1", + CONF_PASSWORD: "test-password", + CONF_PORT: 2222, + CONF_USERNAME: "test-username", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_reauth(hass: HomeAssistant) -> None: + """Test reauth flow.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "1.1.1.1", + CONF_PORT: 123, + CONF_RESOURCES: ["battery.voltage"], + }, + ) + config_entry.add_to_hass(hass) + config_entry.async_start_reauth(hass) + await hass.async_block_till_done() + flows = hass.config_entries.flow.async_progress_by_handler(DOMAIN) + assert len(flows) == 1 + flow = flows[0] + + with patch( + "homeassistant.components.nut.AIONUTClient.list_ups", + side_effect=NUTLoginError, + ), patch( + "homeassistant.components.nut.AIONUTClient.list_vars", + side_effect=NUTLoginError, + ): + result2 = await hass.config_entries.flow.async_configure( + flow["flow_id"], + { + CONF_USERNAME: "test-username", + CONF_PASSWORD: "test-password", + }, + ) + + assert result2["type"] is data_entry_flow.FlowResultType.FORM + assert result2["errors"] == {"password": "invalid_auth"} + + mock_pynut = _get_mock_nutclient( + list_vars={"battery.voltage": "voltage", "ups.status": "OL"}, list_ups=["ups1"] + ) + with patch( + "homeassistant.components.nut.AIONUTClient", + return_value=mock_pynut, + ), patch( + "homeassistant.components.nut.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + flow["flow_id"], + { + CONF_USERNAME: "test-username", + CONF_PASSWORD: "test-password", + }, + ) + await hass.async_block_till_done() + + assert result2["type"] is data_entry_flow.FlowResultType.ABORT + assert result2["reason"] == "reauth_successful" + assert len(mock_setup_entry.mock_calls) == 1 + async def test_abort_if_already_setup(hass: HomeAssistant) -> None: """Test we abort if component is already setup.""" diff --git a/tests/components/nut/test_init.py b/tests/components/nut/test_init.py index d15e9d4b12a..4dd5f2357e8 100644 --- a/tests/components/nut/test_init.py +++ b/tests/components/nut/test_init.py @@ -2,7 +2,7 @@ from unittest.mock import patch -from aionut import NUTError +from aionut import NUTError, NUTLoginError from homeassistant.components.nut.const import DOMAIN from homeassistant.config_entries import ConfigEntryState @@ -66,3 +66,27 @@ async def test_config_not_ready(hass: HomeAssistant) -> None: await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() assert entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_auth_fails(hass: HomeAssistant) -> None: + """Test for setup failure if auth has changed.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_HOST: "mock", CONF_PORT: "mock"}, + ) + entry.add_to_hass(hass) + + with patch( + "homeassistant.components.nut.AIONUTClient.list_ups", + return_value={"ups1"}, + ), patch( + "homeassistant.components.nut.AIONUTClient.list_vars", + side_effect=NUTLoginError, + ): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + assert entry.state is ConfigEntryState.SETUP_ERROR + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + assert flows[0]["context"]["source"] == "reauth"