Change wake word interception to a subscription (#125629)
* Allow stopping intercepting wake words * Make wake word interception a subscription * Keep future * Add test for unsub
This commit is contained in:
parent
3ba39d5158
commit
c63cab336c
2 changed files with 129 additions and 31 deletions
homeassistant/components/assist_satellite
tests/components/assist_satellite
|
@ -6,6 +6,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
@ -42,5 +43,19 @@ async def websocket_intercept_wake_word(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def intercept_wake_word() -> None:
|
||||||
|
"""Push an intercepted wake word to websocket."""
|
||||||
|
try:
|
||||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
connection.send_message(
|
||||||
|
websocket_api.event_message(
|
||||||
|
msg["id"],
|
||||||
|
{"wake_word_phrase": wake_word_phrase},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
connection.send_error(msg["id"], "home_assistant_error", str(err))
|
||||||
|
|
||||||
|
task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
|
||||||
|
connection.subscriptions[msg["id"]] = task.cancel
|
||||||
|
connection.send_message(websocket_api.result_message(msg["id"]))
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
"""Test WebSocket API."""
|
"""Test WebSocket API."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import PipelineStage
|
from homeassistant.components.assist_pipeline import PipelineStage
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
@ -28,20 +31,23 @@ async def test_intercept_wake_word(
|
||||||
"entity_id": ENTITY_ID,
|
"entity_id": ENTITY_ID,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
for _ in range(3):
|
assert msg["success"]
|
||||||
await asyncio.sleep(0)
|
assert msg["result"] is None
|
||||||
|
subscription_id = msg["id"]
|
||||||
|
|
||||||
await entity.async_accept_pipeline_from_satellite(
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
object(),
|
object(), # type: ignore[arg-type]
|
||||||
start_stage=PipelineStage.STT,
|
start_stage=PipelineStage.STT,
|
||||||
wake_word_phrase="ok, nabu",
|
wake_word_phrase="ok, nabu",
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await ws_client.receive_json()
|
async with asyncio.timeout(1):
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
assert response["success"]
|
assert msg["id"] == subscription_id
|
||||||
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
|
assert msg["type"] == "event"
|
||||||
|
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}
|
||||||
|
|
||||||
|
|
||||||
async def test_intercept_wake_word_requires_on_device_wake_word(
|
async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||||
|
@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
async with asyncio.timeout(1):
|
||||||
await asyncio.sleep(0)
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] is None
|
||||||
|
|
||||||
await entity.async_accept_pipeline_from_satellite(
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
object(),
|
object(), # type: ignore[arg-type]
|
||||||
# Emulate wake word processing in Home Assistant
|
# Emulate wake word processing in Home Assistant
|
||||||
start_stage=PipelineStage.WAKE_WORD,
|
start_stage=PipelineStage.WAKE_WORD,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await ws_client.receive_json()
|
async with asyncio.timeout(1):
|
||||||
assert not response["success"]
|
msg = await ws_client.receive_json()
|
||||||
assert response["error"] == {
|
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
"code": "home_assistant_error",
|
"code": "home_assistant_error",
|
||||||
"message": "Only on-device wake words currently supported",
|
"message": "Only on-device wake words currently supported",
|
||||||
}
|
}
|
||||||
|
@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(3):
|
async with asyncio.timeout(1):
|
||||||
await asyncio.sleep(0)
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] is None
|
||||||
|
|
||||||
await entity.async_accept_pipeline_from_satellite(
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
object(),
|
object(), # type: ignore[arg-type]
|
||||||
start_stage=PipelineStage.STT,
|
start_stage=PipelineStage.STT,
|
||||||
# We are not passing wake word phrase
|
# We are not passing wake word phrase
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await ws_client.receive_json()
|
async with asyncio.timeout(1):
|
||||||
assert not response["success"]
|
msg = await ws_client.receive_json()
|
||||||
assert response["error"] == {
|
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
"code": "home_assistant_error",
|
"code": "home_assistant_error",
|
||||||
"message": "No wake word phrase provided",
|
"message": "No wake word phrase provided",
|
||||||
}
|
}
|
||||||
|
@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin(
|
||||||
"entity_id": ENTITY_ID,
|
"entity_id": ENTITY_ID,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response = await ws_client.receive_json()
|
|
||||||
|
|
||||||
assert not response["success"]
|
async with asyncio.timeout(1):
|
||||||
assert response["error"] == {
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
"code": "unauthorized",
|
"code": "unauthorized",
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
}
|
}
|
||||||
|
@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite(
|
||||||
"entity_id": "assist_satellite.invalid",
|
"entity_id": "assist_satellite.invalid",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response = await ws_client.receive_json()
|
async with asyncio.timeout(1):
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
assert not response["success"]
|
assert not msg["success"]
|
||||||
assert response["error"] == {
|
assert msg["error"] == {
|
||||||
"code": "not_found",
|
"code": "not_found",
|
||||||
"message": "Entity not found",
|
"message": "Entity not found",
|
||||||
}
|
}
|
||||||
|
@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice(
|
||||||
entity: MockAssistSatellite,
|
entity: MockAssistSatellite,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test intercepting a wake word requires admin access."""
|
"""Test intercepting a wake word twice cancels the previous request."""
|
||||||
ws_client = await hass_ws_client(hass)
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await ws_client.send_json_auto_id(
|
await ws_client.send_json_auto_id(
|
||||||
|
@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] is None
|
||||||
|
|
||||||
|
task = hass.async_create_task(ws_client.receive_json())
|
||||||
|
|
||||||
await ws_client.send_json_auto_id(
|
await ws_client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_satellite/intercept_wake_word",
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
"entity_id": ENTITY_ID,
|
"entity_id": ENTITY_ID,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response = await ws_client.receive_json()
|
|
||||||
|
|
||||||
assert not response["success"]
|
# Should get an error from previous subscription
|
||||||
assert response["error"] == {
|
async with asyncio.timeout(1):
|
||||||
|
msg = await task
|
||||||
|
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
"code": "home_assistant_error",
|
"code": "home_assistant_error",
|
||||||
"message": "Wake word interception already in progress",
|
"message": "Wake word interception already in progress",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Response to second subscription
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
msg = await ws_client.receive_json()
|
||||||
|
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intercept_wake_word_unsubscribe(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test that closing the websocket connection stops interception."""
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_satellite/intercept_wake_word",
|
||||||
|
"entity_id": ENTITY_ID,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for interception to start
|
||||||
|
for _ in range(3):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def receive_json():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# Raises TypeError when connection is closed
|
||||||
|
await ws_client.receive_json()
|
||||||
|
|
||||||
|
task = hass.async_create_task(receive_json())
|
||||||
|
|
||||||
|
# Close connection
|
||||||
|
await ws_client.close()
|
||||||
|
await task
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
) as mock_pipeline_from_audio_stream,
|
||||||
|
):
|
||||||
|
# Start a pipeline with a wake word
|
||||||
|
await entity.async_accept_pipeline_from_satellite(
|
||||||
|
object(),
|
||||||
|
wake_word_phrase="ok, nabu", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wake word should not be intercepted
|
||||||
|
mock_pipeline_from_audio_stream.assert_called_once()
|
||||||
|
|
Loading…
Add table
Reference in a new issue