Add generic classes BaseFlowHandler and BaseFlowManager (#111814)

* Add generic classes BaseFlowHandler and BaseFlowManager

* Migrate zwave_js

* Update tests

* Update tests

* Address review comments
This commit is contained in:
Erik Montnemery 2024-02-29 16:52:39 +01:00 committed by GitHub
parent 3a8b6412ed
commit a0e558c457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 341 additions and 273 deletions

View file

@ -242,6 +242,9 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
}
ConfigFlowResult = FlowResult
class ConfigEntry:
"""Hold a configuration entry."""
@ -903,7 +906,7 @@ class ConfigEntry:
@callback
def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]:
) -> Generator[ConfigFlowResult, None, None]:
"""Get any active flows of certain sources for this entry."""
return (
flow
@ -970,9 +973,11 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
def __init__(
self,
hass: HomeAssistant,
@ -1010,7 +1015,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Start a configuration flow."""
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
@ -1024,7 +1029,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and await _support_single_config_entry_only(self.hass, handler)
and self.config_entries.async_entries(handler, include_ignore=False)
):
return FlowResult(
return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow_id,
handler=handler,
@ -1065,7 +1070,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
handler: str,
context: dict,
data: Any,
) -> tuple[data_entry_flow.FlowHandler, FlowResult]:
) -> tuple[ConfigFlow, ConfigFlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
@ -1093,8 +1098,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self._discovery_debouncer.async_shutdown()
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> ConfigFlowResult:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
@ -1128,7 +1133,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and flow.context["source"] != SOURCE_IGNORE
and self.config_entries.async_entries(flow.handler, include_ignore=False)
):
return FlowResult(
return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow.flow_id,
handler=flow.handler,
@ -1213,7 +1218,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
return flow
async def async_post_init(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> None:
"""After a flow is initialised trigger new flow notifications."""
source = flow.context["source"]
@ -1852,7 +1857,13 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured")
class ConfigFlow(data_entry_flow.FlowHandler):
class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
class ConfigFlow(ConfigEntryBaseFlow):
"""Base class for config flows with some helpers."""
def __init_subclass__(cls, *, domain: str | None = None, **kwargs: Any) -> None:
@ -2008,7 +2019,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self,
include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None,
) -> list[data_entry_flow.FlowResult]:
) -> list[ConfigFlowResult]:
"""Return other in progress flows for current domain."""
return [
flw
@ -2020,22 +2031,18 @@ class ConfigFlow(data_entry_flow.FlowHandler):
if flw["flow_id"] != self.flow_id
]
async def async_step_ignore(
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
async def async_step_ignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
"""Ignore this config flow."""
await self.async_set_unique_id(user_input["unique_id"], raise_on_progress=False)
return self.async_create_entry(title=user_input["title"], data={})
async def async_step_unignore(
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
async def async_step_unignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
"""Rediscover a config entry by it's unique_id."""
return self.async_abort(reason="not_implemented")
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initiated by the user."""
return self.async_abort(reason="not_implemented")
@ -2068,14 +2075,14 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def _async_step_discovery_without_unique_id(
self,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by discovery."""
await self._async_handle_discovery_without_unique_id()
return await self.async_step_user()
async def async_step_discovery(
self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by discovery."""
return await self._async_step_discovery_without_unique_id()
@ -2085,7 +2092,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
*,
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Abort the config flow."""
# Remove reauth notification if no reauth flows are in progress
if self.source == SOURCE_REAUTH and not any(
@ -2104,55 +2111,53 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def async_step_bluetooth(
self, discovery_info: BluetoothServiceInfoBleak
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Bluetooth discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by DHCP discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by HASS IO discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_integration_discovery(
self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by integration specific discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_homekit(
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Homekit discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_mqtt(
self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by MQTT discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_ssdp(
self, discovery_info: SsdpServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by SSDP discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_usb(
self, discovery_info: UsbServiceInfo
) -> data_entry_flow.FlowResult:
async def async_step_usb(self, discovery_info: UsbServiceInfo) -> ConfigFlowResult:
"""Handle a flow initialized by USB discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Zeroconf discovery."""
return await self._async_step_discovery_without_unique_id()
@ -2165,7 +2170,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
description: str | None = None,
description_placeholders: Mapping[str, str] | None = None,
options: Mapping[str, Any] | None = None,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Finish config flow and create a config entry."""
result = super().async_create_entry(
title=title,
@ -2175,6 +2180,8 @@ class ConfigFlow(data_entry_flow.FlowHandler):
)
result["options"] = options or {}
result["minor_version"] = self.MINOR_VERSION
result["version"] = self.VERSION
return result
@ -2188,7 +2195,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
data: Mapping[str, Any] | UndefinedType = UNDEFINED,
options: Mapping[str, Any] | UndefinedType = UNDEFINED,
reason: str = "reauth_successful",
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Update config entry, reload config entry and finish config flow."""
result = self.hass.config_entries.async_update_entry(
entry=entry,
@ -2202,9 +2209,11 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.FlowManager):
class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
def _async_get_config_entry(self, config_entry_id: str) -> ConfigEntry:
"""Return config entry or raise if not found."""
entry = self.hass.config_entries.async_get_entry(config_entry_id)
@ -2229,8 +2238,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
return handler.async_get_options_flow(entry)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry.
@ -2249,7 +2258,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
result["result"] = True
return result
async def _async_setup_preview(self, flow: data_entry_flow.FlowHandler) -> None:
async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None:
"""Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler)
await _load_integration(self.hass, entry.domain, {})
@ -2258,7 +2267,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
await flow.async_setup_preview(self.hass)
class OptionsFlow(data_entry_flow.FlowHandler):
class OptionsFlow(ConfigEntryBaseFlow):
"""Base class for config options flows."""
handler: str