Change wake word interception to a subscription (#125629)
* Allow stopping intercepting wake words * Make wake word interception a subscription * Keep future * Add test for unsub
This commit is contained in:
parent
3ba39d5158
commit
c63cab336c
2 changed files with 129 additions and 31 deletions
|
@ -6,6 +6,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
|
@ -42,5 +43,19 @@ async def websocket_intercept_wake_word(
|
|||
)
|
||||
return
|
||||
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
||||
async def intercept_wake_word() -> None:
|
||||
"""Push an intercepted wake word to websocket."""
|
||||
try:
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
msg["id"],
|
||||
{"wake_word_phrase": wake_word_phrase},
|
||||
)
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
connection.send_error(msg["id"], "home_assistant_error", str(err))
|
||||
|
||||
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
|
||||
connection.subscriptions[msg["id"]] = task.cancel
|
||||
connection.send_message(websocket_api.result_message(msg["id"]))
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Test WebSocket API."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineStage
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
@ -28,20 +31,23 @@ async def test_intercept_wake_word(
|
|||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
msg = await ws_client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] is None
|
||||
subscription_id = msg["id"]
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
object(), # type: ignore[arg-type]
|
||||
start_stage=PipelineStage.STT,
|
||||
wake_word_phrase="ok, nabu",
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
|
||||
assert msg["id"] == subscription_id
|
||||
assert msg["type"] == "event"
|
||||
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||
|
@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
|
|||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"] is None
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
object(), # type: ignore[arg-type]
|
||||
# Emulate wake word processing in Home Assistant
|
||||
start_stage=PipelineStage.WAKE_WORD,
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Only on-device wake words currently supported",
|
||||
}
|
||||
|
@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
|
|||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"] is None
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
object(), # type: ignore[arg-type]
|
||||
start_stage=PipelineStage.STT,
|
||||
# We are not passing wake word phrase
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "No wake word phrase provided",
|
||||
}
|
||||
|
@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin(
|
|||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "unauthorized",
|
||||
"message": "Unauthorized",
|
||||
}
|
||||
|
@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite(
|
|||
"entity_id": "assist_satellite.invalid",
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Entity not found",
|
||||
}
|
||||
|
@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice(
|
|||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
"""Test intercepting a wake word twice cancels the previous request."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
|
@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice(
|
|||
}
|
||||
)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"] is None
|
||||
|
||||
task = hass.async_create_task(ws_client.receive_json())
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
# Should get an error from previous subscription
|
||||
async with asyncio.timeout(1):
|
||||
msg = await task
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Wake word interception already in progress",
|
||||
}
|
||||
|
||||
# Response to second subscription
|
||||
async with asyncio.timeout(1):
|
||||
msg = await ws_client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"] is None
|
||||
|
||||
|
||||
async def test_intercept_wake_word_unsubscribe(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test that closing the websocket connection stops interception."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
# Wait for interception to start
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def receive_json():
|
||||
with pytest.raises(TypeError):
|
||||
# Raises TypeError when connection is closed
|
||||
await ws_client.receive_json()
|
||||
|
||||
task = hass.async_create_task(receive_json())
|
||||
|
||||
# Close connection
|
||||
await ws_client.close()
|
||||
await task
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_pipeline_from_audio_stream,
|
||||
):
|
||||
# Start a pipeline with a wake word
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
wake_word_phrase="ok, nabu", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Wake word should not be intercepted
|
||||
mock_pipeline_from_audio_stream.assert_called_once()
|
||||
|
|
Loading…
Add table
Reference in a new issue