Pipeline runs are only equal with same id (#101341)
* Pipeline runs are only equal with same id * Use dict instead of list in PipelineRuns * Let it blow up * Test * Test rest of __eq__
This commit is contained in:
parent
d8f1023210
commit
e6504218bc
2 changed files with 42 additions and 8 deletions
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import array
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
|
@ -475,7 +475,7 @@ class PipelineRun:
|
|||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
|
||||
tts_engine: str = field(init=False, repr=False)
|
||||
tts_options: dict | None = field(init=False, default=None)
|
||||
wake_word_entity_id: str = field(init=False, repr=False)
|
||||
wake_word_entity_id: str | None = field(init=False, default=None, repr=False)
|
||||
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
|
||||
|
||||
abort_wake_word_detection: bool = field(init=False, default=False)
|
||||
|
@ -518,6 +518,13 @@ class PipelineRun:
|
|||
self.audio_settings.noise_suppression_level,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compare pipeline runs by id."""
|
||||
if isinstance(other, PipelineRun):
|
||||
return self.id == other.id
|
||||
|
||||
return False
|
||||
|
||||
@callback
|
||||
def process_event(self, event: PipelineEvent) -> None:
|
||||
"""Log an event and call listener."""
|
||||
|
@ -1565,21 +1572,19 @@ class PipelineRuns:
|
|||
|
||||
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
|
||||
"""Initialize."""
|
||||
self._pipeline_runs: dict[str, list[PipelineRun]] = {}
|
||||
self._pipeline_runs: dict[str, dict[str, PipelineRun]] = defaultdict(dict)
|
||||
self._pipeline_store = pipeline_store
|
||||
pipeline_store.async_add_listener(self._change_listener)
|
||||
|
||||
def add_run(self, pipeline_run: PipelineRun) -> None:
|
||||
"""Add pipeline run."""
|
||||
pipeline_id = pipeline_run.pipeline.id
|
||||
if pipeline_id not in self._pipeline_runs:
|
||||
self._pipeline_runs[pipeline_id] = []
|
||||
self._pipeline_runs[pipeline_id].append(pipeline_run)
|
||||
self._pipeline_runs[pipeline_id][pipeline_run.id] = pipeline_run
|
||||
|
||||
def remove_run(self, pipeline_run: PipelineRun) -> None:
|
||||
"""Remove pipeline run."""
|
||||
pipeline_id = pipeline_run.pipeline.id
|
||||
self._pipeline_runs[pipeline_id].remove(pipeline_run)
|
||||
self._pipeline_runs[pipeline_id].pop(pipeline_run.id)
|
||||
|
||||
async def _change_listener(
|
||||
self, change_type: str, item_id: str, change: dict
|
||||
|
@ -1589,7 +1594,7 @@ class PipelineRuns:
|
|||
return
|
||||
if pipeline_runs := self._pipeline_runs.get(item_id):
|
||||
# Create a temporary list in case the list is modified while we iterate
|
||||
for pipeline_run in list(pipeline_runs):
|
||||
for pipeline_run in list(pipeline_runs.values()):
|
||||
pipeline_run.abort_wake_word_detection = True
|
||||
|
||||
|
||||
|
|
|
@ -627,3 +627,32 @@ async def test_wake_word_detection_aborted(
|
|||
await pipeline_input.execute()
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
||||
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test that pipeline run equality uses unique id."""
|
||||
|
||||
def event_callback(event):
|
||||
pass
|
||||
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
||||
run_1 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
run_2 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
|
||||
assert run_1 == run_1
|
||||
assert run_1 != run_2
|
||||
assert run_1 != 1234
|
||||
|
|
Loading…
Add table
Reference in a new issue