Fix re-authentication in AirVisual (#41801)

This commit is contained in:
Aaron Bach 2020-10-15 01:30:39 -06:00 committed by GitHub
parent 53a1d92f2b
commit 099de37ee5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 35 deletions

View file

@ -12,7 +12,7 @@ from pyairvisual.errors import (
) )
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_REAUTH
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, ATTR_ATTRIBUTION,
CONF_API_KEY, CONF_API_KEY,
@ -97,14 +97,12 @@ def async_get_geography_id(geography_dict):
@callback @callback
def async_get_cloud_api_update_interval(hass, api_key): def async_get_cloud_api_update_interval(hass, api_key, num_consumers):
"""Get a leveled scan interval for a particular cloud API key. """Get a leveled scan interval for a particular cloud API key.
This will shift based on the number of active consumers, thus keeping the user This will shift based on the number of active consumers, thus keeping the user
under the monthly API limit. under the monthly API limit.
""" """
num_consumers = len(async_get_cloud_coordinators_by_api_key(hass, api_key))
# Assuming 10,000 calls per month and a "smallest possible month" of 28 days; note # Assuming 10,000 calls per month and a "smallest possible month" of 28 days; note
# that we give a buffer of 1500 API calls for any drift, restarts, etc.: # that we give a buffer of 1500 API calls for any drift, restarts, etc.:
minutes_between_api_calls = ceil(1 / (8500 / 28 / 24 / 60 / num_consumers)) minutes_between_api_calls = ceil(1 / (8500 / 28 / 24 / 60 / num_consumers))
@ -133,8 +131,16 @@ def async_get_cloud_coordinators_by_api_key(hass, api_key):
@callback @callback
def async_sync_geo_coordinator_update_intervals(hass, api_key): def async_sync_geo_coordinator_update_intervals(hass, api_key):
"""Sync the update interval for geography-based data coordinators (by API key).""" """Sync the update interval for geography-based data coordinators (by API key)."""
update_interval = async_get_cloud_api_update_interval(hass, api_key) coordinators = async_get_cloud_coordinators_by_api_key(hass, api_key)
for coordinator in async_get_cloud_coordinators_by_api_key(hass, api_key):
if not coordinators:
return
update_interval = async_get_cloud_api_update_interval(
hass, api_key, len(coordinators)
)
for coordinator in coordinators:
LOGGER.debug( LOGGER.debug(
"Updating interval for coordinator: %s, %s", "Updating interval for coordinator: %s, %s",
coordinator.name, coordinator.name,
@ -234,13 +240,26 @@ async def async_setup_entry(hass, config_entry):
try: try:
return await api_coro return await api_coro
except (InvalidKeyError, KeyExpiredError): except (InvalidKeyError, KeyExpiredError):
hass.async_create_task( matching_flows = [
hass.config_entries.flow.async_init( flow
DOMAIN, for flow in hass.config_entries.flow.async_progress()
context={"source": "reauth"}, if flow["context"]["source"] == SOURCE_REAUTH
data=config_entry.data, and flow["context"]["unique_id"] == config_entry.unique_id
]
if not matching_flows:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": SOURCE_REAUTH,
"unique_id": config_entry.unique_id,
},
data=config_entry.data,
)
) )
)
return {}
except AirVisualError as err: except AirVisualError as err:
raise UpdateFailed(f"Error while retrieving data: {err}") from err raise UpdateFailed(f"Error while retrieving data: {err}") from err
@ -262,7 +281,7 @@ async def async_setup_entry(hass, config_entry):
) )
# Only geography-based entries have options: # Only geography-based entries have options:
config_entry.add_update_listener(async_update_options) config_entry.add_update_listener(async_reload_entry)
else: else:
_standardize_node_pro_config_entry(hass, config_entry) _standardize_node_pro_config_entry(hass, config_entry)
@ -356,10 +375,9 @@ async def async_unload_entry(hass, config_entry):
return unload_ok return unload_ok
async def async_update_options(hass, config_entry): async def async_reload_entry(hass, config_entry):
"""Handle an options update.""" """Handle an options update."""
coordinator = hass.data[DOMAIN][DATA_COORDINATOR][config_entry.entry_id] await hass.config_entries.async_reload(config_entry.entry_id)
await coordinator.async_request_refresh()
class AirVisualEntity(CoordinatorEntity): class AirVisualEntity(CoordinatorEntity):

View file

@ -107,33 +107,35 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
): ):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
return await self.async_step_geography_finish(
user_input, "geography", self.geography_schema
)
async def async_step_geography_finish(self, user_input, error_step, error_schema):
"""Validate a Cloud API key."""
websession = aiohttp_client.async_get_clientsession(self.hass) websession = aiohttp_client.async_get_clientsession(self.hass)
cloud_api = CloudAPI(user_input[CONF_API_KEY], session=websession) cloud_api = CloudAPI(user_input[CONF_API_KEY], session=websession)
# If this is the first (and only the first) time we've seen this API key, check # If this is the first (and only the first) time we've seen this API key, check
# that it's valid: # that it's valid:
checked_keys = self.hass.data.setdefault("airvisual_checked_api_keys", set()) valid_keys = self.hass.data.setdefault("airvisual_checked_api_keys", set())
check_keys_lock = self.hass.data.setdefault( valid_keys_lock = self.hass.data.setdefault(
"airvisual_checked_api_keys_lock", asyncio.Lock() "airvisual_checked_api_keys_lock", asyncio.Lock()
) )
async with check_keys_lock: async with valid_keys_lock:
if user_input[CONF_API_KEY] not in checked_keys: if user_input[CONF_API_KEY] not in valid_keys:
try: try:
await cloud_api.air_quality.nearest_city() await cloud_api.air_quality.nearest_city()
except InvalidKeyError: except InvalidKeyError:
return self.async_show_form( return self.async_show_form(
step_id="geography", step_id=error_step,
data_schema=self.geography_schema, data_schema=error_schema,
errors={CONF_API_KEY: "invalid_api_key"}, errors={CONF_API_KEY: "invalid_api_key"},
) )
checked_keys.add(user_input[CONF_API_KEY]) valid_keys.add(user_input[CONF_API_KEY])
return await self.async_step_geography_finish(user_input)
async def async_step_geography_finish(self, user_input=None):
"""Handle the finalization of a Cloud API config entry."""
existing_entry = await self.async_set_unique_id(self._geo_id) existing_entry = await self.async_set_unique_id(self._geo_id)
if existing_entry: if existing_entry:
self.hass.config_entries.async_update_entry(existing_entry, data=user_input) self.hass.config_entries.async_update_entry(existing_entry, data=user_input)
@ -178,6 +180,7 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_reauth(self, data): async def async_step_reauth(self, data):
"""Handle configuration by re-auth.""" """Handle configuration by re-auth."""
self._geo_id = async_get_geography_id(data)
self._latitude = data[CONF_LATITUDE] self._latitude = data[CONF_LATITUDE]
self._longitude = data[CONF_LONGITUDE] self._longitude = data[CONF_LONGITUDE]
@ -194,11 +197,12 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
CONF_API_KEY: user_input[CONF_API_KEY], CONF_API_KEY: user_input[CONF_API_KEY],
CONF_LATITUDE: self._latitude, CONF_LATITUDE: self._latitude,
CONF_LONGITUDE: self._longitude, CONF_LONGITUDE: self._longitude,
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY,
} }
self._geo_id = async_get_geography_id(conf) return await self.async_step_geography_finish(
conf, "reauth_confirm", self.api_key_data_schema
return await self.async_step_geography_finish(conf) )
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Handle the start of the config flow.""" """Handle the start of the config flow."""

View file

@ -18,6 +18,12 @@
"password": "[%key:common::config_flow::data::password%]" "password": "[%key:common::config_flow::data::password%]"
} }
}, },
"reauth_confirm": {
"title": "Re-authenticate AirVisual",
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]"
}
},
"user": { "user": {
"title": "Configure AirVisual", "title": "Configure AirVisual",
"description": "Pick what type of AirVisual data you want to monitor.", "description": "Pick what type of AirVisual data you want to monitor.",
@ -34,7 +40,8 @@
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_location%] or Node/Pro ID is already registered." "already_configured": "[%key:common::config_flow::abort::already_configured_location%] or Node/Pro ID is already registered.",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"options": { "options": {

View file

@ -1,13 +1,13 @@
{ {
"config": { "config": {
"abort": { "abort": {
"already_configured": "Location is already configured or Node/Pro ID is already registered." "already_configured": "Location is already configured or Node/Pro ID is already registered.",
"reauth_successful": "Re-authentication was successful"
}, },
"error": { "error": {
"cannot_connect": "Failed to connect", "cannot_connect": "Failed to connect",
"general_error": "Unexpected error", "general_error": "Unexpected error",
"invalid_api_key": "Invalid API key", "invalid_api_key": "Invalid API key"
"unable_to_connect": "Unable to connect to Node/Pro unit."
}, },
"step": { "step": {
"geography": { "geography": {
@ -27,6 +27,12 @@
"description": "Monitor a personal AirVisual unit. The password can be retrieved from the unit's UI.", "description": "Monitor a personal AirVisual unit. The password can be retrieved from the unit's UI.",
"title": "Configure an AirVisual Node/Pro" "title": "Configure an AirVisual Node/Pro"
}, },
"reauth_confirm": {
"data": {
"api_key": "API Key"
},
"title": "Re-authenticate AirVisual"
},
"user": { "user": {
"data": { "data": {
"cloud_api": "Geographical Location", "cloud_api": "Geographical Location",

View file

@ -273,7 +273,7 @@ async def test_step_reauth(hass):
with patch( with patch(
"homeassistant.components.airvisual.async_setup_entry", return_value=True "homeassistant.components.airvisual.async_setup_entry", return_value=True
), patch("pyairvisual.air_quality.AirQuality"): ), patch("pyairvisual.air_quality.AirQuality.nearest_city", return_value=True):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_API_KEY: "defgh67890"} result["flow_id"], user_input={CONF_API_KEY: "defgh67890"}
) )