Dynamic wake word loading for Wyoming (#101827)

* Change supported_wake_words property to async method

* Add test

* Add timeout + test

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2023-10-11 12:21:32 -05:00 committed by GitHub
parent 6c4ac71218
commit 257686fcfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 17 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import logging import logging
from typing import final from typing import final
@ -34,6 +35,8 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
TIMEOUT_FETCH_WAKE_WORDS = 10
@callback @callback
def async_default_entity(hass: HomeAssistant) -> str | None: def async_default_entity(hass: HomeAssistant) -> str | None:
@ -86,9 +89,8 @@ class WakeWordDetectionEntity(RestoreEntity):
"""Return the state of the entity.""" """Return the state of the entity."""
return self.__last_detected return self.__last_detected
@property
@abstractmethod @abstractmethod
def supported_wake_words(self) -> list[WakeWord]: async def get_supported_wake_words(self) -> list[WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
@abstractmethod @abstractmethod
@ -133,8 +135,9 @@ class WakeWordDetectionEntity(RestoreEntity):
vol.Required("entity_id"): cv.entity_domain(DOMAIN), vol.Required("entity_id"): cv.entity_domain(DOMAIN),
} }
) )
@websocket_api.async_response
@callback @callback
def websocket_entity_info( async def websocket_entity_info(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None: ) -> None:
"""Get info about wake word entity.""" """Get info about wake word entity."""
@ -147,7 +150,16 @@ def websocket_entity_info(
) )
return return
try:
async with asyncio.timeout(TIMEOUT_FETCH_WAKE_WORDS):
wake_words = await entity.get_supported_wake_words()
except asyncio.TimeoutError:
connection.send_error(
msg["id"], websocket_api.const.ERR_TIMEOUT, "Timeout fetching wake words"
)
return
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{"wake_words": entity.supported_wake_words}, {"wake_words": wake_words},
) )

View file

@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN from .const import DOMAIN
from .data import WyomingService from .data import WyomingService, load_wyoming_info
from .error import WyomingError from .error import WyomingError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -28,7 +28,7 @@ async def async_setup_entry(
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities( async_add_entities(
[ [
WyomingWakeWordProvider(config_entry, service), WyomingWakeWordProvider(hass, config_entry, service),
] ]
) )
@ -38,10 +38,12 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
def __init__( def __init__(
self, self,
hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: ConfigEntry,
service: WyomingService, service: WyomingService,
) -> None: ) -> None:
"""Set up provider.""" """Set up provider."""
self.hass = hass
self.service = service self.service = service
wake_service = service.info.wake[0] wake_service = service.info.wake[0]
@ -52,9 +54,19 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
self._attr_name = wake_service.name self._attr_name = wake_service.name
self._attr_unique_id = f"{config_entry.entry_id}-wake_word" self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
@property async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
info = await load_wyoming_info(
self.service.host, self.service.port, retries=0, timeout=1
)
if info is not None:
wake_service = info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
for ww in wake_service.models
]
return self._supported_wake_words return self._supported_wake_words
async def _async_process_audio_stream( async def _async_process_audio_stream(

View file

@ -181,8 +181,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
url_path = "wake_word.test" url_path = "wake_word.test"
_attr_name = "test" _attr_name = "test"
@property async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")] return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
@ -191,7 +190,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.""" """Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None: if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].id wake_word_id = (await self.get_supported_wake_words())[0].id
async for chunk, timestamp in stream: async for chunk, timestamp in stream:
if chunk.startswith(b"wake word"): if chunk.startswith(b"wake word"):
return wake_word.DetectionResult( return wake_word.DetectionResult(

View file

@ -1,6 +1,9 @@
"""Test wake_word component setup.""" """Test wake_word component setup."""
import asyncio
from collections.abc import AsyncIterable, Generator from collections.abc import AsyncIterable, Generator
from functools import partial
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from freezegun import freeze_time from freezegun import freeze_time
import pytest import pytest
@ -37,8 +40,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
url_path = "wake_word.test" url_path = "wake_word.test"
_attr_name = "test" _attr_name = "test"
@property async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
def supported_wake_words(self) -> list[wake_word.WakeWord]:
"""Return a list of supported wake words.""" """Return a list of supported wake words."""
return [ return [
wake_word.WakeWord(id="test_ww", name="Test Wake Word"), wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
@ -50,7 +52,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
) -> wake_word.DetectionResult | None: ) -> wake_word.DetectionResult | None:
"""Try to detect wake word(s) in an audio stream with timestamps.""" """Try to detect wake word(s) in an audio stream with timestamps."""
if wake_word_id is None: if wake_word_id is None:
wake_word_id = self.supported_wake_words[0].id wake_word_id = (await self.get_supported_wake_words())[0].id
async for _chunk, timestamp in stream: async for _chunk, timestamp in stream:
if timestamp >= 2000: if timestamp >= 2000:
@ -294,7 +296,7 @@ async def test_list_wake_words_unknown_entity(
setup: MockProviderEntity, setup: MockProviderEntity,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
) -> None: ) -> None:
"""Test that the list_wake_words websocket command works.""" """Test that the list_wake_words websocket command handles unknown entity."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json( await client.send_json(
{ {
@ -308,3 +310,28 @@ async def test_list_wake_words_unknown_entity(
assert not msg["success"] assert not msg["success"]
assert msg["error"] == {"code": "not_found", "message": "Entity not found"} assert msg["error"] == {"code": "not_found", "message": "Entity not found"}
async def test_list_wake_words_timeout(
hass: HomeAssistant,
setup: MockProviderEntity,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the list_wake_words websocket command handles unknown entity."""
client = await hass_ws_client(hass)
with patch.object(
setup, "get_supported_wake_words", partial(asyncio.sleep, 1)
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0):
await client.send_json(
{
"id": 5,
"type": "wake_word/info",
"entity_id": setup.entity_id,
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {"code": "timeout", "message": "Timeout fetching wake words"}

View file

@ -6,12 +6,13 @@ from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from wyoming.asr import Transcript from wyoming.asr import Transcript
from wyoming.info import Info, WakeModel, WakeProgram
from wyoming.wake import Detection from wyoming.wake import Detection
from homeassistant.components import wake_word from homeassistant.components import wake_word
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import MockAsyncTcpClient from . import TEST_ATTR, MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None: async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
@ -24,7 +25,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
) )
assert entity is not None assert entity is not None
assert entity.supported_wake_words == [ assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord(id="Test Model", name="Test Model") wake_word.WakeWord(id="Test Model", name="Test Model")
] ]
@ -157,3 +158,55 @@ async def test_detect_message_with_wrong_wake_word(
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word") result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
assert result is None assert result is None
async def test_dynamic_wake_word_info(
hass: HomeAssistant, init_wyoming_wake_word
) -> None:
"""Test that supported wake words are loaded dynamically."""
entity = wake_word.async_get_wake_word_detection_entity(
hass, "wake_word.test_wake_word"
)
assert entity is not None
# Original info
assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("Test Model", "Test Model")
]
new_info = Info(
wake=[
WakeProgram(
name="dynamic",
description="Dynamic Wake Word",
installed=True,
attribution=TEST_ATTR,
models=[
WakeModel(
name="ww1",
description="Wake Word 1",
installed=True,
attribution=TEST_ATTR,
languages=[],
),
WakeModel(
name="ww2",
description="Wake Word 2",
installed=True,
attribution=TEST_ATTR,
languages=[],
),
],
)
]
)
# Different Wyoming info will be fetched
with patch(
"homeassistant.components.wyoming.wake_word.load_wyoming_info",
return_value=new_info,
):
assert (await entity.get_supported_wake_words()) == [
wake_word.WakeWord("ww1", "Wake Word 1"),
wake_word.WakeWord("ww2", "Wake Word 2"),
]