Pass device ID to conversation input (#93867)
This commit is contained in:
parent
a1e9cf1c24
commit
cd330a2740
10 changed files with 31 additions and 10 deletions
|
@ -57,6 +57,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
pipeline_id: str | None = None,
|
pipeline_id: str | None = None,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
tts_audio_output: str | None = None,
|
tts_audio_output: str | None = None,
|
||||||
|
device_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream.
|
"""Create an audio pipeline from an audio stream.
|
||||||
|
|
||||||
|
@ -64,6 +65,7 @@ async def async_pipeline_from_audio_stream(
|
||||||
"""
|
"""
|
||||||
pipeline_input = PipelineInput(
|
pipeline_input = PipelineInput(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
stt_stream=stt_stream,
|
stt_stream=stt_stream,
|
||||||
run=PipelineRun(
|
run=PipelineRun(
|
||||||
|
|
|
@ -499,7 +499,7 @@ class PipelineRun:
|
||||||
self.intent_agent = agent_info.id
|
self.intent_agent = agent_info.id
|
||||||
|
|
||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self, intent_input: str, conversation_id: str | None
|
self, intent_input: str, conversation_id: str | None, device_id: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
"""Run intent recognition portion of pipeline. Returns text to speak."""
|
||||||
if self.intent_agent is None:
|
if self.intent_agent is None:
|
||||||
|
@ -521,6 +521,7 @@ class PipelineRun:
|
||||||
hass=self.hass,
|
hass=self.hass,
|
||||||
text=intent_input,
|
text=intent_input,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
context=self.context,
|
context=self.context,
|
||||||
language=self.pipeline.conversation_language,
|
language=self.pipeline.conversation_language,
|
||||||
agent_id=self.intent_agent,
|
agent_id=self.intent_agent,
|
||||||
|
@ -655,6 +656,8 @@ class PipelineInput:
|
||||||
|
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
device_id: str | None = None
|
||||||
|
|
||||||
async def execute(self) -> None:
|
async def execute(self) -> None:
|
||||||
"""Run pipeline."""
|
"""Run pipeline."""
|
||||||
self.run.start()
|
self.run.start()
|
||||||
|
@ -678,7 +681,9 @@ class PipelineInput:
|
||||||
if current_stage == PipelineStage.INTENT:
|
if current_stage == PipelineStage.INTENT:
|
||||||
assert intent_input is not None
|
assert intent_input is not None
|
||||||
tts_input = await self.run.recognize_intent(
|
tts_input = await self.run.recognize_intent(
|
||||||
intent_input, self.conversation_id
|
intent_input,
|
||||||
|
self.conversation_id,
|
||||||
|
self.device_id,
|
||||||
)
|
)
|
||||||
current_stage = PipelineStage.TTS
|
current_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
|
|
@ -362,6 +362,7 @@ async def async_converse(
|
||||||
context: core.Context,
|
context: core.Context,
|
||||||
language: str | None = None,
|
language: str | None = None,
|
||||||
agent_id: str | None = None,
|
agent_id: str | None = None,
|
||||||
|
device_id: str | None = None,
|
||||||
) -> ConversationResult:
|
) -> ConversationResult:
|
||||||
"""Process text and get intent."""
|
"""Process text and get intent."""
|
||||||
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
|
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
|
||||||
|
@ -375,6 +376,7 @@ async def async_converse(
|
||||||
text=text,
|
text=text,
|
||||||
context=context,
|
context=context,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,6 +16,7 @@ class ConversationInput:
|
||||||
text: str
|
text: str
|
||||||
context: Context
|
context: Context
|
||||||
conversation_id: str | None
|
conversation_id: str | None
|
||||||
|
device_id: str | None
|
||||||
language: str
|
language: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ async def async_setup_entry( # noqa: C901
|
||||||
port = entry.data[CONF_PORT]
|
port = entry.data[CONF_PORT]
|
||||||
password = entry.data[CONF_PASSWORD]
|
password = entry.data[CONF_PASSWORD]
|
||||||
noise_psk = entry.data.get(CONF_NOISE_PSK)
|
noise_psk = entry.data.get(CONF_NOISE_PSK)
|
||||||
device_id: str | None = None
|
device_id: str = None # type: ignore[assignment]
|
||||||
|
|
||||||
zeroconf_instance = await zeroconf.async_get_instance(hass)
|
zeroconf_instance = await zeroconf.async_get_instance(hass)
|
||||||
|
|
||||||
|
@ -316,6 +316,7 @@ async def async_setup_entry( # noqa: C901
|
||||||
|
|
||||||
hass.async_create_background_task(
|
hass.async_create_background_task(
|
||||||
voice_assistant_udp_server.run_pipeline(
|
voice_assistant_udp_server.run_pipeline(
|
||||||
|
device_id=device_id,
|
||||||
conversation_id=conversation_id or None,
|
conversation_id=conversation_id or None,
|
||||||
use_vad=use_vad,
|
use_vad=use_vad,
|
||||||
),
|
),
|
||||||
|
|
|
@ -293,6 +293,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
|
|
||||||
async def run_pipeline(
|
async def run_pipeline(
|
||||||
self,
|
self,
|
||||||
|
device_id: str,
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
use_vad: bool = False,
|
use_vad: bool = False,
|
||||||
pipeline_timeout: float = 30.0,
|
pipeline_timeout: float = 30.0,
|
||||||
|
@ -331,6 +332,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
self.hass, DOMAIN, self.device_info.mac_address
|
self.hass, DOMAIN, self.device_info.mac_address
|
||||||
),
|
),
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
device_id=device_id,
|
||||||
tts_audio_output=tts_audio_output,
|
tts_audio_output=tts_audio_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -251,6 +251,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
self.hass, DOMAIN, self.voip_device.voip_id
|
self.hass, DOMAIN, self.voip_device.voip_id
|
||||||
),
|
),
|
||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
|
device_id=self.voip_device.device_id,
|
||||||
tts_audio_output="raw",
|
tts_audio_output="raw",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1371,6 +1371,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
|
||||||
text="open the front door",
|
text="open the front door",
|
||||||
context=Context(),
|
context=Context(),
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
|
device_id=None,
|
||||||
language=hass.config.language,
|
language=hass.config.language,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -68,7 +68,9 @@ async def test_pipeline_events(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline function is called."""
|
"""Test that the pipeline function is called."""
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||||
|
assert device_id == "mock-device-id"
|
||||||
|
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
# Fake events
|
# Fake events
|
||||||
|
@ -121,7 +123,9 @@ async def test_pipeline_events(
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v1.transport = Mock()
|
voice_assistant_udp_server_v1.transport = Mock()
|
||||||
|
|
||||||
await voice_assistant_udp_server_v1.run_pipeline(conversation_id=None)
|
await voice_assistant_udp_server_v1.run_pipeline(
|
||||||
|
device_id="mock-device-id", conversation_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server(
|
async def test_udp_server(
|
||||||
|
@ -380,7 +384,7 @@ async def test_speech_detection(
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -412,7 +416,7 @@ async def test_no_speech(
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -452,7 +456,7 @@ async def test_speech_timeout(
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
|
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -467,7 +471,7 @@ async def test_cancelled(
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(b"")
|
voice_assistant_udp_server_v2.queue.put_nowait(b"")
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# No events should be sent if cancelled while waiting for speech
|
# No events should be sent if cancelled while waiting for speech
|
||||||
|
|
|
@ -31,7 +31,9 @@ async def test_pipeline(
|
||||||
# Used to test that audio queue is cleared before pipeline starts
|
# Used to test that audio queue is cleared before pipeline starts
|
||||||
bad_chunk = bytes([1, 2, 3, 4])
|
bad_chunk = bytes([1, 2, 3, 4])
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||||
|
assert device_id == voip_device.device_id
|
||||||
|
|
||||||
stt_stream = kwargs["stt_stream"]
|
stt_stream = kwargs["stt_stream"]
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
async for _chunk in stt_stream:
|
async for _chunk in stt_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue