Pipeline runs are only equal with same id (#101341)

* Pipeline runs are only equal with same id

* Use dict instead of list in PipelineRuns

* Let it blow up

* Test

* Test rest of __eq__
This commit is contained in:
Michael Hansen 2023-10-03 16:52:31 -05:00 committed by GitHub
parent d8f1023210
commit e6504218bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 8 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import array
import asyncio
from collections import deque
from collections import defaultdict, deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from dataclasses import asdict, dataclass, field
from enum import StrEnum
@ -475,7 +475,7 @@ class PipelineRun:
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
tts_engine: str = field(init=False, repr=False)
tts_options: dict | None = field(init=False, default=None)
wake_word_entity_id: str = field(init=False, repr=False)
wake_word_entity_id: str | None = field(init=False, default=None, repr=False)
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
abort_wake_word_detection: bool = field(init=False, default=False)
@ -518,6 +518,13 @@ class PipelineRun:
self.audio_settings.noise_suppression_level,
)
def __eq__(self, other: Any) -> bool:
"""Compare pipeline runs by id."""
if isinstance(other, PipelineRun):
return self.id == other.id
return False
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
@ -1565,21 +1572,19 @@ class PipelineRuns:
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
"""Initialize."""
self._pipeline_runs: dict[str, list[PipelineRun]] = {}
self._pipeline_runs: dict[str, dict[str, PipelineRun]] = defaultdict(dict)
self._pipeline_store = pipeline_store
pipeline_store.async_add_listener(self._change_listener)
def add_run(self, pipeline_run: PipelineRun) -> None:
"""Add pipeline run."""
pipeline_id = pipeline_run.pipeline.id
if pipeline_id not in self._pipeline_runs:
self._pipeline_runs[pipeline_id] = []
self._pipeline_runs[pipeline_id].append(pipeline_run)
self._pipeline_runs[pipeline_id][pipeline_run.id] = pipeline_run
def remove_run(self, pipeline_run: PipelineRun) -> None:
"""Remove pipeline run."""
pipeline_id = pipeline_run.pipeline.id
self._pipeline_runs[pipeline_id].remove(pipeline_run)
self._pipeline_runs[pipeline_id].pop(pipeline_run.id)
async def _change_listener(
self, change_type: str, item_id: str, change: dict
@ -1589,7 +1594,7 @@ class PipelineRuns:
return
if pipeline_runs := self._pipeline_runs.get(item_id):
# Create a temporary list in case the list is modified while we iterate
for pipeline_run in list(pipeline_runs):
for pipeline_run in list(pipeline_runs.values()):
pipeline_run.abort_wake_word_detection = True

View file

@ -627,3 +627,32 @@ async def test_wake_word_detection_aborted(
await pipeline_input.execute()
assert process_events(events) == snapshot
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
"""Test that pipeline run equality uses unique id."""
def event_callback(event):
pass
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
run_1 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
run_2 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
assert run_1 == run_1
assert run_1 != run_2
assert run_1 != 1234