Avoid double lookups with data_entry_flow indices (#100627)

This commit is contained in:
J. Nick Koston 2023-09-20 11:55:51 +02:00 committed by GitHub
parent 06c7f0959c
commit d675825b5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -138,8 +138,8 @@ class FlowManager(abc.ABC):
self.hass = hass self.hass = hass
self._preview: set[str] = set() self._preview: set[str] = set()
self._progress: dict[str, FlowHandler] = {} self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {} self._handler_progress_index: dict[str, set[FlowHandler]] = {}
self._init_data_process_index: dict[type, set[str]] = {} self._init_data_process_index: dict[type, set[FlowHandler]] = {}
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
@ -221,9 +221,9 @@ class FlowManager(abc.ABC):
"""Return flows in progress init matching by data type as a partial FlowResult.""" """Return flows in progress init matching by data type as a partial FlowResult."""
return _async_flow_handler_to_flow_result( return _async_flow_handler_to_flow_result(
( (
self._progress[flow_id] progress
for flow_id in self._init_data_process_index.get(init_data_type, {}) for progress in self._init_data_process_index.get(init_data_type, set())
if matcher(self._progress[flow_id].init_data) if matcher(progress.init_data)
), ),
include_uninitialized, include_uninitialized,
) )
@ -237,18 +237,13 @@ class FlowManager(abc.ABC):
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
is a superset of match_context. is a superset of match_context.
""" """
match_context_items = match_context.items() if match_context else None if not match_context:
return list(self._handler_progress_index.get(handler, []))
match_context_items = match_context.items()
return [ return [
progress progress
for flow_id in self._handler_progress_index.get(handler, {}) for progress in self._handler_progress_index.get(handler, set())
if (progress := self._progress[flow_id]) if match_context_items <= progress.context.items()
and (
not match_context_items
or (
(context := progress.context)
and match_context_items <= context.items()
)
)
] ]
async def async_init( async def async_init(
@ -348,22 +343,20 @@ class FlowManager(abc.ABC):
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
self._init_data_process_index.setdefault(init_data_type, set()).add( self._init_data_process_index.setdefault(init_data_type, set()).add(flow)
flow.flow_id
)
self._progress[flow.flow_id] = flow self._progress[flow.flow_id] = flow
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id) self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback @callback
def _async_remove_flow_from_index(self, flow: FlowHandler) -> None: def _async_remove_flow_from_index(self, flow: FlowHandler) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
self._init_data_process_index[init_data_type].remove(flow.flow_id) self._init_data_process_index[init_data_type].remove(flow)
if not self._init_data_process_index[init_data_type]: if not self._init_data_process_index[init_data_type]:
del self._init_data_process_index[init_data_type] del self._init_data_process_index[init_data_type]
handler = flow.handler handler = flow.handler
self._handler_progress_index[handler].remove(flow.flow_id) self._handler_progress_index[handler].remove(flow)
if not self._handler_progress_index[handler]: if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler] del self._handler_progress_index[handler]