Pass device ID to conversation input (#93867)

This commit is contained in:
Paulus Schoutsen 2023-05-31 16:56:12 -04:00 committed by GitHub
parent a1e9cf1c24
commit cd330a2740
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 31 additions and 10 deletions

View file

@ -57,6 +57,7 @@ async def async_pipeline_from_audio_stream(
pipeline_id: str | None = None, pipeline_id: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
device_id: str | None = None,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream. """Create an audio pipeline from an audio stream.
@ -64,6 +65,7 @@ async def async_pipeline_from_audio_stream(
""" """
pipeline_input = PipelineInput( pipeline_input = PipelineInput(
conversation_id=conversation_id, conversation_id=conversation_id,
device_id=device_id,
stt_metadata=stt_metadata, stt_metadata=stt_metadata,
stt_stream=stt_stream, stt_stream=stt_stream,
run=PipelineRun( run=PipelineRun(

View file

@ -499,7 +499,7 @@ class PipelineRun:
self.intent_agent = agent_info.id self.intent_agent = agent_info.id
async def recognize_intent( async def recognize_intent(
self, intent_input: str, conversation_id: str | None self, intent_input: str, conversation_id: str | None, device_id: str | None
) -> str: ) -> str:
"""Run intent recognition portion of pipeline. Returns text to speak.""" """Run intent recognition portion of pipeline. Returns text to speak."""
if self.intent_agent is None: if self.intent_agent is None:
@ -521,6 +521,7 @@ class PipelineRun:
hass=self.hass, hass=self.hass,
text=intent_input, text=intent_input,
conversation_id=conversation_id, conversation_id=conversation_id,
device_id=device_id,
context=self.context, context=self.context,
language=self.pipeline.conversation_language, language=self.pipeline.conversation_language,
agent_id=self.intent_agent, agent_id=self.intent_agent,
@ -655,6 +656,8 @@ class PipelineInput:
conversation_id: str | None = None conversation_id: str | None = None
device_id: str | None = None
async def execute(self) -> None: async def execute(self) -> None:
"""Run pipeline.""" """Run pipeline."""
self.run.start() self.run.start()
@ -678,7 +681,9 @@ class PipelineInput:
if current_stage == PipelineStage.INTENT: if current_stage == PipelineStage.INTENT:
assert intent_input is not None assert intent_input is not None
tts_input = await self.run.recognize_intent( tts_input = await self.run.recognize_intent(
intent_input, self.conversation_id intent_input,
self.conversation_id,
self.device_id,
) )
current_stage = PipelineStage.TTS current_stage = PipelineStage.TTS

View file

@ -362,6 +362,7 @@ async def async_converse(
context: core.Context, context: core.Context,
language: str | None = None, language: str | None = None,
agent_id: str | None = None, agent_id: str | None = None,
device_id: str | None = None,
) -> ConversationResult: ) -> ConversationResult:
"""Process text and get intent.""" """Process text and get intent."""
agent = await _get_agent_manager(hass).async_get_agent(agent_id) agent = await _get_agent_manager(hass).async_get_agent(agent_id)
@ -375,6 +376,7 @@ async def async_converse(
text=text, text=text,
context=context, context=context,
conversation_id=conversation_id, conversation_id=conversation_id,
device_id=device_id,
language=language, language=language,
) )
) )

View file

@ -16,6 +16,7 @@ class ConversationInput:
text: str text: str
context: Context context: Context
conversation_id: str | None conversation_id: str | None
device_id: str | None
language: str language: str

View file

@ -143,7 +143,7 @@ async def async_setup_entry( # noqa: C901
port = entry.data[CONF_PORT] port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD] password = entry.data[CONF_PASSWORD]
noise_psk = entry.data.get(CONF_NOISE_PSK) noise_psk = entry.data.get(CONF_NOISE_PSK)
device_id: str | None = None device_id: str = None # type: ignore[assignment]
zeroconf_instance = await zeroconf.async_get_instance(hass) zeroconf_instance = await zeroconf.async_get_instance(hass)
@ -316,6 +316,7 @@ async def async_setup_entry( # noqa: C901
hass.async_create_background_task( hass.async_create_background_task(
voice_assistant_udp_server.run_pipeline( voice_assistant_udp_server.run_pipeline(
device_id=device_id,
conversation_id=conversation_id or None, conversation_id=conversation_id or None,
use_vad=use_vad, use_vad=use_vad,
), ),

View file

@ -293,6 +293,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
async def run_pipeline( async def run_pipeline(
self, self,
device_id: str,
conversation_id: str | None, conversation_id: str | None,
use_vad: bool = False, use_vad: bool = False,
pipeline_timeout: float = 30.0, pipeline_timeout: float = 30.0,
@ -331,6 +332,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.hass, DOMAIN, self.device_info.mac_address self.hass, DOMAIN, self.device_info.mac_address
), ),
conversation_id=conversation_id, conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output, tts_audio_output=tts_audio_output,
) )

View file

@ -251,6 +251,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.hass, DOMAIN, self.voip_device.voip_id self.hass, DOMAIN, self.voip_device.voip_id
), ),
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="raw", tts_audio_output="raw",
) )

View file

@ -1371,6 +1371,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
text="open the front door", text="open the front door",
context=Context(), context=Context(),
conversation_id=None, conversation_id=None,
device_id=None,
language=hass.config.language, language=hass.config.language,
) )
) )

View file

@ -68,7 +68,9 @@ async def test_pipeline_events(
) -> None: ) -> None:
"""Test that the pipeline function is called.""" """Test that the pipeline function is called."""
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == "mock-device-id"
event_callback = kwargs["event_callback"] event_callback = kwargs["event_callback"]
# Fake events # Fake events
@ -121,7 +123,9 @@ async def test_pipeline_events(
): ):
voice_assistant_udp_server_v1.transport = Mock() voice_assistant_udp_server_v1.transport = Mock()
await voice_assistant_udp_server_v1.run_pipeline(conversation_id=None) await voice_assistant_udp_server_v1.run_pipeline(
device_id="mock-device-id", conversation_id=None
)
async def test_udp_server( async def test_udp_server(
@ -380,7 +384,7 @@ async def test_speech_detection(
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0 device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
) )
@ -412,7 +416,7 @@ async def test_no_speech(
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0 device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
) )
@ -452,7 +456,7 @@ async def test_speech_timeout(
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2))) voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0 device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
) )
@ -467,7 +471,7 @@ async def test_cancelled(
voice_assistant_udp_server_v2.queue.put_nowait(b"") voice_assistant_udp_server_v2.queue.put_nowait(b"")
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_udp_server_v2.run_pipeline(
conversation_id=None, use_vad=True, pipeline_timeout=1.0 device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
) )
# No events should be sent if cancelled while waiting for speech # No events should be sent if cancelled while waiting for speech

View file

@ -31,7 +31,9 @@ async def test_pipeline(
# Used to test that audio queue is cleared before pipeline starts # Used to test that audio queue is cleared before pipeline starts
bad_chunk = bytes([1, 2, 3, 4]) bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == voip_device.device_id
stt_stream = kwargs["stt_stream"] stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"] event_callback = kwargs["event_callback"]
async for _chunk in stt_stream: async for _chunk in stt_stream: