Add TypeVar default for FlowResult (#112345)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
parent
33fe6ad647
commit
3d3e9900c3
14 changed files with 77 additions and 81 deletions
|
@ -11,8 +11,9 @@ from enum import StrEnum
|
|||
from functools import partial
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Generic, Required, TypedDict, TypeVar
|
||||
from typing import Any, Generic, Required, TypedDict
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
|
||||
from .core import HomeAssistant, callback
|
||||
|
@ -84,7 +85,7 @@ STEP_ID_OPTIONAL_STEPS = {
|
|||
}
|
||||
|
||||
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult")
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -188,10 +189,10 @@ def _map_error_to_schema_errors(
|
|||
schema_errors[path_part_str] = error.error_message
|
||||
|
||||
|
||||
class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT]
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -200,9 +201,9 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._preview: set[str] = set()
|
||||
self._progress: dict[str, BaseFlowHandler] = {}
|
||||
self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {}
|
||||
self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {}
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
|
||||
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
|
||||
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
|
@ -211,7 +212,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> BaseFlowHandler[_FlowResultT]:
|
||||
) -> FlowHandler[_FlowResultT]:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
|
@ -219,12 +220,12 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: BaseFlowHandler, result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
) -> _FlowResultT:
|
||||
"""Finish a data entry flow."""
|
||||
|
||||
async def async_post_init(
|
||||
self, flow: BaseFlowHandler, result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
|
@ -298,7 +299,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
@callback
|
||||
def _async_progress_by_handler(
|
||||
self, handler: str, match_context: dict[str, Any] | None
|
||||
) -> list[BaseFlowHandler[_FlowResultT]]:
|
||||
) -> list[FlowHandler[_FlowResultT]]:
|
||||
"""Return the flows in progress by handler.
|
||||
|
||||
If match_context is specified, only return flows with a context that
|
||||
|
@ -362,7 +363,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
data_schema := cur_step.get("data_schema")
|
||||
) is not None and user_input is not None:
|
||||
try:
|
||||
user_input = data_schema(user_input)
|
||||
user_input = data_schema(user_input) # type: ignore[operator]
|
||||
except vol.Invalid as ex:
|
||||
raised_errors = [ex]
|
||||
if isinstance(ex, vol.MultipleInvalid):
|
||||
|
@ -444,7 +445,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
self._async_remove_flow_progress(flow_id)
|
||||
|
||||
@callback
|
||||
def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None:
|
||||
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
|
@ -453,9 +454,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_from_index(
|
||||
self, flow: BaseFlowHandler[_FlowResultT]
|
||||
) -> None:
|
||||
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
|
@ -481,7 +480,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: BaseFlowHandler[_FlowResultT],
|
||||
flow: FlowHandler[_FlowResultT],
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
) -> _FlowResultT:
|
||||
|
@ -558,7 +557,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
return result
|
||||
|
||||
def _raise_if_step_does_not_exist(
|
||||
self, flow: BaseFlowHandler, step_id: str
|
||||
self, flow: FlowHandler[_FlowResultT], step_id: str
|
||||
) -> None:
|
||||
"""Raise if the step does not exist."""
|
||||
method = f"async_step_{step_id}"
|
||||
|
@ -569,7 +568,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
|
||||
)
|
||||
|
||||
async def _async_setup_preview(self, flow: BaseFlowHandler) -> None:
|
||||
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
"""Set up preview for a flow handler."""
|
||||
if flow.handler not in self._preview:
|
||||
self._preview.add(flow.handler)
|
||||
|
@ -577,7 +576,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
|
||||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool
|
||||
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
|
||||
) -> list[_FlowResultT]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
results = []
|
||||
|
@ -595,16 +594,10 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
|
|||
return results
|
||||
|
||||
|
||||
class FlowManager(BaseFlowManager[FlowResult]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result = FlowResult
|
||||
|
||||
|
||||
class BaseFlowHandler(Generic[_FlowResultT]):
|
||||
class FlowHandler(Generic[_FlowResultT]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT]
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
|
||||
# Set by flow manager
|
||||
cur_step: _FlowResultT | None = None
|
||||
|
@ -881,12 +874,6 @@ class BaseFlowHandler(Generic[_FlowResultT]):
|
|||
self.__progress_task = progress_task
|
||||
|
||||
|
||||
class FlowHandler(BaseFlowHandler[FlowResult]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result = FlowResult
|
||||
|
||||
|
||||
# These can be removed if no deprecated constant are in this module anymore
|
||||
__getattr__ = partial(check_if_deprecated_constant, module_globals=globals())
|
||||
__dir__ = partial(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue