Wake word entity state/category fix (#98886)
* Only change wake word entity state on detection * Wake word entity is diagnostic
This commit is contained in:
parent
ba9c969d91
commit
4a417c7dcc
3 changed files with 37 additions and 9 deletions
|
@ -7,7 +7,7 @@ import logging
|
|||
from typing import final
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN, EntityCategory
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
@ -71,16 +71,17 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
class WakeWordDetectionEntity(RestoreEntity):
|
||||
"""Represent a single wake word provider."""
|
||||
|
||||
_attr_entity_category = EntityCategory.DIAGNOSTIC
|
||||
_attr_should_poll = False
|
||||
__last_processed: str | None = None
|
||||
__last_detected: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
"""Return the state of the entity."""
|
||||
if self.__last_processed is None:
|
||||
if self.__last_detected is None:
|
||||
return None
|
||||
return self.__last_processed
|
||||
return self.__last_detected
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -103,9 +104,13 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||
|
||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||
"""
|
||||
self.__last_processed = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self._async_process_audio_stream(stream)
|
||||
result = await self._async_process_audio_stream(stream)
|
||||
if result is not None:
|
||||
# Update last detected only when there is a detection
|
||||
self.__last_detected = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
|
||||
return result
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the entity is added to hass."""
|
||||
|
@ -116,4 +121,4 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||
and state.state is not None
|
||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
):
|
||||
self.__last_processed = state.state
|
||||
self.__last_detected = state.state
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# serializer version: 1
|
||||
# name: test_detected_entity
|
||||
None
|
||||
# ---
|
||||
# name: test_ws_detect
|
||||
dict({
|
||||
'event': dict({
|
||||
|
|
|
@ -3,9 +3,11 @@ from collections.abc import AsyncIterable, Generator
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -147,7 +149,10 @@ async def test_config_entry_unload(
|
|||
|
||||
|
||||
async def test_detected_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
setup: MockProviderEntity,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test successful detection through entity."""
|
||||
|
||||
|
@ -158,9 +163,13 @@ async def test_detected_entity(
|
|||
timestamp += _MS_PER_CHUNK
|
||||
|
||||
# Need 2 seconds to trigger
|
||||
state = setup.state
|
||||
result = await setup.async_process_audio_stream(three_second_stream())
|
||||
assert result == wake_word.DetectionResult("test_ww", 2048)
|
||||
|
||||
assert state != setup.state
|
||||
assert state == snapshot
|
||||
|
||||
|
||||
async def test_not_detected_entity(
|
||||
hass: HomeAssistant, setup: MockProviderEntity
|
||||
|
@ -174,9 +183,13 @@ async def test_not_detected_entity(
|
|||
timestamp += _MS_PER_CHUNK
|
||||
|
||||
# Need 2 seconds to trigger
|
||||
state = setup.state
|
||||
result = await setup.async_process_audio_stream(one_second_stream())
|
||||
assert result is None
|
||||
|
||||
# State should only change when there's a detection
|
||||
assert state == setup.state
|
||||
|
||||
|
||||
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test async_default_engine."""
|
||||
|
@ -224,3 +237,10 @@ async def test_restore_state(
|
|||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == timestamp
|
||||
|
||||
|
||||
async def test_entity_attributes(
|
||||
hass: HomeAssistant, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test that the provider entity attributes match expectations."""
|
||||
assert mock_provider_entity.entity_category == EntityCategory.DIAGNOSTIC
|
||||
|
|
Loading…
Add table
Reference in a new issue