Make FlowHandler.context a typed dict (#126291)

* Make FlowHandler.context a typed dict

* Adjust typing

* Adjust typing

* Avoid calling ConfigFlowContext constructor in hot path
This commit is contained in:
Erik Montnemery 2024-10-08 12:18:45 +02:00 committed by GitHub
parent 217165208b
commit d6ee10a543
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 175 additions and 99 deletions

View file

@ -12,7 +12,6 @@ from typing import Any, cast
import jwt import jwt
from homeassistant import data_entry_flow
from homeassistant.core import ( from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
HassJob, HassJob,
@ -20,13 +19,14 @@ from homeassistant.core import (
HomeAssistant, HomeAssistant,
callback, callback,
) )
from homeassistant.data_entry_flow import FlowHandler, FlowManager, FlowResultType
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult from .models import AuthFlowContext, AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config from .providers import AuthProvider, LoginFlow, auth_provider_from_config
from .providers.homeassistant import HassAuthProvider from .providers.homeassistant import HassAuthProvider
@ -98,7 +98,7 @@ async def auth_manager_from_config(
class AuthManagerFlowManager( class AuthManagerFlowManager(
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]] FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]]
): ):
"""Manage authentication flows.""" """Manage authentication flows."""
@ -113,7 +113,7 @@ class AuthManagerFlowManager(
self, self,
handler_key: tuple[str, str], handler_key: tuple[str, str],
*, *,
context: dict[str, Any] | None = None, context: AuthFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> LoginFlow: ) -> LoginFlow:
"""Create a login flow.""" """Create a login flow."""
@ -124,7 +124,7 @@ class AuthManagerFlowManager(
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]], flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
result: AuthFlowResult, result: AuthFlowResult,
) -> AuthFlowResult: ) -> AuthFlowResult:
"""Return a user as result of login flow. """Return a user as result of login flow.
@ -134,7 +134,7 @@ class AuthManagerFlowManager(
""" """
flow = cast(LoginFlow, flow) flow = cast(LoginFlow, flow)
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] != FlowResultType.CREATE_ENTRY:
return result return result
# we got final result # we got final result

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from ipaddress import IPv4Address, IPv6Address
import secrets import secrets
from typing import Any, NamedTuple from typing import Any, NamedTuple
import uuid import uuid
@ -13,7 +14,7 @@ from attr.setters import validate
from propcache import cached_property from propcache import cached_property
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowContext, FlowResult
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl from . import permissions as perm_mdl
@ -23,7 +24,16 @@ TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = "system" TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token" TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
AuthFlowResult = FlowResult[tuple[str, str]]
class AuthFlowContext(FlowContext, total=False):
"""Typed context dict for auth flow."""
credential_only: bool
ip_address: IPv4Address | IPv6Address
redirect_uri: str
AuthFlowResult = FlowResult[AuthFlowContext, tuple[str, str]]
@attr.s(slots=True) @attr.s(slots=True)

View file

@ -10,9 +10,10 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements from homeassistant import requirements
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowHandler
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.importlib import async_import_module from homeassistant.helpers.importlib import async_import_module
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -21,7 +22,14 @@ from homeassistant.util.hass_dict import HassKey
from ..auth_store import AuthStore from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION from ..const import MFA_SESSION_EXPIRATION
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta from ..models import (
AuthFlowContext,
AuthFlowResult,
Credentials,
RefreshToken,
User,
UserMeta,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed") DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
@ -97,7 +105,7 @@ class AuthProvider:
# Implement by extending class # Implement by extending class
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return the data flow for logging in with auth provider. """Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance. Auth provider should extend LoginFlow and return an instance.
@ -184,7 +192,7 @@ async def load_auth_provider_module(
return module return module
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]): class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
"""Handler for the login flow.""" """Handler for the login flow."""
_flow_result = AuthFlowResult _flow_result = AuthFlowResult

View file

@ -13,7 +13,7 @@ import voluptuous as vol
from homeassistant.const import CONF_COMMAND from homeassistant.const import CONF_COMMAND
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from ..models import AuthFlowResult, Credentials, UserMeta from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
CONF_ARGS = "args" CONF_ARGS = "args"
@ -59,7 +59,7 @@ class CommandLineAuthProvider(AuthProvider):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._user_meta: dict[str, dict[str, Any]] = {} self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return CommandLineLoginFlow(self) return CommandLineLoginFlow(self)

View file

@ -17,7 +17,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from ..models import AuthFlowResult, Credentials, UserMeta from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load() await data.async_load()
self.data = data self.data = data
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return HassLoginFlow(self) return HassLoginFlow(self)

View file

@ -4,14 +4,14 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import hmac import hmac
from typing import Any, cast from typing import cast
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from ..models import AuthFlowResult, Credentials, UserMeta from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
USER_SCHEMA = vol.Schema( USER_SCHEMA = vol.Schema(
@ -36,7 +36,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider): class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords.""" """Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return ExampleLoginFlow(self) return ExampleLoginFlow(self)

View file

@ -25,7 +25,13 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import is_cloud_connection from homeassistant.helpers.network import is_cloud_connection
from .. import InvalidAuthError from .. import InvalidAuthError
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta from ..models import (
AuthFlowContext,
AuthFlowResult,
Credentials,
RefreshToken,
UserMeta,
)
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
type IPAddress = IPv4Address | IPv6Address type IPAddress = IPv4Address | IPv6Address
@ -98,7 +104,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA.""" """Trusted Networks auth provider does not support MFA."""
return False return False
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address")) ip_addr = cast(IPAddress, context.get("ip_address"))

View file

@ -80,7 +80,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import AuthFlowResult, Credentials from homeassistant.auth.models import AuthFlowContext, AuthFlowResult, Credentials
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.components.http import KEY_HASS from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.auth import async_user_not_allowed_do_auth
@ -322,11 +322,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
try: try:
result = await self._flow_mgr.async_init( result = await self._flow_mgr.async_init(
handler, handler,
context={ context=AuthFlowContext(
"ip_address": ip_address(request.remote), # type: ignore[arg-type] ip_address=ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user", credential_only=data.get("type") == "link_user",
"redirect_uri": redirect_uri, redirect_uri=redirect_uri,
}, ),
) )
except data_entry_flow.UnknownHandler: except data_entry_flow.UnknownHandler:
return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND) return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND)

View file

@ -11,6 +11,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowContext
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
@ -44,7 +45,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
self, self,
handler_key: str, handler_key: str,
*, *,
context: dict[str, Any], context: FlowContext | None,
data: dict[str, Any], data: dict[str, Any],
) -> data_entry_flow.FlowHandler: ) -> data_entry_flow.FlowHandler:
"""Create a setup flow. handler is a mfa module.""" """Create a setup flow. handler is a mfa module."""

View file

@ -463,7 +463,7 @@ async def ignore_config_flow(
) )
return return
context = {"source": config_entries.SOURCE_IGNORE} context = config_entries.ConfigFlowContext(source=config_entries.SOURCE_IGNORE)
if "discovery_key" in flow["context"]: if "discovery_key" in flow["context"]:
context["discovery_key"] = flow["context"]["discovery_key"] context["discovery_key"] = flow["context"]["discovery_key"]
await hass.config_entries.flow.async_init( await hass.config_entries.flow.async_init(

View file

@ -12,7 +12,13 @@ from homeassistant.components.homeassistant_hardware import (
firmware_config_flow, firmware_config_flow,
silabs_multiprotocol_addon, silabs_multiprotocol_addon,
) )
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow from homeassistant.config_entries import (
ConfigEntry,
ConfigEntryBaseFlow,
ConfigFlowContext,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.core import callback from homeassistant.core import callback
from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant
@ -33,10 +39,10 @@ else:
TranslationPlaceholderProtocol = object TranslationPlaceholderProtocol = object
class SkyConnectTranslationMixin(TranslationPlaceholderProtocol): class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol):
"""Translation placeholder mixin for Home Assistant SkyConnect.""" """Translation placeholder mixin for Home Assistant SkyConnect."""
context: dict[str, Any] context: ConfigFlowContext
def _get_translation_placeholders(self) -> dict[str, str]: def _get_translation_placeholders(self) -> dict[str, str]:
"""Shared translation placeholders.""" """Shared translation placeholders."""

View file

@ -53,7 +53,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
self, self,
handler_key: str, handler_key: str,
*, *,
context: dict[str, Any] | None = None, context: data_entry_flow.FlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> RepairsFlow: ) -> RepairsFlow:
"""Create a flow. platform is a repairs module.""" """Create a flow. platform is a repairs module."""

View file

@ -378,7 +378,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
for flow in _config_entries.flow.async_progress_by_handler( for flow in _config_entries.flow.async_progress_by_handler(
DOMAIN, include_uninitialized=True DOMAIN, include_uninitialized=True
): ):
context: dict[str, Any] = flow["context"] context = flow["context"]
if context.get("source") != SOURCE_REAUTH: if context.get("source") != SOURCE_REAUTH:
continue continue
entry_id: str = context["entry_id"] entry_id: str = context["entry_id"]

View file

@ -540,7 +540,9 @@ class ZeroconfDiscovery:
continue continue
matcher_domain = matcher[ATTR_DOMAIN] matcher_domain = matcher[ATTR_DOMAIN]
context = { # Create a type annotated regular dict since this is a hot path and creating
# a regular dict is slightly cheaper than calling ConfigFlowContext
context: config_entries.ConfigFlowContext = {
"source": config_entries.SOURCE_ZEROCONF, "source": config_entries.SOURCE_ZEROCONF,
} }
if domain: if domain:

View file

@ -29,6 +29,7 @@ from homeassistant.config_entries import (
ConfigEntryBaseFlow, ConfigEntryBaseFlow,
ConfigEntryState, ConfigEntryState,
ConfigFlow, ConfigFlow,
ConfigFlowContext,
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlow,
OptionsFlowManager, OptionsFlowManager,
@ -192,7 +193,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property @property
@abstractmethod @abstractmethod
def flow_manager(self) -> FlowManager[ConfigFlowResult]: def flow_manager(self) -> FlowManager[ConfigFlowContext, ConfigFlowResult]:
"""Return the flow manager of the flow.""" """Return the flow manager of the flow."""
async def async_step_install_addon( async def async_step_install_addon(

View file

@ -41,7 +41,7 @@ from .core import (
HomeAssistant, HomeAssistant,
callback, callback,
) )
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowResult from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowContext, FlowResult
from .exceptions import ( from .exceptions import (
ConfigEntryAuthFailed, ConfigEntryAuthFailed,
ConfigEntryError, ConfigEntryError,
@ -267,7 +267,19 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
} }
class ConfigFlowResult(FlowResult, total=False): class ConfigFlowContext(FlowContext, total=False):
"""Typed context dict for config flow."""
alternative_domain: str
configuration_url: str
confirm_only: bool
discovery_key: DiscoveryKey
entry_id: str
title_placeholders: Mapping[str, str]
unique_id: str | None
class ConfigFlowResult(FlowResult[ConfigFlowContext, str], total=False):
"""Typed result dict for config flow.""" """Typed result dict for config flow."""
minor_version: int minor_version: int
@ -1026,7 +1038,7 @@ class ConfigEntry(Generic[_DataT]):
def async_start_reauth( def async_start_reauth(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
context: dict[str, Any] | None = None, context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Start a reauth flow.""" """Start a reauth flow."""
@ -1044,7 +1056,7 @@ class ConfigEntry(Generic[_DataT]):
async def _async_init_reauth( async def _async_init_reauth(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
context: dict[str, Any] | None = None, context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Start a reauth flow.""" """Start a reauth flow."""
@ -1056,12 +1068,12 @@ class ConfigEntry(Generic[_DataT]):
return return
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
self.domain, self.domain,
context={ context=ConfigFlowContext(
"source": SOURCE_REAUTH, source=SOURCE_REAUTH,
"entry_id": self.entry_id, entry_id=self.entry_id,
"title_placeholders": {"name": self.title}, title_placeholders={"name": self.title},
"unique_id": self.unique_id, unique_id=self.unique_id,
} )
| (context or {}), | (context or {}),
data=self.data | (data or {}), data=self.data | (data or {}),
) )
@ -1086,7 +1098,7 @@ class ConfigEntry(Generic[_DataT]):
def async_start_reconfigure( def async_start_reconfigure(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
context: dict[str, Any] | None = None, context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Start a reconfigure flow.""" """Start a reconfigure flow."""
@ -1103,7 +1115,7 @@ class ConfigEntry(Generic[_DataT]):
async def _async_init_reconfigure( async def _async_init_reconfigure(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
context: dict[str, Any] | None = None, context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Start a reconfigure flow.""" """Start a reconfigure flow."""
@ -1115,12 +1127,12 @@ class ConfigEntry(Generic[_DataT]):
return return
await hass.config_entries.flow.async_init( await hass.config_entries.flow.async_init(
self.domain, self.domain,
context={ context=ConfigFlowContext(
"source": SOURCE_RECONFIGURE, source=SOURCE_RECONFIGURE,
"entry_id": self.entry_id, entry_id=self.entry_id,
"title_placeholders": {"name": self.title}, title_placeholders={"name": self.title},
"unique_id": self.unique_id, unique_id=self.unique_id,
} )
| (context or {}), | (context or {}),
data=self.data | (data or {}), data=self.data | (data or {}),
) )
@ -1214,7 +1226,9 @@ def _report_non_awaited_platform_forwards(entry: ConfigEntry, what: str) -> None
) )
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): class ConfigEntriesFlowManager(
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
):
"""Manage all the config entry flows that are in progress.""" """Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -1260,7 +1274,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return False return False
async def async_init( async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self,
handler: str,
*,
context: ConfigFlowContext | None = None,
data: Any = None,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Start a configuration flow.""" """Start a configuration flow."""
if not context or "source" not in context: if not context or "source" not in context:
@ -1319,7 +1337,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
self, self,
flow_id: str, flow_id: str,
handler: str, handler: str,
context: dict, context: ConfigFlowContext,
data: Any, data: Any,
) -> tuple[ConfigFlow, ConfigFlowResult]: ) -> tuple[ConfigFlow, ConfigFlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown.""" """Run the init in a task to allow it to be canceled at shutdown."""
@ -1357,7 +1375,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish a config flow and add an entry. """Finish a config flow and add an entry.
@ -1504,7 +1522,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result return result
async def async_create_flow( async def async_create_flow(
self, handler_key: str, *, context: dict | None = None, data: Any = None self,
handler_key: str,
*,
context: ConfigFlowContext | None = None,
data: Any = None,
) -> ConfigFlow: ) -> ConfigFlow:
"""Create a flow for specified handler. """Create a flow for specified handler.
@ -1522,7 +1544,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_post_init( async def async_post_init(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> None: ) -> None:
"""After a flow is initialised trigger new flow notifications.""" """After a flow is initialised trigger new flow notifications."""
@ -1560,7 +1582,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
@callback @callback
def async_has_matching_discovery_flow( def async_has_matching_discovery_flow(
self, handler: str, match_context: dict[str, Any], data: Any self, handler: str, match_context: ConfigFlowContext, data: Any
) -> bool: ) -> bool:
"""Check if an existing matching discovery flow is in progress. """Check if an existing matching discovery flow is in progress.
@ -2385,7 +2407,9 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured") raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]): class ConfigEntryBaseFlow(
data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
):
"""Base class for config and option flows.""" """Base class for config and option flows."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2406,7 +2430,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
if not self.context: if not self.context:
return None return None
return cast(str | None, self.context.get("unique_id")) return self.context.get("unique_id")
@staticmethod @staticmethod
@callback @callback
@ -2779,7 +2803,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
"""Return reauth entry id.""" """Return reauth entry id."""
if self.source != SOURCE_REAUTH: if self.source != SOURCE_REAUTH:
raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}") raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}")
return self.context["entry_id"] # type: ignore[no-any-return] return self.context["entry_id"]
@callback @callback
def _get_reauth_entry(self) -> ConfigEntry: def _get_reauth_entry(self) -> ConfigEntry:
@ -2793,7 +2817,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
"""Return reconfigure entry id.""" """Return reconfigure entry id."""
if self.source != SOURCE_RECONFIGURE: if self.source != SOURCE_RECONFIGURE:
raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}") raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}")
return self.context["entry_id"] # type: ignore[no-any-return] return self.context["entry_id"]
@callback @callback
def _get_reconfigure_entry(self) -> ConfigEntry: def _get_reconfigure_entry(self) -> ConfigEntry:
@ -2805,7 +2829,9 @@ class ConfigFlow(ConfigEntryBaseFlow):
raise UnknownEntry raise UnknownEntry
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): class OptionsFlowManager(
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2822,7 +2848,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
self, self,
handler_key: str, handler_key: str,
*, *,
context: dict[str, Any] | None = None, context: ConfigFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> OptionsFlow: ) -> OptionsFlow:
"""Create an options flow for a config entry. """Create an options flow for a config entry.
@ -2835,7 +2861,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
@ -2860,7 +2886,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result return result
async def _async_setup_preview( async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult] self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
) -> None: ) -> None:
"""Set up preview for an option flow handler.""" """Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler) entry = self._async_get_config_entry(flow.handler)

View file

@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
} }
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult") _FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
_FlowResultT = TypeVar(
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
)
_HandlerT = TypeVar("_HandlerT", default=str) _HandlerT = TypeVar("_HandlerT", default=str)
@ -139,10 +142,17 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders self.description_placeholders = description_placeholders
class FlowResult(TypedDict, Generic[_HandlerT], total=False): class FlowContext(TypedDict, total=False):
"""Typed context dict."""
show_advanced_options: bool
source: str
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
"""Typed result dict.""" """Typed result dict."""
context: dict[str, Any] context: _FlowContextT
data_schema: vol.Schema | None data_schema: vol.Schema | None
data: Mapping[str, Any] data: Mapping[str, Any]
description_placeholders: Mapping[str, str | None] | None description_placeholders: Mapping[str, str | None] | None
@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]): class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment] _flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._preview: set[_HandlerT] = set() self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {} self._progress: dict[
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
] = {}
self._handler_progress_index: defaultdict[ self._handler_progress_index: defaultdict[
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]] _HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set) ] = defaultdict(set)
self._init_data_process_index: defaultdict[ self._init_data_process_index: defaultdict[
type, set[FlowHandler[_FlowResultT, _HandlerT]] type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set) ] = defaultdict(set)
@abc.abstractmethod @abc.abstractmethod
@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self, self,
handler_key: _HandlerT, handler_key: _HandlerT,
*, *,
context: dict[str, Any] | None = None, context: _FlowContextT | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT, _HandlerT]: ) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
"""Create a flow for specified handler. """Create a flow for specified handler.
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@abc.abstractmethod @abc.abstractmethod
async def async_finish_flow( async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> _FlowResultT: ) -> _FlowResultT:
"""Finish a data entry flow. """Finish a data entry flow.
@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
""" """
async def async_post_init( async def async_post_init(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> None: ) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback @callback
def _async_progress_by_handler( def _async_progress_by_handler(
self, handler: _HandlerT, match_context: dict[str, Any] | None self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT, _HandlerT]]: ) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler. """Return the flows in progress by handler.
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self, self,
handler: _HandlerT, handler: _HandlerT,
*, *,
context: dict[str, Any] | None = None, context: _FlowContextT | None = None,
data: Any = None, data: Any = None,
) -> _FlowResultT: ) -> _FlowResultT:
"""Start a data entry flow.""" """Start a data entry flow."""
if context is None: if context is None:
context = {} context = cast(_FlowContextT, {})
flow = await self.async_create_flow(handler, context=context, data=data) flow = await self.async_create_flow(handler, context=context, data=data)
if not flow: if not flow:
raise UnknownFlow("Flow was not created") raise UnknownFlow("Flow was not created")
@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback @callback
def _async_add_flow_progress( def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT] self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None: ) -> None:
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None: if flow.init_data is not None:
@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback @callback
def _async_remove_flow_from_index( def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT] self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None: ) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if flow.init_data is not None: if flow.init_data is not None:
@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
async def _async_handle_step( async def _async_handle_step(
self, self,
flow: FlowHandler[_FlowResultT, _HandlerT], flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
step_id: str, step_id: str,
user_input: dict | BaseServiceInfo | None, user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT: ) -> _FlowResultT:
@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
return result return result
def _raise_if_step_does_not_exist( def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
) -> None: ) -> None:
"""Raise if the step does not exist.""" """Raise if the step does not exist."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
) )
async def _async_setup_preview( async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT] self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None: ) -> None:
"""Set up preview for a flow handler.""" """Set up preview for a flow handler."""
if flow.handler not in self._preview: if flow.handler not in self._preview:
@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback @callback
def _async_flow_handler_to_flow_result( def _async_flow_handler_to_flow_result(
self, self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]], flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
include_uninitialized: bool, include_uninitialized: bool,
) -> list[_FlowResultT]: ) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" """Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
] ]
class FlowHandler(Generic[_FlowResultT, _HandlerT]): class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Handle a data entry flow.""" """Handle a data entry flow."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment] _flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
hass: HomeAssistant = None # type: ignore[assignment] hass: HomeAssistant = None # type: ignore[assignment]
handler: _HandlerT = None # type: ignore[assignment] handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value. # Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment] context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
# Set by _async_create_flow callback # Set by _async_create_flow callback
init_step = "init" init_step = "init"
@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
@property @property
def source(self) -> str | None: def source(self) -> str | None:
"""Source that initialized the flow.""" """Source that initialized the flow."""
return self.context.get("source", None) # type: ignore[no-any-return] return self.context.get("source", None) # type: ignore[return-value]
@property @property
def show_advanced_options(self) -> bool: def show_advanced_options(self) -> bool:
"""If we should show advanced options.""" """If we should show advanced options."""
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return] return self.context.get("show_advanced_options", False) # type: ignore[return-value]
def add_suggested_values_to_schema( def add_suggested_values_to_schema(
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None

View file

@ -18,7 +18,7 @@ from . import config_validation as cv
_FlowManagerT = TypeVar( _FlowManagerT = TypeVar(
"_FlowManagerT", "_FlowManagerT",
bound=data_entry_flow.FlowManager[Any], bound=data_entry_flow.FlowManager[Any, Any],
default=data_entry_flow.FlowManager, default=data_entry_flow.FlowManager,
) )

View file

@ -13,7 +13,7 @@ from homeassistant.util.async_ import gather_with_limited_concurrency
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigFlowResult from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult
FLOW_INIT_LIMIT = 20 FLOW_INIT_LIMIT = 20
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey( DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
@ -42,7 +42,7 @@ class DiscoveryKey:
def async_create_flow( def async_create_flow(
hass: HomeAssistant, hass: HomeAssistant,
domain: str, domain: str,
context: dict[str, Any], context: ConfigFlowContext,
data: Any, data: Any,
*, *,
discovery_key: DiscoveryKey | None = None, discovery_key: DiscoveryKey | None = None,
@ -70,7 +70,7 @@ def async_create_flow(
@callback @callback
def _async_init_flow( def _async_init_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any
) -> Coroutine[None, None, ConfigFlowResult] | None: ) -> Coroutine[None, None, ConfigFlowResult] | None:
"""Create a discovery flow.""" """Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data # Avoid spawning flows that have the same initial discovery data
@ -98,7 +98,7 @@ class PendingFlowKey(NamedTuple):
class PendingFlowValue(NamedTuple): class PendingFlowValue(NamedTuple):
"""Value for pending flows.""" """Value for pending flows."""
context: dict[str, Any] context: ConfigFlowContext
data: Any data: Any
@ -137,7 +137,7 @@ class FlowDispatcher:
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros) await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
@callback @callback
def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None: def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None:
"""Create and add or queue a flow.""" """Create and add or queue a flow."""
key = PendingFlowKey(domain, context["source"]) key = PendingFlowKey(domain, context["source"])
values = PendingFlowValue(context, data) values = PendingFlowValue(context, data)