Add TypeVar default for FlowResult (#112345)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2024-03-05 22:52:11 +01:00 committed by GitHub
parent 33fe6ad647
commit 3d3e9900c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 77 additions and 81 deletions

View file

@ -11,8 +11,9 @@ from enum import StrEnum
from functools import partial
import logging
from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict, TypeVar
from typing import Any, Generic, Required, TypedDict
from typing_extensions import TypeVar
import voluptuous as vol
from .core import HomeAssistant, callback
@ -84,7 +85,7 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult")
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
@dataclass(slots=True)
@ -188,10 +189,10 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
class FlowManager(abc.ABC, Generic[_FlowResultT]):
"""Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT]
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
def __init__(
self,
@ -200,9 +201,9 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[str] = set()
self._progress: dict[str, BaseFlowHandler] = {}
self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {}
self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {}
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
@abc.abstractmethod
async def async_create_flow(
@ -211,7 +212,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> BaseFlowHandler[_FlowResultT]:
) -> FlowHandler[_FlowResultT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -219,12 +220,12 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: BaseFlowHandler, result: _FlowResultT
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> _FlowResultT:
"""Finish a data entry flow."""
async def async_post_init(
self, flow: BaseFlowHandler, result: _FlowResultT
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously."""
@ -298,7 +299,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None
) -> list[BaseFlowHandler[_FlowResultT]]:
) -> list[FlowHandler[_FlowResultT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@ -362,7 +363,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
data_schema := cur_step.get("data_schema")
) is not None and user_input is not None:
try:
user_input = data_schema(user_input)
user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex:
raised_errors = [ex]
if isinstance(ex, vol.MultipleInvalid):
@ -444,7 +445,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None:
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -453,9 +454,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback
def _async_remove_flow_from_index(
self, flow: BaseFlowHandler[_FlowResultT]
) -> None:
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -481,7 +480,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step(
self,
flow: BaseFlowHandler[_FlowResultT],
flow: FlowHandler[_FlowResultT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@ -558,7 +557,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return result
def _raise_if_step_does_not_exist(
self, flow: BaseFlowHandler, step_id: str
self, flow: FlowHandler[_FlowResultT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@ -569,7 +568,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: BaseFlowHandler) -> None:
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
self._preview.add(flow.handler)
@ -577,7 +576,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
@ -595,16 +594,10 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return results
class FlowManager(BaseFlowManager[FlowResult]):
"""Manage all the flows that are in progress."""
_flow_result = FlowResult
class BaseFlowHandler(Generic[_FlowResultT]):
class FlowHandler(Generic[_FlowResultT]):
"""Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT]
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
# Set by flow manager
cur_step: _FlowResultT | None = None
@ -881,12 +874,6 @@ class BaseFlowHandler(Generic[_FlowResultT]):
self.__progress_task = progress_task
class FlowHandler(BaseFlowHandler[FlowResult]):
"""Handle a data entry flow."""
_flow_result = FlowResult
# These can be removed if no deprecated constant are in this module anymore
__getattr__ = partial(check_if_deprecated_constant, module_globals=globals())
__dir__ = partial(