Dynamic wake word loading for Wyoming (#101827)

* Change supported_wake_words property to async method

* Add test

* Add timeout + test

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-10-11 12:21:32 -05:00 committed by GitHub
parent 6c4ac71218
commit 257686fcfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 17 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
import logging
from typing import final
@ -34,6 +35,8 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
TIMEOUT_FETCH_WAKE_WORDS = 10
@callback
def async_default_entity(hass: HomeAssistant) -> str | None:
@ -86,9 +89,8 @@ class WakeWordDetectionEntity(RestoreEntity):
"""Return the state of the entity."""
return self.__last_detected
@property
@abstractmethod
def supported_wake_words(self) -> list[WakeWord]:
async def get_supported_wake_words(self) -> list[WakeWord]:
"""Return a list of supported wake words."""
@abstractmethod
@ -133,8 +135,9 @@ class WakeWordDetectionEntity(RestoreEntity):
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.async_response
@callback
def websocket_entity_info(
async def websocket_entity_info(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Get info about wake word entity."""
@ -147,7 +150,16 @@ def websocket_entity_info(
)
return
try:
async with asyncio.timeout(TIMEOUT_FETCH_WAKE_WORDS):
wake_words = await entity.get_supported_wake_words()
except asyncio.TimeoutError:
connection.send_error(
msg["id"], websocket_api.const.ERR_TIMEOUT, "Timeout fetching wake words"
)
return
connection.send_result(
msg["id"],
{"wake_words": entity.supported_wake_words},
{"wake_words": wake_words},
)

View file

@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .data import WyomingService
from .data import WyomingService, load_wyoming_info
from .error import WyomingError
_LOGGER = logging.getLogger(__name__)
@ -28,7 +28,7 @@ async def async_setup_entry(
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingWakeWordProvider(config_entry, service),
WyomingWakeWordProvider(hass, config_entry, service),
]
)
@ -38,10 +38,12 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
self.hass = hass
self.service = service
wake_service = service.info.wake[0]
@ -52,9 +54,19 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
self._attr_name = wake_service.name
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
info = await load_wyoming_info(
self.service.host, self.service.port, retries=0, timeout=1
)
if info is not None:
wake_service = info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
for ww in wake_service.models
]
return self._supported_wake_words
async def _async_process_audio_stream(

View file

@ -181,8 +181,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
@ -191,7 +190,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].id
wake_word_id = (await self.get_supported_wake_words())[0].id
async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"):
return wake_word.DetectionResult(

View file

@ -1,6 +1,9 @@
"""Test wake_word component setup."""
import asyncio
from collections.abc import AsyncIterable, Generator
from functools import partial
from pathlib import Path
from unittest.mock import patch
from freezegun import freeze_time
import pytest
@ -37,8 +40,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
url_path = "wake_word.test"
_attr_name = "test"
@property
def supported_wake_words(self) -> list[wake_word.WakeWord]:
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words."""
return [
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
@ -50,7 +52,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].id
wake_word_id = (await self.get_supported_wake_words())[0].id
async for _chunk, timestamp in stream:
if timestamp >= 2000:
@ -294,7 +296,7 @@ async def test_list_wake_words_unknown_entity(
setup: MockProviderEntity,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the list_wake_words websocket command works."""
"""Test that the list_wake_words websocket command handles unknown entity."""
client = await hass_ws_client(hass)
await client.send_json(
{
@ -308,3 +310,28 @@ async def test_list_wake_words_unknown_entity(
assert not msg["success"]
assert msg["error"] == {"code": "not_found", "message": "Entity not found"}
async def test_list_wake_words_timeout(
hass: HomeAssistant,
setup: MockProviderEntity,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the list_wake_words websocket command handles unknown entity."""
client = await hass_ws_client(hass)
with patch.object(
setup, "get_supported_wake_words", partial(asyncio.sleep, 1)
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0):
await client.send_json(
{
"id": 5,
"type": "wake_word/info",
"entity_id": setup.entity_id,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {"code": "timeout", "message": "Timeout fetching wake words"}

View file

@ -6,12 +6,13 @@ from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from wyoming.asr import Transcript
from wyoming.info import Info, WakeModel, WakeProgram
from wyoming.wake import Detection
from homeassistant.components import wake_word
from homeassistant.core import HomeAssistant
from . import MockAsyncTcpClient
from . import TEST_ATTR, MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
@ -24,7 +25,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
)
assert entity is not None
assert entity.supported_wake_words == [
assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord(id="Test Model", name="Test Model")
]
@ -157,3 +158,55 @@ async def test_detect_message_with_wrong_wake_word(
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert result is None
async def test_dynamic_wake_word_info(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that supported wake words are loaded dynamically."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
# Original info
assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("Test Model", "Test Model")
]
new_info = Info(
wake=[
WakeProgram(
name="dynamic",
description="Dynamic Wake Word",
installed=True,
attribution=TEST_ATTR,
models=[
WakeModel(
name="ww1",
description="Wake Word 1",
installed=True,
attribution=TEST_ATTR,
languages=[],
),
WakeModel(
name="ww2",
description="Wake Word 2",
installed=True,
attribution=TEST_ATTR,
languages=[],
),
],
)
]
)
# Different Wyoming info will be fetched
with patch(
"homeassistant.components.wyoming.wake_word.load_wyoming_info",
return_value=new_info,
):
assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("ww1", "Wake Word 1"),
wake_word.WakeWord("ww2", "Wake Word 2"),
]