Add TypeVar default for FlowResult (#112345)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2024-03-05 22:52:11 +01:00 committed by GitHub
parent 33fe6ad647
commit 3d3e9900c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 77 additions and 81 deletions

View file

@ -91,8 +91,6 @@ async def auth_manager_from_config(
class AuthManagerFlowManager(data_entry_flow.FlowManager): class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows.""" """Manage authentication flows."""
_flow_result = FlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None: def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows.""" """Init auth manager flows."""
super().__init__(hass) super().__init__(hass)
@ -112,7 +110,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context) return await auth_provider.async_login_flow(context)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: FlowResult self, flow: data_entry_flow.FlowHandler, result: FlowResult
) -> FlowResult: ) -> FlowResult:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
flow = cast(LoginFlow, flow) flow = cast(LoginFlow, flow)

View file

@ -96,8 +96,6 @@ class MultiFactorAuthModule:
class SetupFlow(data_entry_flow.FlowHandler): class SetupFlow(data_entry_flow.FlowHandler):
"""Handler for the setup flow.""" """Handler for the setup flow."""
_flow_result = FlowResult
def __init__( def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
) -> None: ) -> None:

View file

@ -184,8 +184,6 @@ async def load_auth_provider_module(
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow.""" """Handler for the login flow."""
_flow_result = FlowResult
def __init__(self, auth_provider: AuthProvider) -> None: def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider

View file

@ -38,8 +38,6 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager): class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows.""" """Manage multi factor authentication flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( # type: ignore[override] async def async_create_flow( # type: ignore[override]
self, self,
handler_key: str, handler_key: str,
@ -56,7 +54,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
return await mfa_module.async_setup_flow(user_id) return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow.""" """Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result) _LOGGER.debug("flow_result: %s", result)

View file

@ -141,7 +141,9 @@ def _prepare_config_flow_result_json(
return data return data
class ConfigManagerFlowIndexView(FlowManagerIndexView): class ConfigManagerFlowIndexView(
FlowManagerIndexView[config_entries.ConfigEntriesFlowManager]
):
"""View to create config flows.""" """View to create config flows."""
url = "/api/config/config_entries/flow" url = "/api/config/config_entries/flow"
@ -196,7 +198,9 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
return _prepare_config_flow_result_json(result, super()._prepare_result_json) return _prepare_config_flow_result_json(result, super()._prepare_result_json)
class ConfigManagerFlowResourceView(FlowManagerResourceView): class ConfigManagerFlowResourceView(
FlowManagerResourceView[config_entries.ConfigEntriesFlowManager]
):
"""View to interact with the flow manager.""" """View to interact with the flow manager."""
url = "/api/config/config_entries/flow/{flow_id}" url = "/api/config/config_entries/flow/{flow_id}"
@ -238,7 +242,9 @@ class ConfigManagerAvailableFlowView(HomeAssistantView):
return self.json(await async_get_config_flows(hass, **kwargs)) return self.json(await async_get_config_flows(hass, **kwargs))
class OptionManagerFlowIndexView(FlowManagerIndexView): class OptionManagerFlowIndexView(
FlowManagerIndexView[config_entries.OptionsFlowManager]
):
"""View to create option flows.""" """View to create option flows."""
url = "/api/config/config_entries/options/flow" url = "/api/config/config_entries/options/flow"
@ -255,7 +261,9 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
return await super().post(request) return await super().post(request)
class OptionManagerFlowResourceView(FlowManagerResourceView): class OptionManagerFlowResourceView(
FlowManagerResourceView[config_entries.OptionsFlowManager]
):
"""View to interact with the option flow manager.""" """View to interact with the option flow manager."""
url = "/api/config/config_entries/options/flow/{flow_id}" url = "/api/config/config_entries/options/flow/{flow_id}"

View file

@ -48,11 +48,9 @@ class ConfirmRepairFlow(RepairsFlow):
) )
class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowResult]): class RepairsFlowManager(data_entry_flow.FlowManager):
"""Manage repairs flows.""" """Manage repairs flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( async def async_create_flow(
self, self,
handler_key: str, handler_key: str,
@ -84,7 +82,7 @@ class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowRes
return flow return flow
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult: ) -> data_entry_flow.FlowResult:
"""Complete a fix flow.""" """Complete a fix flow."""
if result.get("type") != data_entry_flow.FlowResultType.ABORT: if result.get("type") != data_entry_flow.FlowResultType.ABORT:

View file

@ -10,8 +10,6 @@ from homeassistant.core import HomeAssistant
class RepairsFlow(data_entry_flow.FlowHandler): class RepairsFlow(data_entry_flow.FlowHandler):
"""Handle a flow for fixing an issue.""" """Handle a flow for fixing an issue."""
_flow_result = data_entry_flow.FlowResult
issue_id: str issue_id: str
data: dict[str, str | int | float | None] | None data: dict[str, str | int | float | None] | None

View file

@ -34,7 +34,7 @@ from homeassistant.config_entries import (
) )
from homeassistant.const import CONF_NAME, CONF_URL from homeassistant.const import CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow, BaseFlowManager from homeassistant.data_entry_flow import AbortFlow, FlowManager
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property @property
@abstractmethod @abstractmethod
def flow_manager(self) -> BaseFlowManager: def flow_manager(self) -> FlowManager[ConfigFlowResult]:
"""Return the flow manager of the flow.""" """Return the flow manager of the flow."""
async def async_step_install_addon( async def async_step_install_addon(

View file

@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled.""" """Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]): class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Manage all the config entry flows that are in progress.""" """Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -1170,7 +1170,9 @@ class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]
self._discovery_debouncer.async_shutdown() self._discovery_debouncer.async_shutdown()
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow) flow = cast(ConfigFlow, flow)
@ -1290,7 +1292,9 @@ class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]
return flow return flow
async def async_post_init( async def async_post_init(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> None: ) -> None:
"""After a flow is initialised trigger new flow notifications.""" """After a flow is initialised trigger new flow notifications."""
source = flow.context["source"] source = flow.context["source"]
@ -1936,7 +1940,7 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured") raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]): class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
"""Base class for config and option flows.""" """Base class for config and option flows."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2288,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
return self.async_abort(reason=reason) return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]): class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2317,7 +2321,9 @@ class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
return handler.async_get_options_flow(entry) return handler.async_get_options_flow(entry)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
@ -2337,7 +2343,9 @@ class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
result["result"] = True result["result"] = True
return result return result
async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None: async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
) -> None:
"""Set up preview for an option flow handler.""" """Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler) entry = self._async_get_config_entry(flow.handler)
await _load_integration(self.hass, entry.domain, {}) await _load_integration(self.hass, entry.domain, {})

View file

@ -11,8 +11,9 @@ from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict, TypeVar from typing import Any, Generic, Required, TypedDict
from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
from .core import HomeAssistant, callback from .core import HomeAssistant, callback
@ -84,7 +85,7 @@ STEP_ID_OPTIONAL_STEPS = {
} }
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult") _FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
@dataclass(slots=True) @dataclass(slots=True)
@ -188,10 +189,10 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message schema_errors[path_part_str] = error.error_message
class BaseFlowManager(abc.ABC, Generic[_FlowResultT]): class FlowManager(abc.ABC, Generic[_FlowResultT]):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT] _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
def __init__( def __init__(
self, self,
@ -200,9 +201,9 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._preview: set[str] = set() self._preview: set[str] = set()
self._progress: dict[str, BaseFlowHandler] = {} self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {} self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {} self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
@ -211,7 +212,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
*, *,
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> BaseFlowHandler[_FlowResultT]: ) -> FlowHandler[_FlowResultT]:
"""Create a flow for specified handler. """Create a flow for specified handler.
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
@ -219,12 +220,12 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod @abc.abstractmethod
async def async_finish_flow( async def async_finish_flow(
self, flow: BaseFlowHandler, result: _FlowResultT self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> _FlowResultT: ) -> _FlowResultT:
"""Finish a data entry flow.""" """Finish a data entry flow."""
async def async_post_init( async def async_post_init(
self, flow: BaseFlowHandler, result: _FlowResultT self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> None: ) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@ -298,7 +299,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback @callback
def _async_progress_by_handler( def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None self, handler: str, match_context: dict[str, Any] | None
) -> list[BaseFlowHandler[_FlowResultT]]: ) -> list[FlowHandler[_FlowResultT]]:
"""Return the flows in progress by handler. """Return the flows in progress by handler.
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
@ -362,7 +363,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
data_schema := cur_step.get("data_schema") data_schema := cur_step.get("data_schema")
) is not None and user_input is not None: ) is not None and user_input is not None:
try: try:
user_input = data_schema(user_input) user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex: except vol.Invalid as ex:
raised_errors = [ex] raised_errors = [ex]
if isinstance(ex, vol.MultipleInvalid): if isinstance(ex, vol.MultipleInvalid):
@ -444,7 +445,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id) self._async_remove_flow_progress(flow_id)
@callback @callback
def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None: def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -453,9 +454,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow) self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback @callback
def _async_remove_flow_from_index( def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
self, flow: BaseFlowHandler[_FlowResultT]
) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -481,7 +480,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step( async def _async_handle_step(
self, self,
flow: BaseFlowHandler[_FlowResultT], flow: FlowHandler[_FlowResultT],
step_id: str, step_id: str,
user_input: dict | BaseServiceInfo | None, user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT: ) -> _FlowResultT:
@ -558,7 +557,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return result return result
def _raise_if_step_does_not_exist( def _raise_if_step_does_not_exist(
self, flow: BaseFlowHandler, step_id: str self, flow: FlowHandler[_FlowResultT], step_id: str
) -> None: ) -> None:
"""Raise if the step does not exist.""" """Raise if the step does not exist."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
@ -569,7 +568,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}" f"Handler {self.__class__.__name__} doesn't support step {step_id}"
) )
async def _async_setup_preview(self, flow: BaseFlowHandler) -> None: async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Set up preview for a flow handler.""" """Set up preview for a flow handler."""
if flow.handler not in self._preview: if flow.handler not in self._preview:
self._preview.add(flow.handler) self._preview.add(flow.handler)
@ -577,7 +576,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback @callback
def _async_flow_handler_to_flow_result( def _async_flow_handler_to_flow_result(
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
) -> list[_FlowResultT]: ) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" """Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = [] results = []
@ -595,16 +594,10 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return results return results
class FlowManager(BaseFlowManager[FlowResult]): class FlowHandler(Generic[_FlowResultT]):
"""Manage all the flows that are in progress."""
_flow_result = FlowResult
class BaseFlowHandler(Generic[_FlowResultT]):
"""Handle a data entry flow.""" """Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT] _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
# Set by flow manager # Set by flow manager
cur_step: _FlowResultT | None = None cur_step: _FlowResultT | None = None
@ -881,12 +874,6 @@ class BaseFlowHandler(Generic[_FlowResultT]):
self.__progress_task = progress_task self.__progress_task = progress_task
class FlowHandler(BaseFlowHandler[FlowResult]):
"""Handle a data entry flow."""
_flow_result = FlowResult
# These can be removed if no deprecated constant are in this module anymore # These can be removed if no deprecated constant are in this module anymore
__getattr__ = partial(check_if_deprecated_constant, module_globals=globals()) __getattr__ = partial(check_if_deprecated_constant, module_globals=globals())
__dir__ = partial( __dir__ = partial(

View file

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any, Generic
from aiohttp import web from aiohttp import web
from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
@ -14,11 +15,17 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from . import config_validation as cv from . import config_validation as cv
_FlowManagerT = TypeVar(
"_FlowManagerT",
bound=data_entry_flow.FlowManager[Any],
default=data_entry_flow.FlowManager,
)
class _BaseFlowManagerView(HomeAssistantView):
class _BaseFlowManagerView(HomeAssistantView, Generic[_FlowManagerT]):
"""Foundation for flow manager views.""" """Foundation for flow manager views."""
def __init__(self, flow_mgr: data_entry_flow.BaseFlowManager) -> None: def __init__(self, flow_mgr: _FlowManagerT) -> None:
"""Initialize the flow manager index view.""" """Initialize the flow manager index view."""
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr
@ -48,7 +55,7 @@ class _BaseFlowManagerView(HomeAssistantView):
return data return data
class FlowManagerIndexView(_BaseFlowManagerView): class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
"""View to create config flows.""" """View to create config flows."""
@RequestDataValidator( @RequestDataValidator(
@ -96,7 +103,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
return {"show_advanced_options": data["show_advanced_options"]} return {"show_advanced_options": data["show_advanced_options"]}
class FlowManagerResourceView(_BaseFlowManagerView): class FlowManagerResourceView(_BaseFlowManagerView[_FlowManagerT]):
"""View to interact with the flow manager.""" """View to interact with the flow manager."""
async def get(self, request: web.Request, /, flow_id: str) -> web.Response: async def get(self, request: web.Request, /, flow_id: str) -> web.Response:

View file

@ -63,7 +63,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup.""" """Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch( with patch(
"homeassistant.data_entry_flow.BaseFlowManager.async_has_matching_flow", "homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
return_value=True, return_value=True,
): ):
discovery_flow.async_create_flow( discovery_flow.async_create_flow(

View file

@ -45,7 +45,7 @@ def manager_fixture():
handlers = Registry() handlers = Registry()
entries = [] entries = []
class FlowManager(data_entry_flow.BaseFlowManager): class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager.""" """Test flow manager."""
async def async_create_flow(self, handler_key, *, context, data): async def async_create_flow(self, handler_key, *, context, data):
@ -105,7 +105,7 @@ async def test_name(hass: HomeAssistant, entity_registry: er.EntityRegistry) ->
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) @pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_config_flow_advanced_option( async def test_config_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
) -> None: ) -> None:
"""Test handling of advanced options in config flow.""" """Test handling of advanced options in config flow."""
manager.hass = hass manager.hass = hass
@ -200,7 +200,7 @@ async def test_config_flow_advanced_option(
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) @pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_options_flow_advanced_option( async def test_options_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
) -> None: ) -> None:
"""Test handling of advanced options in options flow.""" """Test handling of advanced options in options flow."""
manager.hass = hass manager.hass = hass
@ -475,7 +475,7 @@ async def test_next_step_function(hass: HomeAssistant) -> None:
async def test_suggested_values( async def test_suggested_values(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None: ) -> None:
"""Test suggested_values handling in SchemaFlowFormStep.""" """Test suggested_values handling in SchemaFlowFormStep."""
manager.hass = hass manager.hass = hass
@ -667,7 +667,7 @@ async def test_options_flow_state(hass: HomeAssistant) -> None:
async def test_options_flow_omit_optional_keys( async def test_options_flow_omit_optional_keys(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None: ) -> None:
"""Test handling of advanced options in options flow.""" """Test handling of advanced options in options flow."""
manager.hass = hass manager.hass = hass

View file

@ -24,11 +24,9 @@ def manager():
handlers = Registry() handlers = Registry()
entries = [] entries = []
class FlowManager(data_entry_flow.BaseFlowManager): class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager.""" """Test flow manager."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(self, handler_key, *, context, data): async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow.""" """Test create flow."""
handler = handlers.get(handler_key) handler = handlers.get(handler_key)
@ -81,7 +79,7 @@ async def test_configure_reuses_handler_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0 assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager: data_entry_flow.BaseFlowManager) -> None: async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
"""Test that we reuse instances.""" """Test that we reuse instances."""
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
@ -258,7 +256,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None:
step_id="init", data_schema=vol.Schema({"count": int}) step_id="init", data_schema=vol.Schema({"count": int})
) )
class FlowManager(data_entry_flow.BaseFlowManager): class FlowManager(data_entry_flow.FlowManager):
async def async_create_flow(self, handler_name, *, context, data): async def async_create_flow(self, handler_name, *, context, data):
"""Create a test flow.""" """Create a test flow."""
return TestFlow() return TestFlow()
@ -775,7 +773,7 @@ async def test_async_get_unknown_flow(manager) -> None:
async def test_async_has_matching_flow( async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None: ) -> None:
"""Test we can check for matching flows.""" """Test we can check for matching flows."""
manager.hass = hass manager.hass = hass
@ -951,7 +949,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
async def test_find_flows_by_init_data_type( async def test_find_flows_by_init_data_type(
manager: data_entry_flow.BaseFlowManager, manager: data_entry_flow.FlowManager,
) -> None: ) -> None:
"""Test we can find flows by init data type.""" """Test we can find flows by init data type."""