Add generic classes BaseFlowHandler and BaseFlowManager (#111814)

* Add generic classes BaseFlowHandler and BaseFlowManager

* Migrate zwave_js

* Update tests

* Update tests

* Address review comments
This commit is contained in:
Erik Montnemery 2024-02-29 16:52:39 +01:00 committed by GitHub
parent 3a8b6412ed
commit a0e558c457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 341 additions and 273 deletions

View file

@ -91,6 +91,8 @@ async def auth_manager_from_config(
class AuthManagerFlowManager(data_entry_flow.FlowManager): class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows.""" """Manage authentication flows."""
_flow_result = FlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None: def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows.""" """Init auth manager flows."""
super().__init__(hass) super().__init__(hass)
@ -110,7 +112,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context) return await auth_provider.async_login_flow(context)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: FlowResult
) -> FlowResult: ) -> FlowResult:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
flow = cast(LoginFlow, flow) flow = cast(LoginFlow, flow)

View file

@ -96,6 +96,8 @@ class MultiFactorAuthModule:
class SetupFlow(data_entry_flow.FlowHandler): class SetupFlow(data_entry_flow.FlowHandler):
"""Handler for the setup flow.""" """Handler for the setup flow."""
_flow_result = FlowResult
def __init__( def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
) -> None: ) -> None:

View file

@ -184,6 +184,8 @@ async def load_auth_provider_module(
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow.""" """Handler for the login flow."""
_flow_result = FlowResult
def __init__(self, auth_provider: AuthProvider) -> None: def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider

View file

@ -38,6 +38,8 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager): class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows.""" """Manage multi factor authentication flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( # type: ignore[override] async def async_create_flow( # type: ignore[override]
self, self,
handler_key: str, handler_key: str,
@ -54,7 +56,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
return await mfa_module.async_setup_flow(user_id) return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow.""" """Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result) _LOGGER.debug("flow_result: %s", result)

View file

@ -48,9 +48,11 @@ class ConfirmRepairFlow(RepairsFlow):
) )
class RepairsFlowManager(data_entry_flow.FlowManager): class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowResult]):
"""Manage repairs flows.""" """Manage repairs flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( async def async_create_flow(
self, self,
handler_key: str, handler_key: str,
@ -82,7 +84,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
return flow return flow
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Complete a fix flow.""" """Complete a fix flow."""
if result.get("type") != data_entry_flow.FlowResultType.ABORT: if result.get("type") != data_entry_flow.FlowResultType.ABORT:

View file

@ -7,9 +7,11 @@ from homeassistant import data_entry_flow
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
class RepairsFlow(data_entry_flow.FlowHandler): class RepairsFlow(data_entry_flow.BaseFlowHandler[data_entry_flow.FlowResult]):
"""Handle a flow for fixing an issue.""" """Handle a flow for fixing an issue."""
_flow_result = data_entry_flow.FlowResult
issue_id: str issue_id: str
data: dict[str, str | int | float | None] | None data: dict[str, str | int | float | None] | None

View file

@ -11,7 +11,6 @@ from serial.tools import list_ports
import voluptuous as vol import voluptuous as vol
from zwave_js_server.version import VersionInfo, get_server_version from zwave_js_server.version import VersionInfo, get_server_version
from homeassistant import config_entries, exceptions
from homeassistant.components import usb from homeassistant.components import usb
from homeassistant.components.hassio import ( from homeassistant.components.hassio import (
AddonError, AddonError,
@ -22,14 +21,21 @@ from homeassistant.components.hassio import (
is_hassio, is_hassio,
) )
from homeassistant.components.zeroconf import ZeroconfServiceInfo from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import (
SOURCE_USB,
ConfigEntriesFlowManager,
ConfigEntry,
ConfigEntryBaseFlow,
ConfigEntryState,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowManager,
)
from homeassistant.const import CONF_NAME, CONF_URL from homeassistant.const import CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import ( from homeassistant.data_entry_flow import AbortFlow, BaseFlowManager
AbortFlow, from homeassistant.exceptions import HomeAssistantError
FlowHandler,
FlowManager,
FlowResult,
)
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from . import disconnect_client from . import disconnect_client
@ -156,7 +162,7 @@ async def async_get_usb_ports(hass: HomeAssistant) -> dict[str, str]:
return await hass.async_add_executor_job(get_usb_ports) return await hass.async_add_executor_job(get_usb_ports)
class BaseZwaveJSFlow(FlowHandler, ABC): class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
"""Represent the base config flow for Z-Wave JS.""" """Represent the base config flow for Z-Wave JS."""
def __init__(self) -> None: def __init__(self) -> None:
@ -176,12 +182,12 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
@property @property
@abstractmethod @abstractmethod
def flow_manager(self) -> FlowManager: def flow_manager(self) -> BaseFlowManager:
"""Return the flow manager of the flow.""" """Return the flow manager of the flow."""
async def async_step_install_addon( async def async_step_install_addon(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Install Z-Wave JS add-on.""" """Install Z-Wave JS add-on."""
if not self.install_task: if not self.install_task:
self.install_task = self.hass.async_create_task(self._async_install_addon()) self.install_task = self.hass.async_create_task(self._async_install_addon())
@ -207,13 +213,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
async def async_step_install_failed( async def async_step_install_failed(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Add-on installation failed.""" """Add-on installation failed."""
return self.async_abort(reason="addon_install_failed") return self.async_abort(reason="addon_install_failed")
async def async_step_start_addon( async def async_step_start_addon(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Start Z-Wave JS add-on.""" """Start Z-Wave JS add-on."""
if not self.start_task: if not self.start_task:
self.start_task = self.hass.async_create_task(self._async_start_addon()) self.start_task = self.hass.async_create_task(self._async_start_addon())
@ -237,7 +243,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
async def async_step_start_failed( async def async_step_start_failed(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Add-on start failed.""" """Add-on start failed."""
return self.async_abort(reason="addon_start_failed") return self.async_abort(reason="addon_start_failed")
@ -275,13 +281,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
@abstractmethod @abstractmethod
async def async_step_configure_addon( async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on.""" """Ask for config for Z-Wave JS add-on."""
@abstractmethod @abstractmethod
async def async_step_finish_addon_setup( async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry. """Prepare info needed to complete the config entry.
Get add-on discovery info and server version info. Get add-on discovery info and server version info.
@ -325,7 +331,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
return discovery_info_config return discovery_info_config
class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): class ZWaveJSConfigFlow(BaseZwaveJSFlow, ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Z-Wave JS.""" """Handle a config flow for Z-Wave JS."""
VERSION = 1 VERSION = 1
@ -338,19 +344,19 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
self._usb_discovery = False self._usb_discovery = False
@property @property
def flow_manager(self) -> config_entries.ConfigEntriesFlowManager: def flow_manager(self) -> ConfigEntriesFlowManager:
"""Return the correct flow manager.""" """Return the correct flow manager."""
return self.hass.config_entries.flow return self.hass.config_entries.flow
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(
config_entry: config_entries.ConfigEntry, config_entry: ConfigEntry,
) -> OptionsFlowHandler: ) -> OptionsFlowHandler:
"""Return the options flow.""" """Return the options flow."""
return OptionsFlowHandler(config_entry) return OptionsFlowHandler(config_entry)
async def async_step_import(self, data: dict[str, Any]) -> FlowResult: async def async_step_import(self, data: dict[str, Any]) -> ConfigFlowResult:
"""Handle imported data. """Handle imported data.
This step will be used when importing data This step will be used when importing data
@ -364,7 +370,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
if is_hassio(self.hass): if is_hassio(self.hass):
return await self.async_step_on_supervisor() return await self.async_step_on_supervisor()
@ -373,7 +379,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle zeroconf discovery.""" """Handle zeroconf discovery."""
home_id = str(discovery_info.properties["homeId"]) home_id = str(discovery_info.properties["homeId"])
await self.async_set_unique_id(home_id) await self.async_set_unique_id(home_id)
@ -384,7 +390,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_zeroconf_confirm( async def async_step_zeroconf_confirm(
self, user_input: dict | None = None self, user_input: dict | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Confirm the setup.""" """Confirm the setup."""
if user_input is not None: if user_input is not None:
return await self.async_step_manual({CONF_URL: self.ws_address}) return await self.async_step_manual({CONF_URL: self.ws_address})
@ -398,7 +404,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
}, },
) )
async def async_step_usb(self, discovery_info: usb.UsbServiceInfo) -> FlowResult: async def async_step_usb(
self, discovery_info: usb.UsbServiceInfo
) -> ConfigFlowResult:
"""Handle USB Discovery.""" """Handle USB Discovery."""
if not is_hassio(self.hass): if not is_hassio(self.hass):
return self.async_abort(reason="discovery_requires_supervisor") return self.async_abort(reason="discovery_requires_supervisor")
@ -441,7 +449,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_usb_confirm( async def async_step_usb_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle USB Discovery confirmation.""" """Handle USB Discovery confirmation."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -455,7 +463,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_manual( async def async_step_manual(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a manual configuration.""" """Handle a manual configuration."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -491,7 +499,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
step_id="manual", data_schema=get_manual_schema(user_input), errors=errors step_id="manual", data_schema=get_manual_schema(user_input), errors=errors
) )
async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult: async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Receive configuration from add-on discovery info. """Receive configuration from add-on discovery info.
This flow is triggered by the Z-Wave JS add-on. This flow is triggered by the Z-Wave JS add-on.
@ -517,7 +527,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_hassio_confirm( async def async_step_hassio_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Confirm the add-on discovery.""" """Confirm the add-on discovery."""
if user_input is not None: if user_input is not None:
return await self.async_step_on_supervisor( return await self.async_step_on_supervisor(
@ -528,7 +538,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_on_supervisor( async def async_step_on_supervisor(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle logic when on Supervisor host.""" """Handle logic when on Supervisor host."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -563,7 +573,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_configure_addon( async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on.""" """Ask for config for Z-Wave JS add-on."""
addon_info = await self._async_get_addon_info() addon_info = await self._async_get_addon_info()
addon_config = addon_info.options addon_config = addon_info.options
@ -628,7 +638,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_finish_addon_setup( async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry. """Prepare info needed to complete the config entry.
Get add-on discovery info and server version info. Get add-on discovery info and server version info.
@ -638,7 +648,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
discovery_info = await self._async_get_addon_discovery_info() discovery_info = await self._async_get_addon_discovery_info()
self.ws_address = f"ws://{discovery_info['host']}:{discovery_info['port']}" self.ws_address = f"ws://{discovery_info['host']}:{discovery_info['port']}"
if not self.unique_id or self.context["source"] == config_entries.SOURCE_USB: if not self.unique_id or self.context["source"] == SOURCE_USB:
if not self.version_info: if not self.version_info:
try: try:
self.version_info = await async_get_version_info( self.version_info = await async_get_version_info(
@ -664,7 +674,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
return self._async_create_entry_from_vars() return self._async_create_entry_from_vars()
@callback @callback
def _async_create_entry_from_vars(self) -> FlowResult: def _async_create_entry_from_vars(self) -> ConfigFlowResult:
"""Return a config entry for the flow.""" """Return a config entry for the flow."""
# Abort any other flows that may be in progress # Abort any other flows that may be in progress
for progress in self._async_in_progress(): for progress in self._async_in_progress():
@ -685,10 +695,10 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
) )
class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): class OptionsFlowHandler(BaseZwaveJSFlow, OptionsFlow):
"""Handle an options flow for Z-Wave JS.""" """Handle an options flow for Z-Wave JS."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None: def __init__(self, config_entry: ConfigEntry) -> None:
"""Set up the options flow.""" """Set up the options flow."""
super().__init__() super().__init__()
self.config_entry = config_entry self.config_entry = config_entry
@ -696,7 +706,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
self.revert_reason: str | None = None self.revert_reason: str | None = None
@property @property
def flow_manager(self) -> config_entries.OptionsFlowManager: def flow_manager(self) -> OptionsFlowManager:
"""Return the correct flow manager.""" """Return the correct flow manager."""
return self.hass.config_entries.options return self.hass.config_entries.options
@ -707,7 +717,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
if is_hassio(self.hass): if is_hassio(self.hass):
return await self.async_step_on_supervisor() return await self.async_step_on_supervisor()
@ -716,7 +726,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_manual( async def async_step_manual(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a manual configuration.""" """Handle a manual configuration."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -759,7 +769,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_on_supervisor( async def async_step_on_supervisor(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle logic when on Supervisor host.""" """Handle logic when on Supervisor host."""
if user_input is None: if user_input is None:
return self.async_show_form( return self.async_show_form(
@ -780,7 +790,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_configure_addon( async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on.""" """Ask for config for Z-Wave JS add-on."""
addon_info = await self._async_get_addon_info() addon_info = await self._async_get_addon_info()
addon_config = addon_info.options addon_config = addon_info.options
@ -819,7 +829,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
if ( if (
self.config_entry.data.get(CONF_USE_ADDON) self.config_entry.data.get(CONF_USE_ADDON)
and self.config_entry.state == config_entries.ConfigEntryState.LOADED and self.config_entry.state == ConfigEntryState.LOADED
): ):
# Disconnect integration before restarting add-on. # Disconnect integration before restarting add-on.
await disconnect_client(self.hass, self.config_entry) await disconnect_client(self.hass, self.config_entry)
@ -868,13 +878,13 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_start_failed( async def async_step_start_failed(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Add-on start failed.""" """Add-on start failed."""
return await self.async_revert_addon_config(reason="addon_start_failed") return await self.async_revert_addon_config(reason="addon_start_failed")
async def async_step_finish_addon_setup( async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry update. """Prepare info needed to complete the config entry update.
Get add-on discovery info and server version info. Get add-on discovery info and server version info.
@ -918,7 +928,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
self.hass.config_entries.async_schedule_reload(self.config_entry.entry_id) self.hass.config_entries.async_schedule_reload(self.config_entry.entry_id)
return self.async_create_entry(title=TITLE, data={}) return self.async_create_entry(title=TITLE, data={})
async def async_revert_addon_config(self, reason: str) -> FlowResult: async def async_revert_addon_config(self, reason: str) -> ConfigFlowResult:
"""Abort the options flow. """Abort the options flow.
If the add-on options have been changed, revert those and restart add-on. If the add-on options have been changed, revert those and restart add-on.
@ -944,11 +954,11 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
return await self.async_step_configure_addon(addon_config_input) return await self.async_step_configure_addon(addon_config_input)
class CannotConnect(exceptions.HomeAssistantError): class CannotConnect(HomeAssistantError):
"""Indicate connection error.""" """Indicate connection error."""
class InvalidInput(exceptions.HomeAssistantError): class InvalidInput(HomeAssistantError):
"""Error to indicate input data is invalid.""" """Error to indicate input data is invalid."""
def __init__(self, error: str) -> None: def __init__(self, error: str) -> None:

View file

@ -242,6 +242,9 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
} }
ConfigFlowResult = FlowResult
class ConfigEntry: class ConfigEntry:
"""Hold a configuration entry.""" """Hold a configuration entry."""
@ -903,7 +906,7 @@ class ConfigEntry:
@callback @callback
def async_get_active_flows( def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str] self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]: ) -> Generator[ConfigFlowResult, None, None]:
"""Get any active flows of certain sources for this entry.""" """Get any active flows of certain sources for this entry."""
return ( return (
flow flow
@ -970,9 +973,11 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled.""" """Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager): class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Manage all the config entry flows that are in progress.""" """Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -1010,7 +1015,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
async def async_init( async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Start a configuration flow.""" """Start a configuration flow."""
if not context or "source" not in context: if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set") raise KeyError("Context not set or doesn't have a source set")
@ -1024,7 +1029,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and await _support_single_config_entry_only(self.hass, handler) and await _support_single_config_entry_only(self.hass, handler)
and self.config_entries.async_entries(handler, include_ignore=False) and self.config_entries.async_entries(handler, include_ignore=False)
): ):
return FlowResult( return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT, type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow_id, flow_id=flow_id,
handler=handler, handler=handler,
@ -1065,7 +1070,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
handler: str, handler: str,
context: dict, context: dict,
data: Any, data: Any,
) -> tuple[data_entry_flow.FlowHandler, FlowResult]: ) -> tuple[ConfigFlow, ConfigFlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown.""" """Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data) flow = await self.async_create_flow(handler, context=context, data=data)
if not flow: if not flow:
@ -1093,8 +1098,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self._discovery_debouncer.async_shutdown() self._discovery_debouncer.async_shutdown()
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow) flow = cast(ConfigFlow, flow)
@ -1128,7 +1133,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and flow.context["source"] != SOURCE_IGNORE and flow.context["source"] != SOURCE_IGNORE
and self.config_entries.async_entries(flow.handler, include_ignore=False) and self.config_entries.async_entries(flow.handler, include_ignore=False)
): ):
return FlowResult( return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT, type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow.flow_id, flow_id=flow.flow_id,
handler=flow.handler, handler=flow.handler,
@ -1213,7 +1218,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
return flow return flow
async def async_post_init( async def async_post_init(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> None: ) -> None:
"""After a flow is initialised trigger new flow notifications.""" """After a flow is initialised trigger new flow notifications."""
source = flow.context["source"] source = flow.context["source"]
@ -1852,7 +1857,13 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured") raise data_entry_flow.AbortFlow("already_configured")
class ConfigFlow(data_entry_flow.FlowHandler): class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
class ConfigFlow(ConfigEntryBaseFlow):
"""Base class for config flows with some helpers.""" """Base class for config flows with some helpers."""
def __init_subclass__(cls, *, domain: str | None = None, **kwargs: Any) -> None: def __init_subclass__(cls, *, domain: str | None = None, **kwargs: Any) -> None:
@ -2008,7 +2019,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self, self,
include_uninitialized: bool = False, include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None, match_context: dict[str, Any] | None = None,
) -> list[data_entry_flow.FlowResult]: ) -> list[ConfigFlowResult]:
"""Return other in progress flows for current domain.""" """Return other in progress flows for current domain."""
return [ return [
flw flw
@ -2020,22 +2031,18 @@ class ConfigFlow(data_entry_flow.FlowHandler):
if flw["flow_id"] != self.flow_id if flw["flow_id"] != self.flow_id
] ]
async def async_step_ignore( async def async_step_ignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
"""Ignore this config flow.""" """Ignore this config flow."""
await self.async_set_unique_id(user_input["unique_id"], raise_on_progress=False) await self.async_set_unique_id(user_input["unique_id"], raise_on_progress=False)
return self.async_create_entry(title=user_input["title"], data={}) return self.async_create_entry(title=user_input["title"], data={})
async def async_step_unignore( async def async_step_unignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
"""Rediscover a config entry by it's unique_id.""" """Rediscover a config entry by it's unique_id."""
return self.async_abort(reason="not_implemented") return self.async_abort(reason="not_implemented")
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initiated by the user.""" """Handle a flow initiated by the user."""
return self.async_abort(reason="not_implemented") return self.async_abort(reason="not_implemented")
@ -2068,14 +2075,14 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def _async_step_discovery_without_unique_id( async def _async_step_discovery_without_unique_id(
self, self,
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
await self._async_handle_discovery_without_unique_id() await self._async_handle_discovery_without_unique_id()
return await self.async_step_user() return await self.async_step_user()
async def async_step_discovery( async def async_step_discovery(
self, discovery_info: DiscoveryInfoType self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
@ -2085,7 +2092,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
*, *,
reason: str, reason: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Abort the config flow.""" """Abort the config flow."""
# Remove reauth notification if no reauth flows are in progress # Remove reauth notification if no reauth flows are in progress
if self.source == SOURCE_REAUTH and not any( if self.source == SOURCE_REAUTH and not any(
@ -2104,55 +2111,53 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def async_step_bluetooth( async def async_step_bluetooth(
self, discovery_info: BluetoothServiceInfoBleak self, discovery_info: BluetoothServiceInfoBleak
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by Bluetooth discovery.""" """Handle a flow initialized by Bluetooth discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_dhcp( async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo self, discovery_info: DhcpServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by DHCP discovery.""" """Handle a flow initialized by DHCP discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_hassio( async def async_step_hassio(
self, discovery_info: HassioServiceInfo self, discovery_info: HassioServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by HASS IO discovery.""" """Handle a flow initialized by HASS IO discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_integration_discovery( async def async_step_integration_discovery(
self, discovery_info: DiscoveryInfoType self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by integration specific discovery.""" """Handle a flow initialized by integration specific discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_homekit( async def async_step_homekit(
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by Homekit discovery.""" """Handle a flow initialized by Homekit discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_mqtt( async def async_step_mqtt(
self, discovery_info: MqttServiceInfo self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by MQTT discovery.""" """Handle a flow initialized by MQTT discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_ssdp( async def async_step_ssdp(
self, discovery_info: SsdpServiceInfo self, discovery_info: SsdpServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by SSDP discovery.""" """Handle a flow initialized by SSDP discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_usb( async def async_step_usb(self, discovery_info: UsbServiceInfo) -> ConfigFlowResult:
self, discovery_info: UsbServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by USB discovery.""" """Handle a flow initialized by USB discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Handle a flow initialized by Zeroconf discovery.""" """Handle a flow initialized by Zeroconf discovery."""
return await self._async_step_discovery_without_unique_id() return await self._async_step_discovery_without_unique_id()
@ -2165,7 +2170,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
description: str | None = None, description: str | None = None,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
options: Mapping[str, Any] | None = None, options: Mapping[str, Any] | None = None,
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Finish config flow and create a config entry.""" """Finish config flow and create a config entry."""
result = super().async_create_entry( result = super().async_create_entry(
title=title, title=title,
@ -2175,6 +2180,8 @@ class ConfigFlow(data_entry_flow.FlowHandler):
) )
result["options"] = options or {} result["options"] = options or {}
result["minor_version"] = self.MINOR_VERSION
result["version"] = self.VERSION
return result return result
@ -2188,7 +2195,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
data: Mapping[str, Any] | UndefinedType = UNDEFINED, data: Mapping[str, Any] | UndefinedType = UNDEFINED,
options: Mapping[str, Any] | UndefinedType = UNDEFINED, options: Mapping[str, Any] | UndefinedType = UNDEFINED,
reason: str = "reauth_successful", reason: str = "reauth_successful",
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Update config entry, reload config entry and finish config flow.""" """Update config entry, reload config entry and finish config flow."""
result = self.hass.config_entries.async_update_entry( result = self.hass.config_entries.async_update_entry(
entry=entry, entry=entry,
@ -2202,9 +2209,11 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason=reason) return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.FlowManager): class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
def _async_get_config_entry(self, config_entry_id: str) -> ConfigEntry: def _async_get_config_entry(self, config_entry_id: str) -> ConfigEntry:
"""Return config entry or raise if not found.""" """Return config entry or raise if not found."""
entry = self.hass.config_entries.async_get_entry(config_entry_id) entry = self.hass.config_entries.async_get_entry(config_entry_id)
@ -2229,8 +2238,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
return handler.async_get_options_flow(entry) return handler.async_get_options_flow(entry)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> data_entry_flow.FlowResult: ) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry. Flow.handler and entry_id is the same thing to map flow with entry.
@ -2249,7 +2258,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
result["result"] = True result["result"] = True
return result return result
async def _async_setup_preview(self, flow: data_entry_flow.FlowHandler) -> None: async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None:
"""Set up preview for an option flow handler.""" """Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler) entry = self._async_get_config_entry(flow.handler)
await _load_integration(self.hass, entry.domain, {}) await _load_integration(self.hass, entry.domain, {})
@ -2258,7 +2267,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
await flow.async_setup_preview(self.hass) await flow.async_setup_preview(self.hass)
class OptionsFlow(data_entry_flow.FlowHandler): class OptionsFlow(ConfigEntryBaseFlow):
"""Base class for config options flows.""" """Base class for config options flows."""
handler: str handler: str

View file

@ -11,7 +11,7 @@ from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Required, TypedDict from typing import Any, Generic, Required, TypedDict, TypeVar
import voluptuous as vol import voluptuous as vol
@ -75,6 +75,7 @@ FLOW_NOT_COMPLETE_STEPS = {
FlowResultType.MENU, FlowResultType.MENU,
} }
STEP_ID_OPTIONAL_STEPS = { STEP_ID_OPTIONAL_STEPS = {
FlowResultType.EXTERNAL_STEP, FlowResultType.EXTERNAL_STEP,
FlowResultType.FORM, FlowResultType.FORM,
@ -83,6 +84,9 @@ STEP_ID_OPTIONAL_STEPS = {
} }
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult")
@dataclass(slots=True) @dataclass(slots=True)
class BaseServiceInfo: class BaseServiceInfo:
"""Base class for discovery ServiceInfo.""" """Base class for discovery ServiceInfo."""
@ -163,26 +167,6 @@ class FlowResult(TypedDict, total=False):
version: int version: int
@callback
def _async_flow_handler_to_flow_result(
flows: Iterable[FlowHandler], include_uninitialized: bool
) -> list[FlowResult]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
for flow in flows:
if not include_uninitialized and flow.cur_step is None:
continue
result = FlowResult(
flow_id=flow.flow_id,
handler=flow.handler,
context=flow.context,
)
if flow.cur_step:
result["step_id"] = flow.cur_step["step_id"]
results.append(result)
return results
def _map_error_to_schema_errors( def _map_error_to_schema_errors(
schema_errors: dict[str, Any], schema_errors: dict[str, Any],
error: vol.Invalid, error: vol.Invalid,
@ -206,9 +190,11 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC): class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT]
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -216,9 +202,9 @@ class FlowManager(abc.ABC):
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._preview: set[str] = set() self._preview: set[str] = set()
self._progress: dict[str, FlowHandler] = {} self._progress: dict[str, BaseFlowHandler] = {}
self._handler_progress_index: dict[str, set[FlowHandler]] = {} self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {}
self._init_data_process_index: dict[type, set[FlowHandler]] = {} self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {}
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
@ -227,7 +213,7 @@ class FlowManager(abc.ABC):
*, *,
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> FlowHandler: ) -> BaseFlowHandler[_FlowResultT]:
"""Create a flow for specified handler. """Create a flow for specified handler.
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
@ -235,11 +221,13 @@ class FlowManager(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def async_finish_flow( async def async_finish_flow(
self, flow: FlowHandler, result: FlowResult self, flow: BaseFlowHandler, result: _FlowResultT
) -> FlowResult: ) -> _FlowResultT:
"""Finish a data entry flow.""" """Finish a data entry flow."""
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None: async def async_post_init(
self, flow: BaseFlowHandler, result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@callback @callback
@ -262,16 +250,16 @@ class FlowManager(abc.ABC):
return False return False
@callback @callback
def async_get(self, flow_id: str) -> FlowResult: def async_get(self, flow_id: str) -> _FlowResultT:
"""Return a flow in progress as a partial FlowResult.""" """Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None: if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow raise UnknownFlow
return _async_flow_handler_to_flow_result([flow], False)[0] return self._async_flow_handler_to_flow_result([flow], False)[0]
@callback @callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]: def async_progress(self, include_uninitialized: bool = False) -> list[_FlowResultT]:
"""Return the flows in progress as a partial FlowResult.""" """Return the flows in progress as a partial FlowResult."""
return _async_flow_handler_to_flow_result( return self._async_flow_handler_to_flow_result(
self._progress.values(), include_uninitialized self._progress.values(), include_uninitialized
) )
@ -281,13 +269,13 @@ class FlowManager(abc.ABC):
handler: str, handler: str,
include_uninitialized: bool = False, include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None, match_context: dict[str, Any] | None = None,
) -> list[FlowResult]: ) -> list[_FlowResultT]:
"""Return the flows in progress by handler as a partial FlowResult. """Return the flows in progress by handler as a partial FlowResult.
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
is a superset of match_context. is a superset of match_context.
""" """
return _async_flow_handler_to_flow_result( return self._async_flow_handler_to_flow_result(
self._async_progress_by_handler(handler, match_context), self._async_progress_by_handler(handler, match_context),
include_uninitialized, include_uninitialized,
) )
@ -298,9 +286,9 @@ class FlowManager(abc.ABC):
init_data_type: type, init_data_type: type,
matcher: Callable[[Any], bool], matcher: Callable[[Any], bool],
include_uninitialized: bool = False, include_uninitialized: bool = False,
) -> list[FlowResult]: ) -> list[_FlowResultT]:
"""Return flows in progress init matching by data type as a partial FlowResult.""" """Return flows in progress init matching by data type as a partial FlowResult."""
return _async_flow_handler_to_flow_result( return self._async_flow_handler_to_flow_result(
( (
progress progress
for progress in self._init_data_process_index.get(init_data_type, set()) for progress in self._init_data_process_index.get(init_data_type, set())
@ -312,7 +300,7 @@ class FlowManager(abc.ABC):
@callback @callback
def _async_progress_by_handler( def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None self, handler: str, match_context: dict[str, Any] | None
) -> list[FlowHandler]: ) -> list[BaseFlowHandler[_FlowResultT]]:
"""Return the flows in progress by handler. """Return the flows in progress by handler.
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
@ -329,7 +317,7 @@ class FlowManager(abc.ABC):
async def async_init( async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult: ) -> _FlowResultT:
"""Start a data entry flow.""" """Start a data entry flow."""
if context is None: if context is None:
context = {} context = {}
@ -352,9 +340,9 @@ class FlowManager(abc.ABC):
async def async_configure( async def async_configure(
self, flow_id: str, user_input: dict | None = None self, flow_id: str, user_input: dict | None = None
) -> FlowResult: ) -> _FlowResultT:
"""Continue a data entry flow.""" """Continue a data entry flow."""
result: FlowResult | None = None result: _FlowResultT | None = None
while not result or result["type"] == FlowResultType.SHOW_PROGRESS_DONE: while not result or result["type"] == FlowResultType.SHOW_PROGRESS_DONE:
result = await self._async_configure(flow_id, user_input) result = await self._async_configure(flow_id, user_input)
flow = self._progress.get(flow_id) flow = self._progress.get(flow_id)
@ -364,7 +352,7 @@ class FlowManager(abc.ABC):
async def _async_configure( async def _async_configure(
self, flow_id: str, user_input: dict | None = None self, flow_id: str, user_input: dict | None = None
) -> FlowResult: ) -> _FlowResultT:
"""Continue a data entry flow.""" """Continue a data entry flow."""
if (flow := self._progress.get(flow_id)) is None: if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow raise UnknownFlow
@ -458,7 +446,7 @@ class FlowManager(abc.ABC):
self._async_remove_flow_progress(flow_id) self._async_remove_flow_progress(flow_id)
@callback @callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None: def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None:
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -467,7 +455,9 @@ class FlowManager(abc.ABC):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow) self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback @callback
def _async_remove_flow_from_index(self, flow: FlowHandler) -> None: def _async_remove_flow_from_index(
self, flow: BaseFlowHandler[_FlowResultT]
) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -492,17 +482,24 @@ class FlowManager(abc.ABC):
_LOGGER.exception("Error removing %s flow: %s", flow.handler, err) _LOGGER.exception("Error removing %s flow: %s", flow.handler, err)
async def _async_handle_step( async def _async_handle_step(
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None self,
) -> FlowResult: flow: BaseFlowHandler[_FlowResultT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
"""Handle a step of a flow.""" """Handle a step of a flow."""
self._raise_if_step_does_not_exist(flow, step_id) self._raise_if_step_does_not_exist(flow, step_id)
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
try: try:
result: FlowResult = await getattr(flow, method)(user_input) result: _FlowResultT = await getattr(flow, method)(user_input)
except AbortFlow as err: except AbortFlow as err:
result = _create_abort_data( result = self._flow_result(
flow.flow_id, flow.handler, err.reason, err.description_placeholders type=FlowResultType.ABORT,
flow_id=flow.flow_id,
handler=flow.handler,
reason=err.reason,
description_placeholders=err.description_placeholders,
) )
# Setup the flow handler's preview if needed # Setup the flow handler's preview if needed
@ -521,7 +518,8 @@ class FlowManager(abc.ABC):
if ( if (
result["type"] == FlowResultType.SHOW_PROGRESS result["type"] == FlowResultType.SHOW_PROGRESS
and (progress_task := result.pop("progress_task", None)) # Mypy does not agree with using pop on _FlowResultT
and (progress_task := result.pop("progress_task", None)) # type: ignore[arg-type]
and progress_task != flow.async_get_progress_task() and progress_task != flow.async_get_progress_task()
): ):
# The flow's progress task was changed, register a callback on it # The flow's progress task was changed, register a callback on it
@ -532,8 +530,9 @@ class FlowManager(abc.ABC):
def schedule_configure(_: asyncio.Task) -> None: def schedule_configure(_: asyncio.Task) -> None:
self.hass.async_create_task(call_configure()) self.hass.async_create_task(call_configure())
progress_task.add_done_callback(schedule_configure) # The mypy ignores are a consequence of mypy not accepting the pop above
flow.async_set_progress_task(progress_task) progress_task.add_done_callback(schedule_configure) # type: ignore[attr-defined]
flow.async_set_progress_task(progress_task) # type: ignore[arg-type]
elif result["type"] != FlowResultType.SHOW_PROGRESS: elif result["type"] != FlowResultType.SHOW_PROGRESS:
flow.async_cancel_progress_task() flow.async_cancel_progress_task()
@ -560,7 +559,9 @@ class FlowManager(abc.ABC):
return result return result
def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None: def _raise_if_step_does_not_exist(
self, flow: BaseFlowHandler, step_id: str
) -> None:
"""Raise if the step does not exist.""" """Raise if the step does not exist."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
@ -570,18 +571,45 @@ class FlowManager(abc.ABC):
f"Handler {self.__class__.__name__} doesn't support step {step_id}" f"Handler {self.__class__.__name__} doesn't support step {step_id}"
) )
async def _async_setup_preview(self, flow: FlowHandler) -> None: async def _async_setup_preview(self, flow: BaseFlowHandler) -> None:
"""Set up preview for a flow handler.""" """Set up preview for a flow handler."""
if flow.handler not in self._preview: if flow.handler not in self._preview:
self._preview.add(flow.handler) self._preview.add(flow.handler)
await flow.async_setup_preview(self.hass) await flow.async_setup_preview(self.hass)
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
for flow in flows:
if not include_uninitialized and flow.cur_step is None:
continue
result = self._flow_result(
flow_id=flow.flow_id,
handler=flow.handler,
context=flow.context,
)
if flow.cur_step:
result["step_id"] = flow.cur_step["step_id"]
results.append(result)
return results
class FlowHandler:
class FlowManager(BaseFlowManager[FlowResult]):
"""Manage all the flows that are in progress."""
_flow_result = FlowResult
class BaseFlowHandler(Generic[_FlowResultT]):
"""Handle a data entry flow.""" """Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT]
# Set by flow manager # Set by flow manager
cur_step: FlowResult | None = None cur_step: _FlowResultT | None = None
# While not purely typed, it makes typehinting more useful for us # While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts. # and removes the need for constant None checks or asserts.
@ -657,12 +685,12 @@ class FlowHandler:
description_placeholders: Mapping[str, str | None] | None = None, description_placeholders: Mapping[str, str | None] | None = None,
last_step: bool | None = None, last_step: bool | None = None,
preview: str | None = None, preview: str | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Return the definition of a form to gather user input. """Return the definition of a form to gather user input.
The step_id parameter is deprecated and will be removed in a future release. The step_id parameter is deprecated and will be removed in a future release.
""" """
flow_result = FlowResult( flow_result = self._flow_result(
type=FlowResultType.FORM, type=FlowResultType.FORM,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -684,11 +712,9 @@ class FlowHandler:
data: Mapping[str, Any], data: Mapping[str, Any],
description: str | None = None, description: str | None = None,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Finish flow.""" """Finish flow."""
flow_result = FlowResult( flow_result = self._flow_result(
version=self.VERSION,
minor_version=self.MINOR_VERSION,
type=FlowResultType.CREATE_ENTRY, type=FlowResultType.CREATE_ENTRY,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -707,10 +733,14 @@ class FlowHandler:
*, *,
reason: str, reason: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Abort the flow.""" """Abort the flow."""
return _create_abort_data( return self._flow_result(
self.flow_id, self.handler, reason, description_placeholders type=FlowResultType.ABORT,
flow_id=self.flow_id,
handler=self.handler,
reason=reason,
description_placeholders=description_placeholders,
) )
@callback @callback
@ -720,12 +750,12 @@ class FlowHandler:
step_id: str | None = None, step_id: str | None = None,
url: str, url: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Return the definition of an external step for the user to take. """Return the definition of an external step for the user to take.
The step_id parameter is deprecated and will be removed in a future release. The step_id parameter is deprecated and will be removed in a future release.
""" """
flow_result = FlowResult( flow_result = self._flow_result(
type=FlowResultType.EXTERNAL_STEP, type=FlowResultType.EXTERNAL_STEP,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -737,9 +767,9 @@ class FlowHandler:
return flow_result return flow_result
@callback @callback
def async_external_step_done(self, *, next_step_id: str) -> FlowResult: def async_external_step_done(self, *, next_step_id: str) -> _FlowResultT:
"""Return the definition of an external step for the user to take.""" """Return the definition of an external step for the user to take."""
return FlowResult( return self._flow_result(
type=FlowResultType.EXTERNAL_STEP_DONE, type=FlowResultType.EXTERNAL_STEP_DONE,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -754,7 +784,7 @@ class FlowHandler:
progress_action: str, progress_action: str,
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
progress_task: asyncio.Task[Any] | None = None, progress_task: asyncio.Task[Any] | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Show a progress message to the user, without user input allowed. """Show a progress message to the user, without user input allowed.
The step_id parameter is deprecated and will be removed in a future release. The step_id parameter is deprecated and will be removed in a future release.
@ -777,7 +807,7 @@ class FlowHandler:
if progress_task is None: if progress_task is None:
self.deprecated_show_progress = True self.deprecated_show_progress = True
flow_result = FlowResult( flow_result = self._flow_result(
type=FlowResultType.SHOW_PROGRESS, type=FlowResultType.SHOW_PROGRESS,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -790,9 +820,9 @@ class FlowHandler:
return flow_result return flow_result
@callback @callback
def async_show_progress_done(self, *, next_step_id: str) -> FlowResult: def async_show_progress_done(self, *, next_step_id: str) -> _FlowResultT:
"""Mark the progress done.""" """Mark the progress done."""
return FlowResult( return self._flow_result(
type=FlowResultType.SHOW_PROGRESS_DONE, type=FlowResultType.SHOW_PROGRESS_DONE,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -806,13 +836,13 @@ class FlowHandler:
step_id: str | None = None, step_id: str | None = None,
menu_options: list[str] | dict[str, str], menu_options: list[str] | dict[str, str],
description_placeholders: Mapping[str, str] | None = None, description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult: ) -> _FlowResultT:
"""Show a navigation menu to the user. """Show a navigation menu to the user.
Options dict maps step_id => i18n label Options dict maps step_id => i18n label
The step_id parameter is deprecated and will be removed in a future release. The step_id parameter is deprecated and will be removed in a future release.
""" """
flow_result = FlowResult( flow_result = self._flow_result(
type=FlowResultType.MENU, type=FlowResultType.MENU,
flow_id=self.flow_id, flow_id=self.flow_id,
handler=self.handler, handler=self.handler,
@ -853,21 +883,10 @@ class FlowHandler:
self.__progress_task = progress_task self.__progress_task = progress_task
@callback class FlowHandler(BaseFlowHandler[FlowResult]):
def _create_abort_data( """Handle a data entry flow."""
flow_id: str,
handler: str, _flow_result = FlowResult
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
"""Return the definition of an external step for the user to take."""
return FlowResult(
type=FlowResultType.ABORT,
flow_id=flow_id,
handler=handler,
reason=reason,
description_placeholders=description_placeholders,
)
# These can be removed if no deprecated constant are in this module anymore # These can be removed if no deprecated constant are in this module anymore

View file

@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from .typing import DiscoveryInfoType from .typing import DiscoveryInfoType
@ -46,7 +45,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -57,7 +56,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_confirm( async def async_step_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Confirm setup.""" """Confirm setup."""
if user_input is None and onboarding.async_is_onboarded(self.hass): if user_input is None and onboarding.async_is_onboarded(self.hass):
self._set_confirm_only() self._set_confirm_only()
@ -87,7 +86,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_discovery( async def async_step_discovery(
self, discovery_info: DiscoveryInfoType self, discovery_info: DiscoveryInfoType
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by discovery.""" """Handle a flow initialized by discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -98,7 +97,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_bluetooth( async def async_step_bluetooth(
self, discovery_info: BluetoothServiceInfoBleak self, discovery_info: BluetoothServiceInfoBleak
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by bluetooth discovery.""" """Handle a flow initialized by bluetooth discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -107,7 +106,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_dhcp(self, discovery_info: DhcpServiceInfo) -> FlowResult: async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by dhcp discovery.""" """Handle a flow initialized by dhcp discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -118,7 +119,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_homekit( async def async_step_homekit(
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Homekit discovery.""" """Handle a flow initialized by Homekit discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -127,7 +128,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: async def async_step_mqtt(
self, discovery_info: MqttServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by mqtt discovery.""" """Handle a flow initialized by mqtt discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -138,7 +141,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_zeroconf( async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo self, discovery_info: ZeroconfServiceInfo
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Zeroconf discovery.""" """Handle a flow initialized by Zeroconf discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -147,7 +150,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_ssdp(self, discovery_info: SsdpServiceInfo) -> FlowResult: async def async_step_ssdp(
self, discovery_info: SsdpServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Ssdp discovery.""" """Handle a flow initialized by Ssdp discovery."""
if self._async_in_progress() or self._async_current_entries(): if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -156,7 +161,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult: async def async_step_import(
self, _: dict[str, Any] | None
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by import.""" """Handle a flow initialized by import."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
@ -205,7 +212,7 @@ class WebhookFlowHandler(config_entries.ConfigFlow):
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a user initiated set up flow to create a webhook.""" """Handle a user initiated set up flow to create a webhook."""
if not self._allow_multiple and self._async_current_entries(): if not self._allow_multiple and self._async_current_entries():
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")

View file

@ -25,7 +25,6 @@ from yarl import URL
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import http from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.loader import async_get_application_credentials from homeassistant.loader import async_get_application_credentials
from .aiohttp_client import async_get_clientsession from .aiohttp_client import async_get_clientsession
@ -253,7 +252,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_pick_implementation( async def async_step_pick_implementation(
self, user_input: dict | None = None self, user_input: dict | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow start.""" """Handle a flow start."""
implementations = await async_get_implementations(self.hass, self.DOMAIN) implementations = await async_get_implementations(self.hass, self.DOMAIN)
@ -286,7 +285,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_auth( async def async_step_auth(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Create an entry for auth.""" """Create an entry for auth."""
# Flow has been triggered by external data # Flow has been triggered by external data
if user_input is not None: if user_input is not None:
@ -314,7 +313,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_creation( async def async_step_creation(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Create config entry from external data.""" """Create config entry from external data."""
_LOGGER.debug("Creating config entry from external data") _LOGGER.debug("Creating config entry from external data")
@ -353,14 +352,18 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
{"auth_implementation": self.flow_impl.domain, "token": token} {"auth_implementation": self.flow_impl.domain, "token": token}
) )
async def async_step_authorize_rejected(self, data: None = None) -> FlowResult: async def async_step_authorize_rejected(
self, data: None = None
) -> config_entries.ConfigFlowResult:
"""Step to handle flow rejection.""" """Step to handle flow rejection."""
return self.async_abort( return self.async_abort(
reason="user_rejected_authorize", reason="user_rejected_authorize",
description_placeholders={"error": self.external_data["error"]}, description_placeholders={"error": self.external_data["error"]},
) )
async def async_oauth_create_entry(self, data: dict) -> FlowResult: async def async_oauth_create_entry(
self, data: dict
) -> config_entries.ConfigFlowResult:
"""Create an entry for the flow. """Create an entry for the flow.
Ok to override if you want to fetch extra info or even add another step. Ok to override if you want to fetch extra info or even add another step.
@ -369,7 +372,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> config_entries.ConfigFlowResult:
"""Handle a flow start.""" """Handle a flow start."""
return await self.async_step_pick_implementation(user_input) return await self.async_step_pick_implementation(user_input)

View file

@ -18,7 +18,7 @@ from . import config_validation as cv
class _BaseFlowManagerView(HomeAssistantView): class _BaseFlowManagerView(HomeAssistantView):
"""Foundation for flow manager views.""" """Foundation for flow manager views."""
def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None: def __init__(self, flow_mgr: data_entry_flow.BaseFlowManager) -> None:
"""Initialize the flow manager index view.""" """Initialize the flow manager index view."""
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr

View file

@ -4,9 +4,9 @@ from __future__ import annotations
from collections.abc import Coroutine from collections.abc import Coroutine
from typing import Any, NamedTuple from typing import Any, NamedTuple
from homeassistant.config_entries import ConfigFlowResult
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_limited_concurrency from homeassistant.util.async_ import gather_with_limited_concurrency
@ -40,7 +40,7 @@ def async_create_flow(
@callback @callback
def _async_init_flow( def _async_init_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
) -> Coroutine[None, None, FlowResult] | None: ) -> Coroutine[None, None, ConfigFlowResult] | None:
"""Create a discovery flow.""" """Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data # Avoid spawning flows that have the same initial discovery data
# as ones in progress as it may cause additional device probing # as ones in progress as it may cause additional device probing

View file

@ -10,9 +10,15 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowWithConfigEntry,
)
from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.data_entry_flow import FlowResult, UnknownHandler from homeassistant.data_entry_flow import UnknownHandler
from . import entity_registry as er, selector from . import entity_registry as er, selector
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
@ -126,7 +132,7 @@ class SchemaCommonFlowHandler:
async def async_step( async def async_step(
self, step_id: str, user_input: dict[str, Any] | None = None self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a step.""" """Handle a step."""
if isinstance(self._flow[step_id], SchemaFlowFormStep): if isinstance(self._flow[step_id], SchemaFlowFormStep):
return await self._async_form_step(step_id, user_input) return await self._async_form_step(step_id, user_input)
@ -141,7 +147,7 @@ class SchemaCommonFlowHandler:
async def _async_form_step( async def _async_form_step(
self, step_id: str, user_input: dict[str, Any] | None = None self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a form step.""" """Handle a form step."""
form_step: SchemaFlowFormStep = cast(SchemaFlowFormStep, self._flow[step_id]) form_step: SchemaFlowFormStep = cast(SchemaFlowFormStep, self._flow[step_id])
@ -204,7 +210,7 @@ class SchemaCommonFlowHandler:
async def _show_next_step_or_create_entry( async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep self, form_step: SchemaFlowFormStep
) -> FlowResult: ) -> ConfigFlowResult:
next_step_id_or_end_flow: str | None next_step_id_or_end_flow: str | None
if callable(form_step.next_step): if callable(form_step.next_step):
@ -222,7 +228,7 @@ class SchemaCommonFlowHandler:
next_step_id: str, next_step_id: str,
error: SchemaFlowError | None = None, error: SchemaFlowError | None = None,
user_input: dict[str, Any] | None = None, user_input: dict[str, Any] | None = None,
) -> FlowResult: ) -> ConfigFlowResult:
"""Show form for next step.""" """Show form for next step."""
if isinstance(self._flow[next_step_id], SchemaFlowMenuStep): if isinstance(self._flow[next_step_id], SchemaFlowMenuStep):
menu_step = cast(SchemaFlowMenuStep, self._flow[next_step_id]) menu_step = cast(SchemaFlowMenuStep, self._flow[next_step_id])
@ -271,7 +277,7 @@ class SchemaCommonFlowHandler:
async def _async_menu_step( async def _async_menu_step(
self, step_id: str, user_input: dict[str, Any] | None = None self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a menu step.""" """Handle a menu step."""
menu_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id]) menu_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id])
return self._handler.async_show_menu( return self._handler.async_show_menu(
@ -280,7 +286,7 @@ class SchemaCommonFlowHandler:
) )
class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): class SchemaConfigFlowHandler(ConfigFlow, ABC):
"""Handle a schema based config flow.""" """Handle a schema based config flow."""
config_flow: Mapping[str, SchemaFlowStep] config_flow: Mapping[str, SchemaFlowStep]
@ -294,8 +300,8 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
@callback @callback
def _async_get_options_flow( def _async_get_options_flow(
config_entry: config_entries.ConfigEntry, config_entry: ConfigEntry,
) -> config_entries.OptionsFlow: ) -> OptionsFlow:
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
if cls.options_flow is None: if cls.options_flow is None:
raise UnknownHandler raise UnknownHandler
@ -324,9 +330,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
@classmethod @classmethod
@callback @callback
def async_supports_options_flow( def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
cls, config_entry: config_entries.ConfigEntry
) -> bool:
"""Return options flow support for this handler.""" """Return options flow support for this handler."""
return cls.options_flow is not None return cls.options_flow is not None
@ -335,13 +339,13 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
step_id: str, step_id: str,
) -> Callable[ ) -> Callable[
[SchemaConfigFlowHandler, dict[str, Any] | None], [SchemaConfigFlowHandler, dict[str, Any] | None],
Coroutine[Any, Any, FlowResult], Coroutine[Any, Any, ConfigFlowResult],
]: ]:
"""Generate a step handler.""" """Generate a step handler."""
async def _async_step( async def _async_step(
self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle a config flow step.""" """Handle a config flow step."""
# pylint: disable-next=protected-access # pylint: disable-next=protected-access
result = await self._common_handler.async_step(step_id, user_input) result = await self._common_handler.async_step(step_id, user_input)
@ -382,7 +386,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
self, self,
data: Mapping[str, Any], data: Mapping[str, Any],
**kwargs: Any, **kwargs: Any,
) -> FlowResult: ) -> ConfigFlowResult:
"""Finish config flow and create a config entry.""" """Finish config flow and create a config entry."""
self.async_config_flow_finished(data) self.async_config_flow_finished(data)
return super().async_create_entry( return super().async_create_entry(
@ -390,12 +394,12 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
) )
class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): class SchemaOptionsFlowHandler(OptionsFlowWithConfigEntry):
"""Handle a schema based options flow.""" """Handle a schema based options flow."""
def __init__( def __init__(
self, self,
config_entry: config_entries.ConfigEntry, config_entry: ConfigEntry,
options_flow: Mapping[str, SchemaFlowStep], options_flow: Mapping[str, SchemaFlowStep],
async_options_flow_finished: Callable[[HomeAssistant, Mapping[str, Any]], None] async_options_flow_finished: Callable[[HomeAssistant, Mapping[str, Any]], None]
| None = None, | None = None,
@ -430,13 +434,13 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
step_id: str, step_id: str,
) -> Callable[ ) -> Callable[
[SchemaConfigFlowHandler, dict[str, Any] | None], [SchemaConfigFlowHandler, dict[str, Any] | None],
Coroutine[Any, Any, FlowResult], Coroutine[Any, Any, ConfigFlowResult],
]: ]:
"""Generate a step handler.""" """Generate a step handler."""
async def _async_step( async def _async_step(
self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle an options flow step.""" """Handle an options flow step."""
# pylint: disable-next=protected-access # pylint: disable-next=protected-access
result = await self._common_handler.async_step(step_id, user_input) result = await self._common_handler.async_step(step_id, user_input)
@ -449,7 +453,7 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
self, self,
data: Mapping[str, Any], data: Mapping[str, Any],
**kwargs: Any, **kwargs: Any,
) -> FlowResult: ) -> ConfigFlowResult:
"""Finish config flow and create a config entry.""" """Finish config flow and create a config entry."""
if self._async_options_flow_finished: if self._async_options_flow_finished:
self._async_options_flow_finished(self.hass, data) self._async_options_flow_finished(self.hass, data)

View file

@ -55,11 +55,12 @@ class TypeHintMatch:
) )
@dataclass @dataclass(kw_only=True)
class ClassTypeHintMatch: class ClassTypeHintMatch:
"""Class for pattern matching.""" """Class for pattern matching."""
base_class: str base_class: str
exclude_base_classes: set[str] | None = None
matches: list[TypeHintMatch] matches: list[TypeHintMatch]
@ -481,6 +482,7 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
"config_flow": [ "config_flow": [
ClassTypeHintMatch( ClassTypeHintMatch(
base_class="FlowHandler", base_class="FlowHandler",
exclude_base_classes={"ConfigEntryBaseFlow"},
matches=[ matches=[
TypeHintMatch( TypeHintMatch(
function_name="async_step_*", function_name="async_step_*",
@ -492,6 +494,11 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
ClassTypeHintMatch( ClassTypeHintMatch(
base_class="ConfigFlow", base_class="ConfigFlow",
matches=[ matches=[
TypeHintMatch(
function_name="async_step123_*",
arg_types={},
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch( TypeHintMatch(
function_name="async_get_options_flow", function_name="async_get_options_flow",
arg_types={ arg_types={
@ -504,56 +511,66 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
arg_types={ arg_types={
1: "DhcpServiceInfo", 1: "DhcpServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_hassio", function_name="async_step_hassio",
arg_types={ arg_types={
1: "HassioServiceInfo", 1: "HassioServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_homekit", function_name="async_step_homekit",
arg_types={ arg_types={
1: "ZeroconfServiceInfo", 1: "ZeroconfServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_mqtt", function_name="async_step_mqtt",
arg_types={ arg_types={
1: "MqttServiceInfo", 1: "MqttServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_reauth", function_name="async_step_reauth",
arg_types={ arg_types={
1: "Mapping[str, Any]", 1: "Mapping[str, Any]",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_ssdp", function_name="async_step_ssdp",
arg_types={ arg_types={
1: "SsdpServiceInfo", 1: "SsdpServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_usb", function_name="async_step_usb",
arg_types={ arg_types={
1: "UsbServiceInfo", 1: "UsbServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_step_zeroconf", function_name="async_step_zeroconf",
arg_types={ arg_types={
1: "ZeroconfServiceInfo", 1: "ZeroconfServiceInfo",
}, },
return_type="FlowResult", return_type=["ConfigFlowResult", "FlowResult"],
),
],
),
ClassTypeHintMatch(
base_class="OptionsFlow",
matches=[
TypeHintMatch(
function_name="async_step_*",
arg_types={},
return_type=["ConfigFlowResult", "FlowResult"],
), ),
], ],
), ),
@ -3126,11 +3143,19 @@ class HassTypeHintChecker(BaseChecker):
ancestor: nodes.ClassDef ancestor: nodes.ClassDef
checked_class_methods: set[str] = set() checked_class_methods: set[str] = set()
ancestors = list(node.ancestors()) # cache result for inside loop ancestors = list(node.ancestors()) # cache result for inside loop
for class_matches in self._class_matchers: for class_matcher in self._class_matchers:
skip_matcher = False
if exclude_base_classes := class_matcher.exclude_base_classes:
for ancestor in ancestors:
if ancestor.name in exclude_base_classes:
skip_matcher = True
break
if skip_matcher:
continue
for ancestor in ancestors: for ancestor in ancestors:
if ancestor.name == class_matches.base_class: if ancestor.name == class_matcher.base_class:
self._visit_class_functions( self._visit_class_functions(
node, class_matches.matches, checked_class_methods node, class_matcher.matches, checked_class_methods
) )
def _visit_class_functions( def _visit_class_functions(

View file

@ -6,10 +6,9 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .const import DOMAIN from .const import DOMAIN
@ -68,14 +67,14 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
return {"title": "Name of the device"} return {"title": "Name of the device"}
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for NEW_NAME.""" """Handle a config flow for NEW_NAME."""
VERSION = 1 VERSION = 1
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:

View file

@ -147,13 +147,11 @@ async def test_legacy_subscription_repair_flow(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": DOMAIN, "handler": DOMAIN,
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue( assert not issue_registry.async_get_issue(

View file

@ -941,10 +941,8 @@ async def test_two_step_options_flow(hass: HomeAssistant, client) -> None:
"handler": "test1", "handler": "test1",
"type": "create_entry", "type": "create_entry",
"title": "Enable disable", "title": "Enable disable",
"version": 1,
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }

View file

@ -94,13 +94,11 @@ async def test_supervisor_issue_repair_flow(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -190,13 +188,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -305,13 +301,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions_and_confir
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -386,13 +380,11 @@ async def test_supervisor_issue_repair_flow_skip_confirmation(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -486,13 +478,11 @@ async def test_mount_failed_repair_flow(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -598,13 +588,11 @@ async def test_supervisor_issue_docker_config_repair_flow(
flow_id = data["flow_id"] flow_id = data["flow_id"]
assert data == { assert data == {
"version": 1,
"type": "create_entry", "type": "create_entry",
"flow_id": flow_id, "flow_id": flow_id,
"handler": "hassio", "handler": "hassio",
"description": None, "description": None,
"description_placeholders": None, "description_placeholders": None,
"minor_version": 1,
} }
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")

View file

@ -244,9 +244,7 @@ async def test_issues_created(
"description_placeholders": None, "description_placeholders": None,
"flow_id": flow_id, "flow_id": flow_id,
"handler": DOMAIN, "handler": DOMAIN,
"minor_version": 1,
"type": "create_entry", "type": "create_entry",
"version": 1,
} }
await ws_client.send_json({"id": 4, "type": "repairs/list_issues"}) await ws_client.send_json({"id": 4, "type": "repairs/list_issues"})

View file

@ -338,9 +338,7 @@ async def test_fix_issue(
"description_placeholders": None, "description_placeholders": None,
"flow_id": flow_id, "flow_id": flow_id,
"handler": domain, "handler": domain,
"minor_version": 1,
"type": "create_entry", "type": "create_entry",
"version": 1,
} }
await ws_client.send_json({"id": 4, "type": "repairs/list_issues"}) await ws_client.send_json({"id": 4, "type": "repairs/list_issues"})

View file

@ -63,7 +63,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup.""" """Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch( with patch(
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", "homeassistant.data_entry_flow.BaseFlowManager.async_has_matching_flow",
return_value=True, return_value=True,
): ):
discovery_flow.async_create_flow( discovery_flow.async_create_flow(

View file

@ -45,7 +45,7 @@ def manager_fixture():
handlers = Registry() handlers = Registry()
entries = [] entries = []
class FlowManager(data_entry_flow.FlowManager): class FlowManager(data_entry_flow.BaseFlowManager):
"""Test flow manager.""" """Test flow manager."""
async def async_create_flow(self, handler_key, *, context, data): async def async_create_flow(self, handler_key, *, context, data):
@ -105,7 +105,7 @@ async def test_name(hass: HomeAssistant, entity_registry: er.EntityRegistry) ->
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) @pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_config_flow_advanced_option( async def test_config_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
) -> None: ) -> None:
"""Test handling of advanced options in config flow.""" """Test handling of advanced options in config flow."""
manager.hass = hass manager.hass = hass
@ -200,7 +200,7 @@ async def test_config_flow_advanced_option(
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) @pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_options_flow_advanced_option( async def test_options_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
) -> None: ) -> None:
"""Test handling of advanced options in options flow.""" """Test handling of advanced options in options flow."""
manager.hass = hass manager.hass = hass
@ -475,7 +475,7 @@ async def test_next_step_function(hass: HomeAssistant) -> None:
async def test_suggested_values( async def test_suggested_values(
hass: HomeAssistant, manager: data_entry_flow.FlowManager hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None: ) -> None:
"""Test suggested_values handling in SchemaFlowFormStep.""" """Test suggested_values handling in SchemaFlowFormStep."""
manager.hass = hass manager.hass = hass
@ -667,7 +667,7 @@ async def test_options_flow_state(hass: HomeAssistant) -> None:
async def test_options_flow_omit_optional_keys( async def test_options_flow_omit_optional_keys(
hass: HomeAssistant, manager: data_entry_flow.FlowManager hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None: ) -> None:
"""Test handling of advanced options in options flow.""" """Test handling of advanced options in options flow."""
manager.hass = hass manager.hass = hass

View file

@ -346,7 +346,7 @@ def test_invalid_config_flow_step(
pylint.testutils.MessageTest( pylint.testutils.MessageTest(
msg_id="hass-return-type", msg_id="hass-return-type",
node=func_node, node=func_node,
args=("FlowResult", "async_step_zeroconf"), args=(["ConfigFlowResult", "FlowResult"], "async_step_zeroconf"),
line=11, line=11,
col_offset=4, col_offset=4,
end_line=11, end_line=11,
@ -374,7 +374,7 @@ def test_valid_config_flow_step(
async def async_step_zeroconf( async def async_step_zeroconf(
self, self,
device_config: ZeroconfServiceInfo device_config: ZeroconfServiceInfo
) -> FlowResult: ) -> ConfigFlowResult:
pass pass
""", """,
"homeassistant.components.pylint_test.config_flow", "homeassistant.components.pylint_test.config_flow",

View file

@ -24,9 +24,11 @@ def manager():
handlers = Registry() handlers = Registry()
entries = [] entries = []
class FlowManager(data_entry_flow.FlowManager): class FlowManager(data_entry_flow.BaseFlowManager):
"""Test flow manager.""" """Test flow manager."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(self, handler_key, *, context, data): async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow.""" """Test create flow."""
handler = handlers.get(handler_key) handler = handlers.get(handler_key)
@ -79,7 +81,7 @@ async def test_configure_reuses_handler_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0 assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None: async def test_configure_two_steps(manager: data_entry_flow.BaseFlowManager) -> None:
"""Test that we reuse instances.""" """Test that we reuse instances."""
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
@ -211,7 +213,6 @@ async def test_create_saves_data(manager) -> None:
assert len(manager.mock_created_entries) == 1 assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0] entry = manager.mock_created_entries[0]
assert entry["version"] == 5
assert entry["handler"] == "test" assert entry["handler"] == "test"
assert entry["title"] == "Test Title" assert entry["title"] == "Test Title"
assert entry["data"] == "Test Data" assert entry["data"] == "Test Data"
@ -237,7 +238,6 @@ async def test_discovery_init_flow(manager) -> None:
assert len(manager.mock_created_entries) == 1 assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0] entry = manager.mock_created_entries[0]
assert entry["version"] == 5
assert entry["handler"] == "test" assert entry["handler"] == "test"
assert entry["title"] == "hello" assert entry["title"] == "hello"
assert entry["data"] == data assert entry["data"] == data
@ -258,7 +258,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None:
step_id="init", data_schema=vol.Schema({"count": int}) step_id="init", data_schema=vol.Schema({"count": int})
) )
class FlowManager(data_entry_flow.FlowManager): class FlowManager(data_entry_flow.BaseFlowManager):
async def async_create_flow(self, handler_name, *, context, data): async def async_create_flow(self, handler_name, *, context, data):
"""Create a test flow.""" """Create a test flow."""
return TestFlow() return TestFlow()
@ -775,7 +775,7 @@ async def test_async_get_unknown_flow(manager) -> None:
async def test_async_has_matching_flow( async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None: ) -> None:
"""Test we can check for matching flows.""" """Test we can check for matching flows."""
manager.hass = hass manager.hass = hass
@ -951,7 +951,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
async def test_find_flows_by_init_data_type( async def test_find_flows_by_init_data_type(
manager: data_entry_flow.FlowManager, manager: data_entry_flow.BaseFlowManager,
) -> None: ) -> None:
"""Test we can find flows by init data type.""" """Test we can find flows by init data type."""