From e6504218bc14734cc87201ac31a652d85a510547 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 3 Oct 2023 16:52:31 -0500 Subject: [PATCH] Pipeline runs are only equal with same id (#101341) * Pipeline runs are only equal with same id * Use dict instead of list in PipelineRuns * Let it blow up * Test * Test rest of __eq__ --- .../components/assist_pipeline/pipeline.py | 21 +++++++++----- tests/components/assist_pipeline/test_init.py | 29 +++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 7e4c71671ad..76444fb2436 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations import array import asyncio -from collections import deque +from collections import defaultdict, deque from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable from dataclasses import asdict, dataclass, field from enum import StrEnum @@ -475,7 +475,7 @@ class PipelineRun: stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False) tts_engine: str = field(init=False, repr=False) tts_options: dict | None = field(init=False, default=None) - wake_word_entity_id: str = field(init=False, repr=False) + wake_word_entity_id: str | None = field(init=False, default=None, repr=False) wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False) abort_wake_word_detection: bool = field(init=False, default=False) @@ -518,6 +518,13 @@ class PipelineRun: self.audio_settings.noise_suppression_level, ) + def __eq__(self, other: Any) -> bool: + """Compare pipeline runs by id.""" + if isinstance(other, PipelineRun): + return self.id == other.id + + return False + @callback def process_event(self, event: PipelineEvent) -> None: """Log an event and call listener.""" @@ -1565,21 +1572,19 @@ class PipelineRuns: def __init__(self, pipeline_store: PipelineStorageCollection) -> None: """Initialize.""" - self._pipeline_runs: dict[str, list[PipelineRun]] = {} + self._pipeline_runs: dict[str, dict[str, PipelineRun]] = defaultdict(dict) self._pipeline_store = pipeline_store pipeline_store.async_add_listener(self._change_listener) def add_run(self, pipeline_run: PipelineRun) -> None: """Add pipeline run.""" pipeline_id = pipeline_run.pipeline.id - if pipeline_id not in self._pipeline_runs: - self._pipeline_runs[pipeline_id] = [] - self._pipeline_runs[pipeline_id].append(pipeline_run) + self._pipeline_runs[pipeline_id][pipeline_run.id] = pipeline_run def remove_run(self, pipeline_run: PipelineRun) -> None: """Remove pipeline run.""" pipeline_id = pipeline_run.pipeline.id - self._pipeline_runs[pipeline_id].remove(pipeline_run) + self._pipeline_runs[pipeline_id].pop(pipeline_run.id) async def _change_listener( self, change_type: str, item_id: str, change: dict @@ -1589,7 +1594,7 @@ class PipelineRuns: return if pipeline_runs := self._pipeline_runs.get(item_id): # Create a temporary list in case the list is modified while we iterate - for pipeline_run in list(pipeline_runs): + for pipeline_run in list(pipeline_runs.values()): pipeline_run.abort_wake_word_detection = True diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 5258736c89f..98ecae628f1 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -627,3 +627,32 @@ async def test_wake_word_detection_aborted( await pipeline_input.execute() assert process_events(events) == snapshot + + +def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None: + """Test that pipeline run equality uses unique id.""" + + def event_callback(event): + pass + + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass) + run_1 = assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.STT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=event_callback, + ) + run_2 = assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.STT, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=event_callback, + ) + + assert run_1 == run_1 + assert run_1 != run_2 + assert run_1 != 1234