Change wake word interception to a subscription (#125629)

* Allow stopping intercepting wake words

* Make wake word interception a subscription

* Keep future

* Add test for unsub
This commit is contained in:
Michael Hansen 2024-09-16 07:50:43 -05:00 committed by GitHub
parent 3ba39d5158
commit c63cab336c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 129 additions and 31 deletions

View file

@ -6,6 +6,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
@ -42,5 +43,19 @@ async def websocket_intercept_wake_word(
)
return
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"]))

View file

@ -1,6 +1,9 @@
"""Test WebSocket API."""
import asyncio
from unittest.mock import patch
import pytest
from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.config_entries import ConfigEntry
@ -28,20 +31,23 @@ async def test_intercept_wake_word(
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
subscription_id = msg["id"]
await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
wake_word_phrase="ok, nabu",
)
response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
assert msg["id"] == subscription_id
assert msg["type"] == "event"
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}
async def test_intercept_wake_word_requires_on_device_wake_word(
@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
}
)
for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
# Emulate wake word processing in Home Assistant
start_stage=PipelineStage.WAKE_WORD,
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
}
)
for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
# We are not passing wake word phrase
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin(
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite(
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Entity not found",
}
@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice(
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
"""Test intercepting a wake word twice cancels the previous request."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice(
}
)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
task = hass.async_create_task(ws_client.receive_json())
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
# Should get an error from previous subscription
async with asyncio.timeout(1):
msg = await task
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Wake word interception already in progress",
}
# Response to second subscription
async with asyncio.timeout(1):
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
async def test_intercept_wake_word_unsubscribe(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that closing the websocket connection stops interception."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
# Wait for interception to start
for _ in range(3):
await asyncio.sleep(0)
async def receive_json():
with pytest.raises(TypeError):
# Raises TypeError when connection is closed
await ws_client.receive_json()
task = hass.async_create_task(receive_json())
# Close connection
await ws_client.close()
await task
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream,
):
# Start a pipeline with a wake word
await entity.async_accept_pipeline_from_satellite(
object(),
wake_word_phrase="ok, nabu", # type: ignore[arg-type]
)
# Wake word should not be intercepted
mock_pipeline_from_audio_stream.assert_called_once()