Dynamic wake word loading for Wyoming (#101827)
* Change supported_wake_words property to async method * Add test * Add timeout + test --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
6c4ac71218
commit
257686fcfe
5 changed files with 120 additions and 17 deletions
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
from typing import final
|
||||
|
@ -34,6 +35,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
|
||||
TIMEOUT_FETCH_WAKE_WORDS = 10
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_entity(hass: HomeAssistant) -> str | None:
|
||||
|
@ -86,9 +89,8 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||
"""Return the state of the entity."""
|
||||
return self.__last_detected
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_wake_words(self) -> list[WakeWord]:
|
||||
async def get_supported_wake_words(self) -> list[WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -133,8 +135,9 @@ class WakeWordDetectionEntity(RestoreEntity):
|
|||
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
@callback
|
||||
def websocket_entity_info(
|
||||
async def websocket_entity_info(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Get info about wake word entity."""
|
||||
|
@ -147,7 +150,16 @@ def websocket_entity_info(
|
|||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(TIMEOUT_FETCH_WAKE_WORDS):
|
||||
wake_words = await entity.get_supported_wake_words()
|
||||
except asyncio.TimeoutError:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_TIMEOUT, "Timeout fetching wake words"
|
||||
)
|
||||
return
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"wake_words": entity.supported_wake_words},
|
||||
{"wake_words": wake_words},
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .data import WyomingService, load_wyoming_info
|
||||
from .error import WyomingError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -28,7 +28,7 @@ async def async_setup_entry(
|
|||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingWakeWordProvider(config_entry, service),
|
||||
WyomingWakeWordProvider(hass, config_entry, service),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -38,10 +38,12 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.hass = hass
|
||||
self.service = service
|
||||
wake_service = service.info.wake[0]
|
||||
|
||||
|
@ -52,9 +54,19 @@ class WyomingWakeWordProvider(wake_word.WakeWordDetectionEntity):
|
|||
self._attr_name = wake_service.name
|
||||
self._attr_unique_id = f"{config_entry.entry_id}-wake_word"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
info = await load_wyoming_info(
|
||||
self.service.host, self.service.port, retries=0, timeout=1
|
||||
)
|
||||
|
||||
if info is not None:
|
||||
wake_service = info.wake[0]
|
||||
self._supported_wake_words = [
|
||||
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
|
||||
for ww in wake_service.models
|
||||
]
|
||||
|
||||
return self._supported_wake_words
|
||||
|
||||
async def _async_process_audio_stream(
|
||||
|
|
|
@ -181,8 +181,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||
url_path = "wake_word.test"
|
||||
_attr_name = "test"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [wake_word.WakeWord(id="test_ww", name="Test Wake Word")]
|
||||
|
||||
|
@ -191,7 +190,7 @@ class MockWakeWordEntity(wake_word.WakeWordDetectionEntity):
|
|||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
if wake_word_id is None:
|
||||
wake_word_id = self.supported_wake_words[0].id
|
||||
wake_word_id = (await self.get_supported_wake_words())[0].id
|
||||
async for chunk, timestamp in stream:
|
||||
if chunk.startswith(b"wake word"):
|
||||
return wake_word.DetectionResult(
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Test wake_word component setup."""
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
|
@ -37,8 +40,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
|||
url_path = "wake_word.test"
|
||||
_attr_name = "test"
|
||||
|
||||
@property
|
||||
def supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
|
||||
"""Return a list of supported wake words."""
|
||||
return [
|
||||
wake_word.WakeWord(id="test_ww", name="Test Wake Word"),
|
||||
|
@ -50,7 +52,7 @@ class MockProviderEntity(wake_word.WakeWordDetectionEntity):
|
|||
) -> wake_word.DetectionResult | None:
|
||||
"""Try to detect wake word(s) in an audio stream with timestamps."""
|
||||
if wake_word_id is None:
|
||||
wake_word_id = self.supported_wake_words[0].id
|
||||
wake_word_id = (await self.get_supported_wake_words())[0].id
|
||||
|
||||
async for _chunk, timestamp in stream:
|
||||
if timestamp >= 2000:
|
||||
|
@ -294,7 +296,7 @@ async def test_list_wake_words_unknown_entity(
|
|||
setup: MockProviderEntity,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test that the list_wake_words websocket command works."""
|
||||
"""Test that the list_wake_words websocket command handles unknown entity."""
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json(
|
||||
{
|
||||
|
@ -308,3 +310,28 @@ async def test_list_wake_words_unknown_entity(
|
|||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {"code": "not_found", "message": "Entity not found"}
|
||||
|
||||
|
||||
async def test_list_wake_words_timeout(
|
||||
hass: HomeAssistant,
|
||||
setup: MockProviderEntity,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test that the list_wake_words websocket command handles unknown entity."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch.object(
|
||||
setup, "get_supported_wake_words", partial(asyncio.sleep, 1)
|
||||
), patch("homeassistant.components.wake_word.TIMEOUT_FETCH_WAKE_WORDS", 0):
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "wake_word/info",
|
||||
"entity_id": setup.entity_id,
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {"code": "timeout", "message": "Timeout fetching wake words"}
|
||||
|
|
|
@ -6,12 +6,13 @@ from unittest.mock import patch
|
|||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.info import Info, WakeModel, WakeProgram
|
||||
from wyoming.wake import Detection
|
||||
|
||||
from homeassistant.components import wake_word
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import MockAsyncTcpClient
|
||||
from . import TEST_ATTR, MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
|
||||
|
@ -24,7 +25,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_wake_word) -> None:
|
|||
)
|
||||
assert entity is not None
|
||||
|
||||
assert entity.supported_wake_words == [
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord(id="Test Model", name="Test Model")
|
||||
]
|
||||
|
||||
|
@ -157,3 +158,55 @@ async def test_detect_message_with_wrong_wake_word(
|
|||
result = await entity.async_process_audio_stream(audio_stream(), "my-wake-word")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_dynamic_wake_word_info(
|
||||
hass: HomeAssistant, init_wyoming_wake_word
|
||||
) -> None:
|
||||
"""Test that supported wake words are loaded dynamically."""
|
||||
entity = wake_word.async_get_wake_word_detection_entity(
|
||||
hass, "wake_word.test_wake_word"
|
||||
)
|
||||
assert entity is not None
|
||||
|
||||
# Original info
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord("Test Model", "Test Model")
|
||||
]
|
||||
|
||||
new_info = Info(
|
||||
wake=[
|
||||
WakeProgram(
|
||||
name="dynamic",
|
||||
description="Dynamic Wake Word",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
models=[
|
||||
WakeModel(
|
||||
name="ww1",
|
||||
description="Wake Word 1",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=[],
|
||||
),
|
||||
WakeModel(
|
||||
name="ww2",
|
||||
description="Wake Word 2",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Different Wyoming info will be fetched
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.wake_word.load_wyoming_info",
|
||||
return_value=new_info,
|
||||
):
|
||||
assert (await entity.get_supported_wake_words()) == [
|
||||
wake_word.WakeWord("ww1", "Wake Word 1"),
|
||||
wake_word.WakeWord("ww2", "Wake Word 2"),
|
||||
]
|
||||
|
|
Loading…
Add table
Reference in a new issue