Make FlowHandler.context a typed dict (#126291)
* Make FlowHandler.context a typed dict * Adjust typing * Adjust typing * Avoid calling ConfigFlowContext constructor in hot path
This commit is contained in:
parent
217165208b
commit
d6ee10a543
19 changed files with 175 additions and 99 deletions
|
@ -12,7 +12,6 @@ from typing import Any, cast
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
CALLBACK_TYPE,
|
CALLBACK_TYPE,
|
||||||
HassJob,
|
HassJob,
|
||||||
|
@ -20,13 +19,14 @@ from homeassistant.core import (
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
|
from homeassistant.data_entry_flow import FlowHandler, FlowManager, FlowResultType
|
||||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import auth_store, jwt_wrapper, models
|
from . import auth_store, jwt_wrapper, models
|
||||||
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
|
||||||
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
|
||||||
from .models import AuthFlowResult
|
from .models import AuthFlowContext, AuthFlowResult
|
||||||
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
from .providers import AuthProvider, LoginFlow, auth_provider_from_config
|
||||||
from .providers.homeassistant import HassAuthProvider
|
from .providers.homeassistant import HassAuthProvider
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ async def auth_manager_from_config(
|
||||||
|
|
||||||
|
|
||||||
class AuthManagerFlowManager(
|
class AuthManagerFlowManager(
|
||||||
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
|
FlowManager[AuthFlowContext, AuthFlowResult, tuple[str, str]]
|
||||||
):
|
):
|
||||||
"""Manage authentication flows."""
|
"""Manage authentication flows."""
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ class AuthManagerFlowManager(
|
||||||
self,
|
self,
|
||||||
handler_key: tuple[str, str],
|
handler_key: tuple[str, str],
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any] | None = None,
|
context: AuthFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> LoginFlow:
|
) -> LoginFlow:
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
|
@ -124,7 +124,7 @@ class AuthManagerFlowManager(
|
||||||
|
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self,
|
self,
|
||||||
flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
|
flow: FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
|
||||||
result: AuthFlowResult,
|
result: AuthFlowResult,
|
||||||
) -> AuthFlowResult:
|
) -> AuthFlowResult:
|
||||||
"""Return a user as result of login flow.
|
"""Return a user as result of login flow.
|
||||||
|
@ -134,7 +134,7 @@ class AuthManagerFlowManager(
|
||||||
"""
|
"""
|
||||||
flow = cast(LoginFlow, flow)
|
flow = cast(LoginFlow, flow)
|
||||||
|
|
||||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
if result["type"] != FlowResultType.CREATE_ENTRY:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# we got final result
|
# we got final result
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -13,7 +14,7 @@ from attr.setters import validate
|
||||||
from propcache import cached_property
|
from propcache import cached_property
|
||||||
|
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowContext, FlowResult
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import permissions as perm_mdl
|
from . import permissions as perm_mdl
|
||||||
|
@ -23,7 +24,16 @@ TOKEN_TYPE_NORMAL = "normal"
|
||||||
TOKEN_TYPE_SYSTEM = "system"
|
TOKEN_TYPE_SYSTEM = "system"
|
||||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
||||||
|
|
||||||
AuthFlowResult = FlowResult[tuple[str, str]]
|
|
||||||
|
class AuthFlowContext(FlowContext, total=False):
|
||||||
|
"""Typed context dict for auth flow."""
|
||||||
|
|
||||||
|
credential_only: bool
|
||||||
|
ip_address: IPv4Address | IPv6Address
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
|
|
||||||
|
AuthFlowResult = FlowResult[AuthFlowContext, tuple[str, str]]
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
|
|
|
@ -10,9 +10,10 @@ from typing import Any
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant import data_entry_flow, requirements
|
from homeassistant import requirements
|
||||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.data_entry_flow import FlowHandler
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.importlib import async_import_module
|
from homeassistant.helpers.importlib import async_import_module
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
@ -21,7 +22,14 @@ from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
from ..auth_store import AuthStore
|
from ..auth_store import AuthStore
|
||||||
from ..const import MFA_SESSION_EXPIRATION
|
from ..const import MFA_SESSION_EXPIRATION
|
||||||
from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
|
from ..models import (
|
||||||
|
AuthFlowContext,
|
||||||
|
AuthFlowResult,
|
||||||
|
Credentials,
|
||||||
|
RefreshToken,
|
||||||
|
User,
|
||||||
|
UserMeta,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
|
DATA_REQS: HassKey[set[str]] = HassKey("auth_prov_reqs_processed")
|
||||||
|
@ -97,7 +105,7 @@ class AuthProvider:
|
||||||
|
|
||||||
# Implement by extending class
|
# Implement by extending class
|
||||||
|
|
||||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||||
"""Return the data flow for logging in with auth provider.
|
"""Return the data flow for logging in with auth provider.
|
||||||
|
|
||||||
Auth provider should extend LoginFlow and return an instance.
|
Auth provider should extend LoginFlow and return an instance.
|
||||||
|
@ -184,7 +192,7 @@ async def load_auth_provider_module(
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
|
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
_flow_result = AuthFlowResult
|
_flow_result = AuthFlowResult
|
||||||
|
|
|
@ -13,7 +13,7 @@ import voluptuous as vol
|
||||||
from homeassistant.const import CONF_COMMAND
|
from homeassistant.const import CONF_COMMAND
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
|
|
||||||
CONF_ARGS = "args"
|
CONF_ARGS = "args"
|
||||||
|
@ -59,7 +59,7 @@ class CommandLineAuthProvider(AuthProvider):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._user_meta: dict[str, dict[str, Any]] = {}
|
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return CommandLineLoginFlow(self)
|
return CommandLineLoginFlow(self)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import issue_registry as ir
|
from homeassistant.helpers import issue_registry as ir
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
|
|
||||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
|
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
|
||||||
await data.async_load()
|
await data.async_load()
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return HassLoginFlow(self)
|
return HassLoginFlow(self)
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, cast
|
from typing import cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from ..models import AuthFlowResult, Credentials, UserMeta
|
from ..models import AuthFlowContext, AuthFlowResult, Credentials, UserMeta
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
|
|
||||||
USER_SCHEMA = vol.Schema(
|
USER_SCHEMA = vol.Schema(
|
||||||
|
@ -36,7 +36,7 @@ class InvalidAuthError(HomeAssistantError):
|
||||||
class ExampleAuthProvider(AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return ExampleLoginFlow(self)
|
return ExampleLoginFlow(self)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,13 @@ import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.network import is_cloud_connection
|
from homeassistant.helpers.network import is_cloud_connection
|
||||||
|
|
||||||
from .. import InvalidAuthError
|
from .. import InvalidAuthError
|
||||||
from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
|
from ..models import (
|
||||||
|
AuthFlowContext,
|
||||||
|
AuthFlowResult,
|
||||||
|
Credentials,
|
||||||
|
RefreshToken,
|
||||||
|
UserMeta,
|
||||||
|
)
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
|
|
||||||
type IPAddress = IPv4Address | IPv6Address
|
type IPAddress = IPv4Address | IPv6Address
|
||||||
|
@ -98,7 +104,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
"""Trusted Networks auth provider does not support MFA."""
|
"""Trusted Networks auth provider does not support MFA."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
assert context is not None
|
assert context is not None
|
||||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||||
|
|
|
@ -80,7 +80,7 @@ import voluptuous_serialize
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
|
||||||
from homeassistant.auth.models import AuthFlowResult, Credentials
|
from homeassistant.auth.models import AuthFlowContext, AuthFlowResult, Credentials
|
||||||
from homeassistant.components import onboarding
|
from homeassistant.components import onboarding
|
||||||
from homeassistant.components.http import KEY_HASS
|
from homeassistant.components.http import KEY_HASS
|
||||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||||
|
@ -322,11 +322,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
||||||
try:
|
try:
|
||||||
result = await self._flow_mgr.async_init(
|
result = await self._flow_mgr.async_init(
|
||||||
handler,
|
handler,
|
||||||
context={
|
context=AuthFlowContext(
|
||||||
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
|
ip_address=ip_address(request.remote), # type: ignore[arg-type]
|
||||||
"credential_only": data.get("type") == "link_user",
|
credential_only=data.get("type") == "link_user",
|
||||||
"redirect_uri": redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
},
|
),
|
||||||
)
|
)
|
||||||
except data_entry_flow.UnknownHandler:
|
except data_entry_flow.UnknownHandler:
|
||||||
return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND)
|
return self.json_message("Invalid handler specified", HTTPStatus.NOT_FOUND)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import voluptuous_serialize
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.data_entry_flow import FlowContext
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
|
||||||
self,
|
self,
|
||||||
handler_key: str,
|
handler_key: str,
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any],
|
context: FlowContext | None,
|
||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
) -> data_entry_flow.FlowHandler:
|
) -> data_entry_flow.FlowHandler:
|
||||||
"""Create a setup flow. handler is a mfa module."""
|
"""Create a setup flow. handler is a mfa module."""
|
||||||
|
|
|
@ -463,7 +463,7 @@ async def ignore_config_flow(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
context = {"source": config_entries.SOURCE_IGNORE}
|
context = config_entries.ConfigFlowContext(source=config_entries.SOURCE_IGNORE)
|
||||||
if "discovery_key" in flow["context"]:
|
if "discovery_key" in flow["context"]:
|
||||||
context["discovery_key"] = flow["context"]["discovery_key"]
|
context["discovery_key"] = flow["context"]["discovery_key"]
|
||||||
await hass.config_entries.flow.async_init(
|
await hass.config_entries.flow.async_init(
|
||||||
|
|
|
@ -12,7 +12,13 @@ from homeassistant.components.homeassistant_hardware import (
|
||||||
firmware_config_flow,
|
firmware_config_flow,
|
||||||
silabs_multiprotocol_addon,
|
silabs_multiprotocol_addon,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow
|
from homeassistant.config_entries import (
|
||||||
|
ConfigEntry,
|
||||||
|
ConfigEntryBaseFlow,
|
||||||
|
ConfigFlowContext,
|
||||||
|
ConfigFlowResult,
|
||||||
|
OptionsFlow,
|
||||||
|
)
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant
|
from .const import DOCS_WEB_FLASHER_URL, DOMAIN, HardwareVariant
|
||||||
|
@ -33,10 +39,10 @@ else:
|
||||||
TranslationPlaceholderProtocol = object
|
TranslationPlaceholderProtocol = object
|
||||||
|
|
||||||
|
|
||||||
class SkyConnectTranslationMixin(TranslationPlaceholderProtocol):
|
class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol):
|
||||||
"""Translation placeholder mixin for Home Assistant SkyConnect."""
|
"""Translation placeholder mixin for Home Assistant SkyConnect."""
|
||||||
|
|
||||||
context: dict[str, Any]
|
context: ConfigFlowContext
|
||||||
|
|
||||||
def _get_translation_placeholders(self) -> dict[str, str]:
|
def _get_translation_placeholders(self) -> dict[str, str]:
|
||||||
"""Shared translation placeholders."""
|
"""Shared translation placeholders."""
|
||||||
|
|
|
@ -53,7 +53,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
|
||||||
self,
|
self,
|
||||||
handler_key: str,
|
handler_key: str,
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any] | None = None,
|
context: data_entry_flow.FlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> RepairsFlow:
|
) -> RepairsFlow:
|
||||||
"""Create a flow. platform is a repairs module."""
|
"""Create a flow. platform is a repairs module."""
|
||||||
|
|
|
@ -378,7 +378,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
for flow in _config_entries.flow.async_progress_by_handler(
|
for flow in _config_entries.flow.async_progress_by_handler(
|
||||||
DOMAIN, include_uninitialized=True
|
DOMAIN, include_uninitialized=True
|
||||||
):
|
):
|
||||||
context: dict[str, Any] = flow["context"]
|
context = flow["context"]
|
||||||
if context.get("source") != SOURCE_REAUTH:
|
if context.get("source") != SOURCE_REAUTH:
|
||||||
continue
|
continue
|
||||||
entry_id: str = context["entry_id"]
|
entry_id: str = context["entry_id"]
|
||||||
|
|
|
@ -540,7 +540,9 @@ class ZeroconfDiscovery:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
matcher_domain = matcher[ATTR_DOMAIN]
|
matcher_domain = matcher[ATTR_DOMAIN]
|
||||||
context = {
|
# Create a type annotated regular dict since this is a hot path and creating
|
||||||
|
# a regular dict is slightly cheaper than calling ConfigFlowContext
|
||||||
|
context: config_entries.ConfigFlowContext = {
|
||||||
"source": config_entries.SOURCE_ZEROCONF,
|
"source": config_entries.SOURCE_ZEROCONF,
|
||||||
}
|
}
|
||||||
if domain:
|
if domain:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from homeassistant.config_entries import (
|
||||||
ConfigEntryBaseFlow,
|
ConfigEntryBaseFlow,
|
||||||
ConfigEntryState,
|
ConfigEntryState,
|
||||||
ConfigFlow,
|
ConfigFlow,
|
||||||
|
ConfigFlowContext,
|
||||||
ConfigFlowResult,
|
ConfigFlowResult,
|
||||||
OptionsFlow,
|
OptionsFlow,
|
||||||
OptionsFlowManager,
|
OptionsFlowManager,
|
||||||
|
@ -192,7 +193,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
|
def flow_manager(self) -> FlowManager[ConfigFlowContext, ConfigFlowResult]:
|
||||||
"""Return the flow manager of the flow."""
|
"""Return the flow manager of the flow."""
|
||||||
|
|
||||||
async def async_step_install_addon(
|
async def async_step_install_addon(
|
||||||
|
|
|
@ -41,7 +41,7 @@ from .core import (
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
callback,
|
callback,
|
||||||
)
|
)
|
||||||
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowResult
|
from .data_entry_flow import FLOW_NOT_COMPLETE_STEPS, FlowContext, FlowResult
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
ConfigEntryAuthFailed,
|
ConfigEntryAuthFailed,
|
||||||
ConfigEntryError,
|
ConfigEntryError,
|
||||||
|
@ -267,7 +267,19 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConfigFlowResult(FlowResult, total=False):
|
class ConfigFlowContext(FlowContext, total=False):
|
||||||
|
"""Typed context dict for config flow."""
|
||||||
|
|
||||||
|
alternative_domain: str
|
||||||
|
configuration_url: str
|
||||||
|
confirm_only: bool
|
||||||
|
discovery_key: DiscoveryKey
|
||||||
|
entry_id: str
|
||||||
|
title_placeholders: Mapping[str, str]
|
||||||
|
unique_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFlowResult(FlowResult[ConfigFlowContext, str], total=False):
|
||||||
"""Typed result dict for config flow."""
|
"""Typed result dict for config flow."""
|
||||||
|
|
||||||
minor_version: int
|
minor_version: int
|
||||||
|
@ -1026,7 +1038,7 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
def async_start_reauth(
|
def async_start_reauth(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
context: dict[str, Any] | None = None,
|
context: ConfigFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reauth flow."""
|
"""Start a reauth flow."""
|
||||||
|
@ -1044,7 +1056,7 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
async def _async_init_reauth(
|
async def _async_init_reauth(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
context: dict[str, Any] | None = None,
|
context: ConfigFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reauth flow."""
|
"""Start a reauth flow."""
|
||||||
|
@ -1056,12 +1068,12 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
return
|
return
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
self.domain,
|
self.domain,
|
||||||
context={
|
context=ConfigFlowContext(
|
||||||
"source": SOURCE_REAUTH,
|
source=SOURCE_REAUTH,
|
||||||
"entry_id": self.entry_id,
|
entry_id=self.entry_id,
|
||||||
"title_placeholders": {"name": self.title},
|
title_placeholders={"name": self.title},
|
||||||
"unique_id": self.unique_id,
|
unique_id=self.unique_id,
|
||||||
}
|
)
|
||||||
| (context or {}),
|
| (context or {}),
|
||||||
data=self.data | (data or {}),
|
data=self.data | (data or {}),
|
||||||
)
|
)
|
||||||
|
@ -1086,7 +1098,7 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
def async_start_reconfigure(
|
def async_start_reconfigure(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
context: dict[str, Any] | None = None,
|
context: ConfigFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reconfigure flow."""
|
"""Start a reconfigure flow."""
|
||||||
|
@ -1103,7 +1115,7 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
async def _async_init_reconfigure(
|
async def _async_init_reconfigure(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
context: dict[str, Any] | None = None,
|
context: ConfigFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a reconfigure flow."""
|
"""Start a reconfigure flow."""
|
||||||
|
@ -1115,12 +1127,12 @@ class ConfigEntry(Generic[_DataT]):
|
||||||
return
|
return
|
||||||
await hass.config_entries.flow.async_init(
|
await hass.config_entries.flow.async_init(
|
||||||
self.domain,
|
self.domain,
|
||||||
context={
|
context=ConfigFlowContext(
|
||||||
"source": SOURCE_RECONFIGURE,
|
source=SOURCE_RECONFIGURE,
|
||||||
"entry_id": self.entry_id,
|
entry_id=self.entry_id,
|
||||||
"title_placeholders": {"name": self.title},
|
title_placeholders={"name": self.title},
|
||||||
"unique_id": self.unique_id,
|
unique_id=self.unique_id,
|
||||||
}
|
)
|
||||||
| (context or {}),
|
| (context or {}),
|
||||||
data=self.data | (data or {}),
|
data=self.data | (data or {}),
|
||||||
)
|
)
|
||||||
|
@ -1214,7 +1226,9 @@ def _report_non_awaited_platform_forwards(entry: ConfigEntry, what: str) -> None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
class ConfigEntriesFlowManager(
|
||||||
|
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
|
||||||
|
):
|
||||||
"""Manage all the config entry flows that are in progress."""
|
"""Manage all the config entry flows that are in progress."""
|
||||||
|
|
||||||
_flow_result = ConfigFlowResult
|
_flow_result = ConfigFlowResult
|
||||||
|
@ -1260,7 +1274,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_init(
|
async def async_init(
|
||||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
self,
|
||||||
|
handler: str,
|
||||||
|
*,
|
||||||
|
context: ConfigFlowContext | None = None,
|
||||||
|
data: Any = None,
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Start a configuration flow."""
|
"""Start a configuration flow."""
|
||||||
if not context or "source" not in context:
|
if not context or "source" not in context:
|
||||||
|
@ -1319,7 +1337,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
self,
|
self,
|
||||||
flow_id: str,
|
flow_id: str,
|
||||||
handler: str,
|
handler: str,
|
||||||
context: dict,
|
context: ConfigFlowContext,
|
||||||
data: Any,
|
data: Any,
|
||||||
) -> tuple[ConfigFlow, ConfigFlowResult]:
|
) -> tuple[ConfigFlow, ConfigFlowResult]:
|
||||||
"""Run the init in a task to allow it to be canceled at shutdown."""
|
"""Run the init in a task to allow it to be canceled at shutdown."""
|
||||||
|
@ -1357,7 +1375,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
|
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self,
|
self,
|
||||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||||
result: ConfigFlowResult,
|
result: ConfigFlowResult,
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Finish a config flow and add an entry.
|
"""Finish a config flow and add an entry.
|
||||||
|
@ -1504,7 +1522,11 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def async_create_flow(
|
async def async_create_flow(
|
||||||
self, handler_key: str, *, context: dict | None = None, data: Any = None
|
self,
|
||||||
|
handler_key: str,
|
||||||
|
*,
|
||||||
|
context: ConfigFlowContext | None = None,
|
||||||
|
data: Any = None,
|
||||||
) -> ConfigFlow:
|
) -> ConfigFlow:
|
||||||
"""Create a flow for specified handler.
|
"""Create a flow for specified handler.
|
||||||
|
|
||||||
|
@ -1522,7 +1544,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
|
|
||||||
async def async_post_init(
|
async def async_post_init(
|
||||||
self,
|
self,
|
||||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||||
result: ConfigFlowResult,
|
result: ConfigFlowResult,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""After a flow is initialised trigger new flow notifications."""
|
"""After a flow is initialised trigger new flow notifications."""
|
||||||
|
@ -1560,7 +1582,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_has_matching_discovery_flow(
|
def async_has_matching_discovery_flow(
|
||||||
self, handler: str, match_context: dict[str, Any], data: Any
|
self, handler: str, match_context: ConfigFlowContext, data: Any
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an existing matching discovery flow is in progress.
|
"""Check if an existing matching discovery flow is in progress.
|
||||||
|
|
||||||
|
@ -2385,7 +2407,9 @@ def _async_abort_entries_match(
|
||||||
raise data_entry_flow.AbortFlow("already_configured")
|
raise data_entry_flow.AbortFlow("already_configured")
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
|
class ConfigEntryBaseFlow(
|
||||||
|
data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
|
||||||
|
):
|
||||||
"""Base class for config and option flows."""
|
"""Base class for config and option flows."""
|
||||||
|
|
||||||
_flow_result = ConfigFlowResult
|
_flow_result = ConfigFlowResult
|
||||||
|
@ -2406,7 +2430,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
if not self.context:
|
if not self.context:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return cast(str | None, self.context.get("unique_id"))
|
return self.context.get("unique_id")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@callback
|
@callback
|
||||||
|
@ -2779,7 +2803,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
"""Return reauth entry id."""
|
"""Return reauth entry id."""
|
||||||
if self.source != SOURCE_REAUTH:
|
if self.source != SOURCE_REAUTH:
|
||||||
raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}")
|
raise ValueError(f"Source is {self.source}, expected {SOURCE_REAUTH}")
|
||||||
return self.context["entry_id"] # type: ignore[no-any-return]
|
return self.context["entry_id"]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _get_reauth_entry(self) -> ConfigEntry:
|
def _get_reauth_entry(self) -> ConfigEntry:
|
||||||
|
@ -2793,7 +2817,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
"""Return reconfigure entry id."""
|
"""Return reconfigure entry id."""
|
||||||
if self.source != SOURCE_RECONFIGURE:
|
if self.source != SOURCE_RECONFIGURE:
|
||||||
raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}")
|
raise ValueError(f"Source is {self.source}, expected {SOURCE_RECONFIGURE}")
|
||||||
return self.context["entry_id"] # type: ignore[no-any-return]
|
return self.context["entry_id"]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _get_reconfigure_entry(self) -> ConfigEntry:
|
def _get_reconfigure_entry(self) -> ConfigEntry:
|
||||||
|
@ -2805,7 +2829,9 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
raise UnknownEntry
|
raise UnknownEntry
|
||||||
|
|
||||||
|
|
||||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
class OptionsFlowManager(
|
||||||
|
data_entry_flow.FlowManager[ConfigFlowContext, ConfigFlowResult]
|
||||||
|
):
|
||||||
"""Flow to set options for a configuration entry."""
|
"""Flow to set options for a configuration entry."""
|
||||||
|
|
||||||
_flow_result = ConfigFlowResult
|
_flow_result = ConfigFlowResult
|
||||||
|
@ -2822,7 +2848,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
self,
|
self,
|
||||||
handler_key: str,
|
handler_key: str,
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any] | None = None,
|
context: ConfigFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> OptionsFlow:
|
) -> OptionsFlow:
|
||||||
"""Create an options flow for a config entry.
|
"""Create an options flow for a config entry.
|
||||||
|
@ -2835,7 +2861,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
|
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self,
|
self,
|
||||||
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
|
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
|
||||||
result: ConfigFlowResult,
|
result: ConfigFlowResult,
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Finish an options flow and update options for configuration entry.
|
"""Finish an options flow and update options for configuration entry.
|
||||||
|
@ -2860,7 +2886,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _async_setup_preview(
|
async def _async_setup_preview(
|
||||||
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
|
self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up preview for an option flow handler."""
|
"""Set up preview for an option flow handler."""
|
||||||
entry = self._async_get_config_entry(flow.handler)
|
entry = self._async_get_config_entry(flow.handler)
|
||||||
|
|
|
@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
|
_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
|
||||||
|
_FlowResultT = TypeVar(
|
||||||
|
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
|
||||||
|
)
|
||||||
_HandlerT = TypeVar("_HandlerT", default=str)
|
_HandlerT = TypeVar("_HandlerT", default=str)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,10 +142,17 @@ class AbortFlow(FlowError):
|
||||||
self.description_placeholders = description_placeholders
|
self.description_placeholders = description_placeholders
|
||||||
|
|
||||||
|
|
||||||
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
|
class FlowContext(TypedDict, total=False):
|
||||||
|
"""Typed context dict."""
|
||||||
|
|
||||||
|
show_advanced_options: bool
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
|
||||||
"""Typed result dict."""
|
"""Typed result dict."""
|
||||||
|
|
||||||
context: dict[str, Any]
|
context: _FlowContextT
|
||||||
data_schema: vol.Schema | None
|
data_schema: vol.Schema | None
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
description_placeholders: Mapping[str, str | None] | None
|
description_placeholders: Mapping[str, str | None] | None
|
||||||
|
@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
|
||||||
schema_errors[path_part_str] = error.error_message
|
schema_errors[path_part_str] = error.error_message
|
||||||
|
|
||||||
|
|
||||||
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||||
"""Manage all the flows that are in progress."""
|
"""Manage all the flows that are in progress."""
|
||||||
|
|
||||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||||
|
@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
"""Initialize the flow manager."""
|
"""Initialize the flow manager."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._preview: set[_HandlerT] = set()
|
self._preview: set[_HandlerT] = set()
|
||||||
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
|
self._progress: dict[
|
||||||
|
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||||
|
] = {}
|
||||||
self._handler_progress_index: defaultdict[
|
self._handler_progress_index: defaultdict[
|
||||||
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
|
_HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||||
] = defaultdict(set)
|
] = defaultdict(set)
|
||||||
self._init_data_process_index: defaultdict[
|
self._init_data_process_index: defaultdict[
|
||||||
type, set[FlowHandler[_FlowResultT, _HandlerT]]
|
type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||||
] = defaultdict(set)
|
] = defaultdict(set)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
self,
|
self,
|
||||||
handler_key: _HandlerT,
|
handler_key: _HandlerT,
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any] | None = None,
|
context: _FlowContextT | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> FlowHandler[_FlowResultT, _HandlerT]:
|
) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
|
||||||
"""Create a flow for specified handler.
|
"""Create a flow for specified handler.
|
||||||
|
|
||||||
Handler key is the domain of the component that we want to set up.
|
Handler key is the domain of the component that we want to set up.
|
||||||
|
@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
self,
|
||||||
|
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||||
|
result: _FlowResultT,
|
||||||
) -> _FlowResultT:
|
) -> _FlowResultT:
|
||||||
"""Finish a data entry flow.
|
"""Finish a data entry flow.
|
||||||
|
|
||||||
|
@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def async_post_init(
|
async def async_post_init(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
self,
|
||||||
|
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||||
|
result: _FlowResultT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Entry has finished executing its first step asynchronously."""
|
"""Entry has finished executing its first step asynchronously."""
|
||||||
|
|
||||||
|
@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
@callback
|
@callback
|
||||||
def _async_progress_by_handler(
|
def _async_progress_by_handler(
|
||||||
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
||||||
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
|
) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
|
||||||
"""Return the flows in progress by handler.
|
"""Return the flows in progress by handler.
|
||||||
|
|
||||||
If match_context is specified, only return flows with a context that
|
If match_context is specified, only return flows with a context that
|
||||||
|
@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
self,
|
self,
|
||||||
handler: _HandlerT,
|
handler: _HandlerT,
|
||||||
*,
|
*,
|
||||||
context: dict[str, Any] | None = None,
|
context: _FlowContextT | None = None,
|
||||||
data: Any = None,
|
data: Any = None,
|
||||||
) -> _FlowResultT:
|
) -> _FlowResultT:
|
||||||
"""Start a data entry flow."""
|
"""Start a data entry flow."""
|
||||||
if context is None:
|
if context is None:
|
||||||
context = {}
|
context = cast(_FlowContextT, {})
|
||||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||||
if not flow:
|
if not flow:
|
||||||
raise UnknownFlow("Flow was not created")
|
raise UnknownFlow("Flow was not created")
|
||||||
|
@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_add_flow_progress(
|
def _async_add_flow_progress(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a flow to in progress."""
|
"""Add a flow to in progress."""
|
||||||
if flow.init_data is not None:
|
if flow.init_data is not None:
|
||||||
|
@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_remove_flow_from_index(
|
def _async_remove_flow_from_index(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove a flow from in progress."""
|
"""Remove a flow from in progress."""
|
||||||
if flow.init_data is not None:
|
if flow.init_data is not None:
|
||||||
|
@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
|
|
||||||
async def _async_handle_step(
|
async def _async_handle_step(
|
||||||
self,
|
self,
|
||||||
flow: FlowHandler[_FlowResultT, _HandlerT],
|
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||||
step_id: str,
|
step_id: str,
|
||||||
user_input: dict | BaseServiceInfo | None,
|
user_input: dict | BaseServiceInfo | None,
|
||||||
) -> _FlowResultT:
|
) -> _FlowResultT:
|
||||||
|
@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _raise_if_step_does_not_exist(
|
def _raise_if_step_does_not_exist(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
|
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise if the step does not exist."""
|
"""Raise if the step does not exist."""
|
||||||
method = f"async_step_{step_id}"
|
method = f"async_step_{step_id}"
|
||||||
|
@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_setup_preview(
|
async def _async_setup_preview(
|
||||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up preview for a flow handler."""
|
"""Set up preview for a flow handler."""
|
||||||
if flow.handler not in self._preview:
|
if flow.handler not in self._preview:
|
||||||
|
@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
@callback
|
@callback
|
||||||
def _async_flow_handler_to_flow_result(
|
def _async_flow_handler_to_flow_result(
|
||||||
self,
|
self,
|
||||||
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
|
flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
|
||||||
include_uninitialized: bool,
|
include_uninitialized: bool,
|
||||||
) -> list[_FlowResultT]:
|
) -> list[_FlowResultT]:
|
||||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||||
|
@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||||
"""Handle a data entry flow."""
|
"""Handle a data entry flow."""
|
||||||
|
|
||||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||||
|
@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||||
hass: HomeAssistant = None # type: ignore[assignment]
|
hass: HomeAssistant = None # type: ignore[assignment]
|
||||||
handler: _HandlerT = None # type: ignore[assignment]
|
handler: _HandlerT = None # type: ignore[assignment]
|
||||||
# Ensure the attribute has a subscriptable, but immutable, default value.
|
# Ensure the attribute has a subscriptable, but immutable, default value.
|
||||||
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
|
context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
|
||||||
|
|
||||||
# Set by _async_create_flow callback
|
# Set by _async_create_flow callback
|
||||||
init_step = "init"
|
init_step = "init"
|
||||||
|
@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||||
@property
|
@property
|
||||||
def source(self) -> str | None:
|
def source(self) -> str | None:
|
||||||
"""Source that initialized the flow."""
|
"""Source that initialized the flow."""
|
||||||
return self.context.get("source", None) # type: ignore[no-any-return]
|
return self.context.get("source", None) # type: ignore[return-value]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def show_advanced_options(self) -> bool:
|
def show_advanced_options(self) -> bool:
|
||||||
"""If we should show advanced options."""
|
"""If we should show advanced options."""
|
||||||
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return]
|
return self.context.get("show_advanced_options", False) # type: ignore[return-value]
|
||||||
|
|
||||||
def add_suggested_values_to_schema(
|
def add_suggested_values_to_schema(
|
||||||
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None
|
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None
|
||||||
|
|
|
@ -18,7 +18,7 @@ from . import config_validation as cv
|
||||||
|
|
||||||
_FlowManagerT = TypeVar(
|
_FlowManagerT = TypeVar(
|
||||||
"_FlowManagerT",
|
"_FlowManagerT",
|
||||||
bound=data_entry_flow.FlowManager[Any],
|
bound=data_entry_flow.FlowManager[Any, Any],
|
||||||
default=data_entry_flow.FlowManager,
|
default=data_entry_flow.FlowManager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.util.async_ import gather_with_limited_concurrency
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from homeassistant.config_entries import ConfigFlowResult
|
from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult
|
||||||
|
|
||||||
FLOW_INIT_LIMIT = 20
|
FLOW_INIT_LIMIT = 20
|
||||||
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
|
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
|
||||||
|
@ -42,7 +42,7 @@ class DiscoveryKey:
|
||||||
def async_create_flow(
|
def async_create_flow(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
domain: str,
|
domain: str,
|
||||||
context: dict[str, Any],
|
context: ConfigFlowContext,
|
||||||
data: Any,
|
data: Any,
|
||||||
*,
|
*,
|
||||||
discovery_key: DiscoveryKey | None = None,
|
discovery_key: DiscoveryKey | None = None,
|
||||||
|
@ -70,7 +70,7 @@ def async_create_flow(
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_init_flow(
|
def _async_init_flow(
|
||||||
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
|
hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any
|
||||||
) -> Coroutine[None, None, ConfigFlowResult] | None:
|
) -> Coroutine[None, None, ConfigFlowResult] | None:
|
||||||
"""Create a discovery flow."""
|
"""Create a discovery flow."""
|
||||||
# Avoid spawning flows that have the same initial discovery data
|
# Avoid spawning flows that have the same initial discovery data
|
||||||
|
@ -98,7 +98,7 @@ class PendingFlowKey(NamedTuple):
|
||||||
class PendingFlowValue(NamedTuple):
|
class PendingFlowValue(NamedTuple):
|
||||||
"""Value for pending flows."""
|
"""Value for pending flows."""
|
||||||
|
|
||||||
context: dict[str, Any]
|
context: ConfigFlowContext
|
||||||
data: Any
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ class FlowDispatcher:
|
||||||
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
|
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None:
|
def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None:
|
||||||
"""Create and add or queue a flow."""
|
"""Create and add or queue a flow."""
|
||||||
key = PendingFlowKey(domain, context["source"])
|
key = PendingFlowKey(domain, context["source"])
|
||||||
values = PendingFlowValue(context, data)
|
values = PendingFlowValue(context, data)
|
||||||
|
|
Loading…
Add table
Reference in a new issue