Add new helper for matching reauth/reconfigure config flows (#127565)
This commit is contained in:
parent
15a1a83729
commit
2c664efb3c
5 changed files with 102 additions and 39 deletions
|
@ -21,7 +21,6 @@ from homeassistant.config_entries import (
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_SOURCE, CONF_USERNAME
|
from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_SOURCE, CONF_USERNAME
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.data_entry_flow import AbortFlow
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
|
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
|
||||||
|
|
||||||
|
@ -75,7 +74,6 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
|
||||||
_existing_entry_data: Mapping[str, Any] | None = None
|
_existing_entry_data: Mapping[str, Any] | None = None
|
||||||
_existing_entry_unique_id: str | None = None
|
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
@ -85,15 +83,12 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
unique_id = f"{user_input[CONF_REGION]}-{user_input[CONF_USERNAME]}"
|
unique_id = f"{user_input[CONF_REGION]}-{user_input[CONF_USERNAME]}"
|
||||||
|
|
||||||
if self.source not in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
|
|
||||||
await self.async_set_unique_id(unique_id)
|
await self.async_set_unique_id(unique_id)
|
||||||
|
|
||||||
|
if self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}:
|
||||||
|
self._abort_if_unique_id_mismatch(reason="account_mismatch")
|
||||||
|
else:
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
elif (
|
|
||||||
self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}
|
|
||||||
and unique_id != self._existing_entry_unique_id
|
|
||||||
):
|
|
||||||
raise AbortFlow("account_mismatch")
|
|
||||||
|
|
||||||
info = None
|
info = None
|
||||||
try:
|
try:
|
||||||
|
@ -135,16 +130,13 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle configuration by re-auth."""
|
"""Handle configuration by re-auth."""
|
||||||
self._existing_entry_data = entry_data
|
self._existing_entry_data = entry_data
|
||||||
self._existing_entry_unique_id = self._get_reauth_entry().unique_id
|
|
||||||
return await self.async_step_user()
|
return await self.async_step_user()
|
||||||
|
|
||||||
async def async_step_reconfigure(
|
async def async_step_reconfigure(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle a reconfiguration flow initialized by the user."""
|
"""Handle a reconfiguration flow initialized by the user."""
|
||||||
reconfigure_entry = self._get_reconfigure_entry()
|
self._existing_entry_data = self._get_reconfigure_entry().data
|
||||||
self._existing_entry_data = reconfigure_entry.data
|
|
||||||
self._existing_entry_unique_id = reconfigure_entry.unique_id
|
|
||||||
return await self.async_step_user()
|
return await self.async_step_user()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -50,11 +50,9 @@ class SpotifyFlowHandler(
|
||||||
await self.async_set_unique_id(current_user["id"])
|
await self.async_set_unique_id(current_user["id"])
|
||||||
|
|
||||||
if self.source == SOURCE_REAUTH:
|
if self.source == SOURCE_REAUTH:
|
||||||
reauth_entry = self._get_reauth_entry()
|
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
|
||||||
if reauth_entry.data["id"] != current_user["id"]:
|
|
||||||
return self.async_abort(reason="reauth_account_mismatch")
|
|
||||||
return self.async_update_reload_and_abort(
|
return self.async_update_reload_and_abort(
|
||||||
reauth_entry, title=name, data=data
|
self._get_reauth_entry(), title=name, data=data
|
||||||
)
|
)
|
||||||
return self.async_create_entry(title=name, data=data)
|
return self.async_create_entry(title=name, data=data)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult
|
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
|
||||||
from .const import DOMAIN, LOGGER
|
from .const import DOMAIN, LOGGER
|
||||||
|
@ -21,7 +21,6 @@ class OAuth2FlowHandler(
|
||||||
"""Config flow to handle Tesla Fleet API OAuth2 authentication."""
|
"""Config flow to handle Tesla Fleet API OAuth2 authentication."""
|
||||||
|
|
||||||
DOMAIN = DOMAIN
|
DOMAIN = DOMAIN
|
||||||
reauth_entry: ConfigEntry | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self) -> logging.Logger:
|
def logger(self) -> logging.Logger:
|
||||||
|
@ -50,32 +49,19 @@ class OAuth2FlowHandler(
|
||||||
)
|
)
|
||||||
uid = token["sub"]
|
uid = token["sub"]
|
||||||
|
|
||||||
if not self.reauth_entry:
|
|
||||||
await self.async_set_unique_id(uid)
|
await self.async_set_unique_id(uid)
|
||||||
|
if self.source == SOURCE_REAUTH:
|
||||||
|
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
|
||||||
|
return self.async_update_reload_and_abort(
|
||||||
|
self._get_reauth_entry(), data=data
|
||||||
|
)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
return self.async_create_entry(title=uid, data=data)
|
return self.async_create_entry(title=uid, data=data)
|
||||||
|
|
||||||
if self.reauth_entry.unique_id == uid:
|
|
||||||
self.hass.config_entries.async_update_entry(
|
|
||||||
self.reauth_entry,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
|
|
||||||
return self.async_abort(reason="reauth_successful")
|
|
||||||
|
|
||||||
return self.async_abort(
|
|
||||||
reason="reauth_account_mismatch",
|
|
||||||
description_placeholders={"title": self.reauth_entry.title},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_step_reauth(
|
async def async_step_reauth(
|
||||||
self, entry_data: Mapping[str, Any]
|
self, entry_data: Mapping[str, Any]
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Perform reauth upon an API authentication error."""
|
"""Perform reauth upon an API authentication error."""
|
||||||
self.reauth_entry = self.hass.config_entries.async_get_entry(
|
|
||||||
self.context["entry_id"]
|
|
||||||
)
|
|
||||||
return await self.async_step_reauth_confirm()
|
return await self.async_step_reauth_confirm()
|
||||||
|
|
||||||
async def async_step_reauth_confirm(
|
async def async_step_reauth_confirm(
|
||||||
|
|
|
@ -2432,6 +2432,26 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
self._async_current_entries(include_ignore=False), match_dict
|
self._async_current_entries(include_ignore=False), match_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _abort_if_unique_id_mismatch(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
reason: str = "unique_id_mismatch",
|
||||||
|
) -> None:
|
||||||
|
"""Abort if the unique ID does not match the reauth/reconfigure context.
|
||||||
|
|
||||||
|
Requires strings.json entry corresponding to the `reason` parameter
|
||||||
|
in user visible flows.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.source == SOURCE_REAUTH
|
||||||
|
and self._get_reauth_entry().unique_id != self.unique_id
|
||||||
|
) or (
|
||||||
|
self.source == SOURCE_RECONFIGURE
|
||||||
|
and self._get_reconfigure_entry().unique_id != self.unique_id
|
||||||
|
):
|
||||||
|
raise data_entry_flow.AbortFlow(reason)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _abort_if_unique_id_configured(
|
def _abort_if_unique_id_configured(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6677,6 +6677,73 @@ async def test_reauth_helper_alignment(
|
||||||
assert helper_flow_init_data == reauth_flow_init_data
|
assert helper_flow_init_data == reauth_flow_init_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("original_unique_id", "new_unique_id", "reason"),
|
||||||
|
[
|
||||||
|
("unique", "unique", "success"),
|
||||||
|
(None, None, "success"),
|
||||||
|
("unique", "new", "unique_id_mismatch"),
|
||||||
|
("unique", None, "unique_id_mismatch"),
|
||||||
|
(None, "new", "unique_id_mismatch"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"source",
|
||||||
|
[config_entries.SOURCE_REAUTH, config_entries.SOURCE_RECONFIGURE],
|
||||||
|
)
|
||||||
|
async def test_abort_if_unique_id_mismatch(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
source: str,
|
||||||
|
original_unique_id: str | None,
|
||||||
|
new_unique_id: str | None,
|
||||||
|
reason: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test to check if_unique_id_mismatch behavior."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
title="From config flow",
|
||||||
|
domain="test",
|
||||||
|
entry_id="01J915Q6T9F6G5V0QJX6HBC94T",
|
||||||
|
data={"host": "any", "port": 123},
|
||||||
|
unique_id=original_unique_id,
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
mock_setup_entry = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
|
||||||
|
mock_platform(hass, "test.config_flow", None)
|
||||||
|
|
||||||
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
async def async_step_user(self, user_input=None):
|
||||||
|
"""Test user step."""
|
||||||
|
return await self._async_step_confirm()
|
||||||
|
|
||||||
|
async def async_step_reauth(self, entry_data):
|
||||||
|
"""Test reauth step."""
|
||||||
|
return await self._async_step_confirm()
|
||||||
|
|
||||||
|
async def async_step_reconfigure(self, user_input=None):
|
||||||
|
"""Test reauth step."""
|
||||||
|
return await self._async_step_confirm()
|
||||||
|
|
||||||
|
async def _async_step_confirm(self):
|
||||||
|
"""Confirm input."""
|
||||||
|
await self.async_set_unique_id(new_unique_id)
|
||||||
|
self._abort_if_unique_id_mismatch()
|
||||||
|
return self.async_abort(reason="success")
|
||||||
|
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
if source == config_entries.SOURCE_REAUTH:
|
||||||
|
result = await entry.start_reauth_flow(hass)
|
||||||
|
elif source == config_entries.SOURCE_RECONFIGURE:
|
||||||
|
result = await entry.start_reconfigure_flow(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == reason
|
||||||
|
|
||||||
|
|
||||||
def test_state_not_stored_in_storage() -> None:
|
def test_state_not_stored_in_storage() -> None:
|
||||||
"""Test that state is not stored in storage.
|
"""Test that state is not stored in storage.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue