diff --git a/homeassistant/components/airvisual/__init__.py b/homeassistant/components/airvisual/__init__.py index 563e24bf8fd..f06e4fe70b7 100644 --- a/homeassistant/components/airvisual/__init__.py +++ b/homeassistant/components/airvisual/__init__.py @@ -4,7 +4,12 @@ from datetime import timedelta from math import ceil from pyairvisual import Client -from pyairvisual.errors import AirVisualError, NodeProError +from pyairvisual.errors import ( + AirVisualError, + InvalidKeyError, + KeyExpiredError, + NodeProError, +) import voluptuous as vol from homeassistant.config_entries import SOURCE_IMPORT @@ -229,6 +234,14 @@ async def async_setup_entry(hass, config_entry): try: return await api_coro + except (InvalidKeyError, KeyExpiredError): + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, + context={"source": "reauth"}, + data=config_entry.data, + ) + ) except AirVisualError as err: raise UpdateFailed(f"Error while retrieving data: {err}") from err diff --git a/homeassistant/components/airvisual/config_flow.py b/homeassistant/components/airvisual/config_flow.py index abbc2df9061..bb1c262eba7 100644 --- a/homeassistant/components/airvisual/config_flow.py +++ b/homeassistant/components/airvisual/config_flow.py @@ -34,12 +34,19 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 2 CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL + def __init__(self): + """Initialize the config flow.""" + self._geo_id = None + self._latitude = None + self._longitude = None + + self.api_key_data_schema = vol.Schema({vol.Required(CONF_API_KEY): str}) + @property def geography_schema(self): """Return the data schema for the cloud API.""" - return vol.Schema( + return self.api_key_data_schema.extend( { - vol.Required(CONF_API_KEY): str, vol.Required( CONF_LATITUDE, default=self.hass.config.latitude ): cv.latitude, @@ -85,8 +92,8 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): step_id="geography", data_schema=self.geography_schema ) - geo_id = async_get_geography_id(user_input) - await self._async_set_unique_id(geo_id) + self._geo_id = async_get_geography_id(user_input) + await self._async_set_unique_id(self._geo_id) self._abort_if_unique_id_configured() # Find older config entries without unique ID: @@ -95,7 +102,7 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): continue if any( - geo_id == async_get_geography_id(geography) + self._geo_id == async_get_geography_id(geography) for geography in entry.data[CONF_GEOGRAPHIES] ): return self.async_abort(reason="already_configured") @@ -123,10 +130,19 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): checked_keys.add(user_input[CONF_API_KEY]) - return self.async_create_entry( - title=f"Cloud API ({geo_id})", - data={**user_input, CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY}, - ) + return await self.async_step_geography_finish(user_input) + + async def async_step_geography_finish(self, user_input=None): + """Handle the finalization of a Cloud API config entry.""" + existing_entry = await self.async_set_unique_id(self._geo_id) + if existing_entry: + self.hass.config_entries.async_update_entry(existing_entry, data=user_input) + return self.async_abort(reason="reauth_successful") + + return self.async_create_entry( + title=f"Cloud API ({self._geo_id})", + data={**user_input, CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY}, + ) async def async_step_import(self, import_config): """Import a config entry from configuration.yaml.""" @@ -164,6 +180,30 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): data={**user_input, CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO}, ) + async def async_step_reauth(self, data): + """Handle configuration by re-auth.""" + self._latitude = data[CONF_LATITUDE] + self._longitude = data[CONF_LONGITUDE] + + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm(self, user_input=None): + """Handle re-auth completion.""" + if not user_input: + return self.async_show_form( + step_id="reauth_confirm", data_schema=self.api_key_data_schema + ) + + conf = { + CONF_API_KEY: user_input[CONF_API_KEY], + CONF_LATITUDE: self._latitude, + CONF_LONGITUDE: self._longitude, + } + + self._geo_id = async_get_geography_id(conf) + + return await self.async_step_geography_finish(conf) + async def async_step_user(self, user_input=None): """Handle the start of the config flow.""" if not user_input: diff --git a/tests/components/airvisual/test_config_flow.py b/tests/components/airvisual/test_config_flow.py index 8912b0287d7..d365720ad26 100644 --- a/tests/components/airvisual/test_config_flow.py +++ b/tests/components/airvisual/test_config_flow.py @@ -31,7 +31,6 @@ async def test_duplicate_error(hass): CONF_LATITUDE: 51.528308, CONF_LONGITUDE: -0.3817765, } - node_pro_conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "12345"} MockConfigEntry( domain=DOMAIN, unique_id="51.528308, -0.3817765", data=geography_conf @@ -44,6 +43,8 @@ async def test_duplicate_error(hass): assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["reason"] == "already_configured" + node_pro_conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "12345"} + MockConfigEntry( domain=DOMAIN, unique_id="192.168.1.100", data=node_pro_conf ).add_to_hass(hass) @@ -78,24 +79,6 @@ async def test_invalid_identifier(hass): assert result["errors"] == {CONF_API_KEY: "invalid_api_key"} -async def test_node_pro_error(hass): - """Test that an invalid Node/Pro ID shows an error.""" - node_pro_conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "my_password"} - - with patch( - "pyairvisual.node.Node.from_samba", - side_effect=NodeProError, - ): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data={"type": "AirVisual Node/Pro"} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=node_pro_conf - ) - assert result["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result["errors"] == {CONF_IP_ADDRESS: "unable_to_connect"} - - async def test_migration(hass): """Test migrating from version 1 to the current version.""" conf = { @@ -142,6 +125,24 @@ async def test_migration(hass): } +async def test_node_pro_error(hass): + """Test that an invalid Node/Pro ID shows an error.""" + node_pro_conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "my_password"} + + with patch( + "pyairvisual.node.Node.from_samba", + side_effect=NodeProError, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data={"type": "AirVisual Node/Pro"} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=node_pro_conf + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {CONF_IP_ADDRESS: "unable_to_connect"} + + async def test_options_flow(hass): """Test config flow options.""" geography_conf = { @@ -198,28 +199,6 @@ async def test_step_geography(hass): } -async def test_step_node_pro(hass): - """Test the Node/Pro step.""" - conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "my_password"} - - with patch( - "homeassistant.components.airvisual.async_setup_entry", return_value=True - ), patch("pyairvisual.node.Node.from_samba"): - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_USER}, data={"type": "AirVisual Node/Pro"} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input=conf - ) - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == "Node/Pro (192.168.1.100)" - assert result["data"] == { - CONF_IP_ADDRESS: "192.168.1.100", - CONF_PASSWORD: "my_password", - CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO, - } - - async def test_step_import(hass): """Test the import step for both types of configuration.""" geography_conf = { @@ -245,6 +224,61 @@ async def test_step_import(hass): } +async def test_step_node_pro(hass): + """Test the Node/Pro step.""" + conf = {CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "my_password"} + + with patch( + "homeassistant.components.airvisual.async_setup_entry", return_value=True + ), patch("pyairvisual.node.Node.from_samba"): + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER}, data={"type": "AirVisual Node/Pro"} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=conf + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == "Node/Pro (192.168.1.100)" + assert result["data"] == { + CONF_IP_ADDRESS: "192.168.1.100", + CONF_PASSWORD: "my_password", + CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO, + } + + +async def test_step_reauth(hass): + """Test that the reauth step works.""" + geography_conf = { + CONF_API_KEY: "abcde12345", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + } + + MockConfigEntry( + domain=DOMAIN, unique_id="51.528308, -0.3817765", data=geography_conf + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": "reauth"}, data=geography_conf + ) + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "reauth_confirm" + + with patch( + "homeassistant.components.simplisafe.async_setup_entry", return_value=True + ), patch("pyairvisual.api.API.nearest_city"): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={CONF_API_KEY: "defgh67890"} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "reauth_successful" + + assert len(hass.config_entries.async_entries()) == 1 + + async def test_step_user(hass): """Test the user ("pick the integration type") step.""" result = await hass.config_entries.flow.async_init(