diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 467fc3b5228..63cbfda5b9b 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -138,8 +138,8 @@ class FlowManager(abc.ABC): self.hass = hass self._preview: set[str] = set() self._progress: dict[str, FlowHandler] = {} - self._handler_progress_index: dict[str, set[str]] = {} - self._init_data_process_index: dict[type, set[str]] = {} + self._handler_progress_index: dict[str, set[FlowHandler]] = {} + self._init_data_process_index: dict[type, set[FlowHandler]] = {} @abc.abstractmethod async def async_create_flow( @@ -221,9 +221,9 @@ class FlowManager(abc.ABC): """Return flows in progress init matching by data type as a partial FlowResult.""" return _async_flow_handler_to_flow_result( ( - self._progress[flow_id] - for flow_id in self._init_data_process_index.get(init_data_type, {}) - if matcher(self._progress[flow_id].init_data) + progress + for progress in self._init_data_process_index.get(init_data_type, set()) + if matcher(progress.init_data) ), include_uninitialized, ) @@ -237,18 +237,13 @@ class FlowManager(abc.ABC): If match_context is specified, only return flows with a context that is a superset of match_context. """ - match_context_items = match_context.items() if match_context else None + if not match_context: + return list(self._handler_progress_index.get(handler, [])) + match_context_items = match_context.items() return [ progress - for flow_id in self._handler_progress_index.get(handler, {}) - if (progress := self._progress[flow_id]) - and ( - not match_context_items - or ( - (context := progress.context) - and match_context_items <= context.items() - ) - ) + for progress in self._handler_progress_index.get(handler, set()) + if match_context_items <= progress.context.items() ] async def async_init( @@ -348,22 +343,20 @@ class FlowManager(abc.ABC): """Add a flow to in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) - self._init_data_process_index.setdefault(init_data_type, set()).add( - flow.flow_id - ) + self._init_data_process_index.setdefault(init_data_type, set()).add(flow) self._progress[flow.flow_id] = flow - self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id) + self._handler_progress_index.setdefault(flow.handler, set()).add(flow) @callback def _async_remove_flow_from_index(self, flow: FlowHandler) -> None: """Remove a flow from in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) - self._init_data_process_index[init_data_type].remove(flow.flow_id) + self._init_data_process_index[init_data_type].remove(flow) if not self._init_data_process_index[init_data_type]: del self._init_data_process_index[init_data_type] handler = flow.handler - self._handler_progress_index[handler].remove(flow.flow_id) + self._handler_progress_index[handler].remove(flow) if not self._handler_progress_index[handler]: del self._handler_progress_index[handler]