Pass device ID to conversation input (#93867)

This commit is contained in:
Paulus Schoutsen 2023-05-31 16:56:12 -04:00 committed by GitHub
parent a1e9cf1c24
commit cd330a2740
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 31 additions and 10 deletions

View file

@ -57,6 +57,7 @@ async def async_pipeline_from_audio_stream(
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
device_id: str | None = None,
) -> None:
"""Create an audio pipeline from an audio stream.
@ -64,6 +65,7 @@ async def async_pipeline_from_audio_stream(
"""
pipeline_input = PipelineInput(
conversation_id=conversation_id,
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
run=PipelineRun(

View file

@ -499,7 +499,7 @@ class PipelineRun:
self.intent_agent = agent_info.id
async def recognize_intent(
self, intent_input: str, conversation_id: str | None
self, intent_input: str, conversation_id: str | None, device_id: str | None
) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak."""
if self.intent_agent is None:
@ -521,6 +521,7 @@ class PipelineRun:
hass=self.hass,
text=intent_input,
conversation_id=conversation_id,
device_id=device_id,
context=self.context,
language=self.pipeline.conversation_language,
agent_id=self.intent_agent,
@ -655,6 +656,8 @@ class PipelineInput:
conversation_id: str | None = None
device_id: str | None = None
async def execute(self) -> None:
"""Run pipeline."""
self.run.start()
@ -678,7 +681,9 @@ class PipelineInput:
if current_stage == PipelineStage.INTENT:
assert intent_input is not None
tts_input = await self.run.recognize_intent(
intent_input, self.conversation_id
intent_input,
self.conversation_id,
self.device_id,
)
current_stage = PipelineStage.TTS

View file

@ -362,6 +362,7 @@ async def async_converse(
context: core.Context,
language: str | None = None,
agent_id: str | None = None,
device_id: str | None = None,
) -> ConversationResult:
"""Process text and get intent."""
agent = await _get_agent_manager(hass).async_get_agent(agent_id)
@ -375,6 +376,7 @@ async def async_converse(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
)

View file

@ -16,6 +16,7 @@ class ConversationInput:
text: str
context: Context
conversation_id: str | None
device_id: str | None
language: str

View file

@ -143,7 +143,7 @@ async def async_setup_entry( # noqa: C901
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK)
device_id: str | None = None
device_id: str = None # type: ignore[assignment]
zeroconf_instance = await zeroconf.async_get_instance(hass)
@ -316,6 +316,7 @@ async def async_setup_entry( # noqa: C901
hass.async_create_background_task(
voice_assistant_udp_server.run_pipeline(
device_id=device_id,
conversation_id=conversation_id or None,
use_vad=use_vad,
),

View file

@ -293,6 +293,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
async def run_pipeline(
self,
device_id: str,
conversation_id: str | None,
use_vad: bool = False,
pipeline_timeout: float = 30.0,
@ -331,6 +332,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
)

View file

@ -251,6 +251,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="raw",
)

View file

@ -1371,6 +1371,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
text="open the front door",
context=Context(),
conversation_id=None,
device_id=None,
language=hass.config.language,
)
)

View file

@ -68,7 +68,9 @@ async def test_pipeline_events(
) -> None:
"""Test that the pipeline function is called."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == "mock-device-id"
event_callback = kwargs["event_callback"]
# Fake events
@ -121,7 +123,9 @@ async def test_pipeline_events(
):
voice_assistant_udp_server_v1.transport = Mock()
await voice_assistant_udp_server_v1.run_pipeline(conversation_id=None)
await voice_assistant_udp_server_v1.run_pipeline(
device_id="mock-device-id", conversation_id=None
)
async def test_udp_server(
@ -380,7 +384,7 @@ async def test_speech_detection(
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
@ -412,7 +416,7 @@ async def test_no_speech(
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
@ -452,7 +456,7 @@ async def test_speech_timeout(
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
@ -467,7 +471,7 @@ async def test_cancelled(
voice_assistant_udp_server_v2.queue.put_nowait(b"")
await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
# No events should be sent if cancelled while waiting for speech

View file

@ -31,7 +31,9 @@ async def test_pipeline(
# Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, **kwargs):
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream: