From 82efb3d35b8dfaafe4ac4f9d74c84ad9df590f0a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 7 Mar 2024 12:41:14 +0100 Subject: [PATCH] Make FlowResult a generic type (#111952) --- homeassistant/auth/__init__.py | 18 +++-- homeassistant/auth/models.py | 3 + homeassistant/auth/providers/__init__.py | 15 +++-- homeassistant/auth/providers/command_line.py | 5 +- homeassistant/auth/providers/homeassistant.py | 5 +- .../auth/providers/insecure_example.py | 5 +- .../auth/providers/legacy_api_password.py | 5 +- .../auth/providers/trusted_networks.py | 5 +- homeassistant/components/auth/login_flow.py | 20 +++--- .../components/zwave_js/config_flow.py | 2 +- homeassistant/config_entries.py | 14 ++-- homeassistant/data_entry_flow.py | 67 ++++++++++++------- homeassistant/helpers/data_entry_flow.py | 11 +-- 13 files changed, 95 insertions(+), 80 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index f99e90dbc05..fa89134305e 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -19,13 +19,13 @@ from homeassistant.core import ( HomeAssistant, callback, ) -from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.util import dt as dt_util from . import auth_store, jwt_wrapper, models from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config +from .models import AuthFlowResult from .providers import AuthProvider, LoginFlow, auth_provider_from_config EVENT_USER_ADDED = "user_added" @@ -88,9 +88,13 @@ async def auth_manager_from_config( return manager -class AuthManagerFlowManager(data_entry_flow.FlowManager): +class AuthManagerFlowManager( + data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]] +): """Manage authentication flows.""" + _flow_result = AuthFlowResult + def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None: """Init auth manager flows.""" super().__init__(hass) @@ -98,11 +102,11 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager): async def async_create_flow( self, - handler_key: str, + handler_key: tuple[str, str], *, context: dict[str, Any] | None = None, data: dict[str, Any] | None = None, - ) -> data_entry_flow.FlowHandler: + ) -> LoginFlow: """Create a login flow.""" auth_provider = self.auth_manager.get_auth_provider(*handler_key) if not auth_provider: @@ -110,8 +114,10 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager): return await auth_provider.async_login_flow(context) async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: FlowResult - ) -> FlowResult: + self, + flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]], + result: AuthFlowResult, + ) -> AuthFlowResult: """Return a user as result of login flow.""" flow = cast(LoginFlow, flow) diff --git a/homeassistant/auth/models.py b/homeassistant/auth/models.py index 4cf94401478..d71cd682086 100644 --- a/homeassistant/auth/models.py +++ b/homeassistant/auth/models.py @@ -11,6 +11,7 @@ from attr import Attribute from attr.setters import validate from homeassistant.const import __version__ +from homeassistant.data_entry_flow import FlowResult from homeassistant.util import dt as dt_util from . import permissions as perm_mdl @@ -26,6 +27,8 @@ TOKEN_TYPE_NORMAL = "normal" TOKEN_TYPE_SYSTEM = "system" TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token" +AuthFlowResult = FlowResult[tuple[str, str]] + @attr.s(slots=True) class Group: diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 7d74dd2dc26..2036178b225 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -13,14 +13,13 @@ from voluptuous.humanize import humanize_error from homeassistant import data_entry_flow, requirements from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.util import dt as dt_util from homeassistant.util.decorator import Registry from ..auth_store import AuthStore from ..const import MFA_SESSION_EXPIRATION -from ..models import Credentials, RefreshToken, User, UserMeta +from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta _LOGGER = logging.getLogger(__name__) DATA_REQS = "auth_prov_reqs_processed" @@ -181,9 +180,11 @@ async def load_auth_provider_module( return module -class LoginFlow(data_entry_flow.FlowHandler): +class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]): """Handler for the login flow.""" + _flow_result = AuthFlowResult + def __init__(self, auth_provider: AuthProvider) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider @@ -197,7 +198,7 @@ class LoginFlow(data_entry_flow.FlowHandler): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the first step of login flow. Return self.async_show_form(step_id='init') if user_input is None. @@ -207,7 +208,7 @@ class LoginFlow(data_entry_flow.FlowHandler): async def async_step_select_mfa_module( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of select mfa module.""" errors = {} @@ -232,7 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler): async def async_step_mfa( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of mfa validation.""" assert self.credential assert self.user @@ -282,6 +283,6 @@ class LoginFlow(data_entry_flow.FlowHandler): errors=errors, ) - async def async_finish(self, flow_result: Any) -> FlowResult: + async def async_finish(self, flow_result: Any) -> AuthFlowResult: """Handle the pass of login flow.""" return self.async_create_entry(data=flow_result) diff --git a/homeassistant/auth/providers/command_line.py b/homeassistant/auth/providers/command_line.py index 4ec2ca18611..346d2cc503f 100644 --- a/homeassistant/auth/providers/command_line.py +++ b/homeassistant/auth/providers/command_line.py @@ -10,10 +10,9 @@ from typing import Any, cast import voluptuous as vol from homeassistant.const import CONF_COMMAND -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError -from ..models import Credentials, UserMeta +from ..models import AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow CONF_ARGS = "args" @@ -138,7 +137,7 @@ class CommandLineLoginFlow(LoginFlow): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of the form.""" errors = {} diff --git a/homeassistant/auth/providers/homeassistant.py b/homeassistant/auth/providers/homeassistant.py index 6f621b93a6a..2c4fb034dc9 100644 --- a/homeassistant/auth/providers/homeassistant.py +++ b/homeassistant/auth/providers/homeassistant.py @@ -12,11 +12,10 @@ import voluptuous as vol from homeassistant.const import CONF_ID from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.storage import Store -from ..models import Credentials, UserMeta +from ..models import AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow STORAGE_VERSION = 1 @@ -321,7 +320,7 @@ class HassLoginFlow(LoginFlow): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of the form.""" errors = {} diff --git a/homeassistant/auth/providers/insecure_example.py b/homeassistant/auth/providers/insecure_example.py index f7f01e74c27..851bed6638c 100644 --- a/homeassistant/auth/providers/insecure_example.py +++ b/homeassistant/auth/providers/insecure_example.py @@ -8,10 +8,9 @@ from typing import Any, cast import voluptuous as vol from homeassistant.core import callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError -from ..models import Credentials, UserMeta +from ..models import AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow USER_SCHEMA = vol.Schema( @@ -98,7 +97,7 @@ class ExampleLoginFlow(LoginFlow): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of the form.""" errors = None diff --git a/homeassistant/auth/providers/legacy_api_password.py b/homeassistant/auth/providers/legacy_api_password.py index 98c246d74e4..4fcfd4f7b12 100644 --- a/homeassistant/auth/providers/legacy_api_password.py +++ b/homeassistant/auth/providers/legacy_api_password.py @@ -11,12 +11,11 @@ from typing import Any, cast import voluptuous as vol from homeassistant.core import async_get_hass, callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue -from ..models import Credentials, UserMeta +from ..models import AuthFlowResult, Credentials, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow AUTH_PROVIDER_TYPE = "legacy_api_password" @@ -101,7 +100,7 @@ class LegacyLoginFlow(LoginFlow): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of the form.""" errors = {} diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index cc195c14c23..54633744bd9 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -19,13 +19,12 @@ from typing import Any, cast import voluptuous as vol from homeassistant.core import callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.network import is_cloud_connection from .. import InvalidAuthError -from ..models import Credentials, RefreshToken, UserMeta +from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow IPAddress = IPv4Address | IPv6Address @@ -226,7 +225,7 @@ class TrustedNetworksLoginFlow(LoginFlow): async def async_step_init( self, user_input: dict[str, str] | None = None - ) -> FlowResult: + ) -> AuthFlowResult: """Handle the step of the form.""" try: cast( diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index cc6cb5fc47a..12b1893bc9d 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -79,7 +79,7 @@ import voluptuous_serialize from homeassistant import data_entry_flow from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError -from homeassistant.auth.models import Credentials +from homeassistant.auth.models import AuthFlowResult, Credentials from homeassistant.components import onboarding from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.ban import ( @@ -197,8 +197,8 @@ class AuthProvidersView(HomeAssistantView): def _prepare_result_json( - result: data_entry_flow.FlowResult, -) -> data_entry_flow.FlowResult: + result: AuthFlowResult, +) -> AuthFlowResult: """Convert result to JSON.""" if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: data = result.copy() @@ -237,7 +237,7 @@ class LoginFlowBaseView(HomeAssistantView): self, request: web.Request, client_id: str, - result: data_entry_flow.FlowResult, + result: AuthFlowResult, ) -> web.Response: """Convert the flow result to a response.""" if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: @@ -297,7 +297,9 @@ class LoginFlowIndexView(LoginFlowBaseView): vol.Schema( { vol.Required("client_id"): str, - vol.Required("handler"): vol.Any(str, list), + vol.Required("handler"): vol.All( + [vol.Any(str, None)], vol.Length(2, 2), vol.Coerce(tuple) + ), vol.Required("redirect_uri"): str, vol.Optional("type", default="authorize"): str, } @@ -312,15 +314,11 @@ class LoginFlowIndexView(LoginFlowBaseView): if not indieauth.verify_client_id(client_id): return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST) - handler: tuple[str, ...] | str - if isinstance(data["handler"], list): - handler = tuple(data["handler"]) - else: - handler = data["handler"] + handler: tuple[str, str] = tuple(data["handler"]) try: result = await self._flow_mgr.async_init( - handler, # type: ignore[arg-type] + handler, context={ "ip_address": ip_address(request.remote), # type: ignore[arg-type] "credential_only": data.get("type") == "link_user", diff --git a/homeassistant/components/zwave_js/config_flow.py b/homeassistant/components/zwave_js/config_flow.py index cb564de924c..8c5aa55713a 100644 --- a/homeassistant/components/zwave_js/config_flow.py +++ b/homeassistant/components/zwave_js/config_flow.py @@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC): @property @abstractmethod - def flow_manager(self) -> FlowManager[ConfigFlowResult]: + def flow_manager(self) -> FlowManager[ConfigFlowResult, str]: """Return the flow manager of the flow.""" async def async_step_install_addon( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 164825c4dec..d9023e5e11a 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception): """Error to indicate that a flow has been cancelled.""" -class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): +class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]): """Manage all the config entry flows that are in progress.""" _flow_result = ConfigFlowResult @@ -1171,7 +1171,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_finish_flow( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowResult, str], result: ConfigFlowResult, ) -> ConfigFlowResult: """Finish a config flow and add an entry.""" @@ -1293,7 +1293,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_post_init( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowResult, str], result: ConfigFlowResult, ) -> None: """After a flow is initialised trigger new flow notifications.""" @@ -1940,7 +1940,7 @@ def _async_abort_entries_match( raise data_entry_flow.AbortFlow("already_configured") -class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]): +class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult, str]): """Base class for config and option flows.""" _flow_result = ConfigFlowResult @@ -2292,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow): return self.async_abort(reason=reason) -class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): +class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]): """Flow to set options for a configuration entry.""" _flow_result = ConfigFlowResult @@ -2322,7 +2322,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): async def async_finish_flow( self, - flow: data_entry_flow.FlowHandler[ConfigFlowResult], + flow: data_entry_flow.FlowHandler[ConfigFlowResult, str], result: ConfigFlowResult, ) -> ConfigFlowResult: """Finish an options flow and update options for configuration entry. @@ -2344,7 +2344,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): return result async def _async_setup_preview( - self, flow: data_entry_flow.FlowHandler[ConfigFlowResult] + self, flow: data_entry_flow.FlowHandler[ConfigFlowResult, str] ) -> None: """Set up preview for an option flow handler.""" entry = self._async_get_config_entry(flow.handler) diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 3005c21c272..4334ec2b274 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = { } -_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult") +_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult") +_HandlerT = TypeVar("_HandlerT", default=str) @dataclass(slots=True) @@ -138,7 +139,7 @@ class AbortFlow(FlowError): self.description_placeholders = description_placeholders -class FlowResult(TypedDict, total=False): +class FlowResult(TypedDict, Generic[_HandlerT], total=False): """Typed result dict.""" context: dict[str, Any] @@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False): errors: dict[str, str] | None extra: str flow_id: Required[str] - handler: Required[str] + handler: Required[_HandlerT] last_step: bool | None menu_options: list[str] | dict[str, str] options: Mapping[str, Any] @@ -189,7 +190,7 @@ def _map_error_to_schema_errors( schema_errors[path_part_str] = error.error_message -class FlowManager(abc.ABC, Generic[_FlowResultT]): +class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): """Manage all the flows that are in progress.""" _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment] @@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): ) -> None: """Initialize the flow manager.""" self.hass = hass - self._preview: set[str] = set() - self._progress: dict[str, FlowHandler[_FlowResultT]] = {} - self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {} - self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {} + self._preview: set[_HandlerT] = set() + self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {} + self._handler_progress_index: dict[ + _HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]] + ] = {} + self._init_data_process_index: dict[ + type, set[FlowHandler[_FlowResultT, _HandlerT]] + ] = {} @abc.abstractmethod async def async_create_flow( self, - handler_key: str, + handler_key: _HandlerT, *, context: dict[str, Any] | None = None, data: dict[str, Any] | None = None, - ) -> FlowHandler[_FlowResultT]: + ) -> FlowHandler[_FlowResultT, _HandlerT]: """Create a flow for specified handler. Handler key is the domain of the component that we want to set up. @@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): @abc.abstractmethod async def async_finish_flow( - self, flow: FlowHandler[_FlowResultT], result: _FlowResultT + self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT ) -> _FlowResultT: """Finish a data entry flow.""" async def async_post_init( - self, flow: FlowHandler[_FlowResultT], result: _FlowResultT + self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT ) -> None: """Entry has finished executing its first step asynchronously.""" @callback def async_has_matching_flow( - self, handler: str, match_context: dict[str, Any], data: Any + self, handler: _HandlerT, match_context: dict[str, Any], data: Any ) -> bool: """Check if an existing matching flow is in progress. @@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): @callback def async_progress_by_handler( self, - handler: str, + handler: _HandlerT, include_uninitialized: bool = False, match_context: dict[str, Any] | None = None, ) -> list[_FlowResultT]: @@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): @callback def _async_progress_by_handler( - self, handler: str, match_context: dict[str, Any] | None - ) -> list[FlowHandler[_FlowResultT]]: + self, handler: _HandlerT, match_context: dict[str, Any] | None + ) -> list[FlowHandler[_FlowResultT, _HandlerT]]: """Return the flows in progress by handler. If match_context is specified, only return flows with a context that @@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): ] async def async_init( - self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None + self, + handler: _HandlerT, + *, + context: dict[str, Any] | None = None, + data: Any = None, ) -> _FlowResultT: """Start a data entry flow.""" if context is None: @@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): self._async_remove_flow_progress(flow_id) @callback - def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None: + def _async_add_flow_progress( + self, flow: FlowHandler[_FlowResultT, _HandlerT] + ) -> None: """Add a flow to in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) @@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): self._handler_progress_index.setdefault(flow.handler, set()).add(flow) @callback - def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None: + def _async_remove_flow_from_index( + self, flow: FlowHandler[_FlowResultT, _HandlerT] + ) -> None: """Remove a flow from in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) @@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): async def _async_handle_step( self, - flow: FlowHandler[_FlowResultT], + flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str, user_input: dict | BaseServiceInfo | None, ) -> _FlowResultT: @@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): return result def _raise_if_step_does_not_exist( - self, flow: FlowHandler[_FlowResultT], step_id: str + self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str ) -> None: """Raise if the step does not exist.""" method = f"async_step_{step_id}" @@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): f"Handler {self.__class__.__name__} doesn't support step {step_id}" ) - async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None: + async def _async_setup_preview( + self, flow: FlowHandler[_FlowResultT, _HandlerT] + ) -> None: """Set up preview for a flow handler.""" if flow.handler not in self._preview: self._preview.add(flow.handler) @@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): @callback def _async_flow_handler_to_flow_result( - self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool + self, + flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]], + include_uninitialized: bool, ) -> list[_FlowResultT]: """Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" results = [] @@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]): return results -class FlowHandler(Generic[_FlowResultT]): +class FlowHandler(Generic[_FlowResultT, _HandlerT]): """Handle a data entry flow.""" _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment] @@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]): # and removes the need for constant None checks or asserts. flow_id: str = None # type: ignore[assignment] hass: HomeAssistant = None # type: ignore[assignment] - handler: str = None # type: ignore[assignment] + handler: _HandlerT = None # type: ignore[assignment] # Ensure the attribute has a subscriptable, but immutable, default value. context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment] diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index 21ebab7b4eb..524fd4ebf7d 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -17,7 +17,7 @@ from . import config_validation as cv _FlowManagerT = TypeVar( "_FlowManagerT", - bound=data_entry_flow.FlowManager[Any], + bound="data_entry_flow.FlowManager[Any]", default=data_entry_flow.FlowManager, ) @@ -61,7 +61,7 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]): @RequestDataValidator( vol.Schema( { - vol.Required("handler"): vol.Any(str, list), + vol.Required("handler"): str, vol.Optional("show_advanced_options", default=False): cv.boolean, }, extra=vol.ALLOW_EXTRA, @@ -79,14 +79,9 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]): self, request: web.Request, data: dict[str, Any] ) -> web.Response: """Handle a POST request.""" - if isinstance(data["handler"], list): - handler = tuple(data["handler"]) - else: - handler = data["handler"] - try: result = await self._flow_mgr.async_init( - handler, # type: ignore[arg-type] + data["handler"], context=self.get_context(data), ) except data_entry_flow.UnknownHandler: