Wake word entity state/category fix (#98886)
* Only change wake word entity state on detection * Wake word entity is diagnostic
This commit is contained in:
parent
ba9c969d91
commit
4a417c7dcc
3 changed files with 37 additions and 9 deletions
|
@ -7,7 +7,7 @@ import logging
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN, EntityCategory
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
@ -71,16 +71,17 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
class WakeWordDetectionEntity(RestoreEntity):
|
class WakeWordDetectionEntity(RestoreEntity):
|
||||||
"""Represent a single wake word provider."""
|
"""Represent a single wake word provider."""
|
||||||
|
|
||||||
|
_attr_entity_category = EntityCategory.DIAGNOSTIC
|
||||||
_attr_should_poll = False
|
_attr_should_poll = False
|
||||||
__last_processed: str | None = None
|
__last_detected: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@final
|
@final
|
||||||
def state(self) -> str | None:
|
def state(self) -> str | None:
|
||||||
"""Return the state of the entity."""
|
"""Return the state of the entity."""
|
||||||
if self.__last_processed is None:
|
if self.__last_detected is None:
|
||||||
return None
|
return None
|
||||||
return self.__last_processed
|
return self.__last_detected
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -103,9 +104,13 @@ class WakeWordDetectionEntity(RestoreEntity):
|
||||||
|
|
||||||
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
Audio must be 16Khz sample rate with 16-bit mono PCM samples.
|
||||||
"""
|
"""
|
||||||
self.__last_processed = dt_util.utcnow().isoformat()
|
result = await self._async_process_audio_stream(stream)
|
||||||
self.async_write_ha_state()
|
if result is not None:
|
||||||
return await self._async_process_audio_stream(stream)
|
# Update last detected only when there is a detection
|
||||||
|
self.__last_detected = dt_util.utcnow().isoformat()
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def async_internal_added_to_hass(self) -> None:
|
async def async_internal_added_to_hass(self) -> None:
|
||||||
"""Call when the entity is added to hass."""
|
"""Call when the entity is added to hass."""
|
||||||
|
@ -116,4 +121,4 @@ class WakeWordDetectionEntity(RestoreEntity):
|
||||||
and state.state is not None
|
and state.state is not None
|
||||||
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||||
):
|
):
|
||||||
self.__last_processed = state.state
|
self.__last_detected = state.state
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_detected_entity
|
||||||
|
None
|
||||||
|
# ---
|
||||||
# name: test_ws_detect
|
# name: test_ws_detect
|
||||||
dict({
|
dict({
|
||||||
'event': dict({
|
'event': dict({
|
||||||
|
|
|
@ -3,9 +3,11 @@ from collections.abc import AsyncIterable, Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import wake_word
|
from homeassistant.components import wake_word
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||||
|
from homeassistant.const import EntityCategory
|
||||||
from homeassistant.core import HomeAssistant, State
|
from homeassistant.core import HomeAssistant, State
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
@ -147,7 +149,10 @@ async def test_config_entry_unload(
|
||||||
|
|
||||||
|
|
||||||
async def test_detected_entity(
|
async def test_detected_entity(
|
||||||
hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity
|
hass: HomeAssistant,
|
||||||
|
tmp_path: Path,
|
||||||
|
setup: MockProviderEntity,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful detection through entity."""
|
"""Test successful detection through entity."""
|
||||||
|
|
||||||
|
@ -158,9 +163,13 @@ async def test_detected_entity(
|
||||||
timestamp += _MS_PER_CHUNK
|
timestamp += _MS_PER_CHUNK
|
||||||
|
|
||||||
# Need 2 seconds to trigger
|
# Need 2 seconds to trigger
|
||||||
|
state = setup.state
|
||||||
result = await setup.async_process_audio_stream(three_second_stream())
|
result = await setup.async_process_audio_stream(three_second_stream())
|
||||||
assert result == wake_word.DetectionResult("test_ww", 2048)
|
assert result == wake_word.DetectionResult("test_ww", 2048)
|
||||||
|
|
||||||
|
assert state != setup.state
|
||||||
|
assert state == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_not_detected_entity(
|
async def test_not_detected_entity(
|
||||||
hass: HomeAssistant, setup: MockProviderEntity
|
hass: HomeAssistant, setup: MockProviderEntity
|
||||||
|
@ -174,9 +183,13 @@ async def test_not_detected_entity(
|
||||||
timestamp += _MS_PER_CHUNK
|
timestamp += _MS_PER_CHUNK
|
||||||
|
|
||||||
# Need 2 seconds to trigger
|
# Need 2 seconds to trigger
|
||||||
|
state = setup.state
|
||||||
result = await setup.async_process_audio_stream(one_second_stream())
|
result = await setup.async_process_audio_stream(one_second_stream())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
# State should only change when there's a detection
|
||||||
|
assert state == setup.state
|
||||||
|
|
||||||
|
|
||||||
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||||
"""Test async_default_engine."""
|
"""Test async_default_engine."""
|
||||||
|
@ -224,3 +237,10 @@ async def test_restore_state(
|
||||||
state = hass.states.get(entity_id)
|
state = hass.states.get(entity_id)
|
||||||
assert state
|
assert state
|
||||||
assert state.state == timestamp
|
assert state.state == timestamp
|
||||||
|
|
||||||
|
|
||||||
|
async def test_entity_attributes(
|
||||||
|
hass: HomeAssistant, mock_provider_entity: MockProviderEntity
|
||||||
|
) -> None:
|
||||||
|
"""Test that the provider entity attributes match expectations."""
|
||||||
|
assert mock_provider_entity.entity_category == EntityCategory.DIAGNOSTIC
|
||||||
|
|
Loading…
Add table
Reference in a new issue