Make context a mandatory parameter for async_pipeline_from_audio_stream (#91658)

This commit is contained in:
Erik Montnemery 2023-04-19 15:30:29 +02:00 committed by GitHub
parent ebd20c8a7b
commit 090f59aaa2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 8 deletions

View file

@ -42,18 +42,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_pipeline_from_audio_stream( async def async_pipeline_from_audio_stream(
hass: HomeAssistant, hass: HomeAssistant,
context: Context,
event_callback: PipelineEventCallback, event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata, stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes], stt_stream: AsyncIterable[bytes],
pipeline_id: str | None = None, pipeline_id: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
context: Context | None = None,
tts_options: dict | None = None, tts_options: dict | None = None,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream.""" """Create an audio pipeline from an audio stream."""
if context is None:
context = Context()
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None: if pipeline is None:
raise PipelineNotFound( raise PipelineNotFound(

View file

@ -16,7 +16,7 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
) )
from homeassistant.components.media_player import async_process_play_media_url from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from .enum_mapper import EsphomeEnumMapper from .enum_mapper import EsphomeEnumMapper
@ -50,6 +50,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize UDP receiver.""" """Initialize UDP receiver."""
self.context = Context()
self.hass = hass self.hass = hass
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
@ -151,6 +152,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
await async_pipeline_from_audio_stream( await async_pipeline_from_audio_stream(
self.hass, self.hass,
context=self.context,
event_callback=handle_pipeline_event, event_callback=handle_pipeline_event,
stt_metadata=stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", language="",

View file

@ -21,7 +21,7 @@ from homeassistant.components.assist_pipeline import (
) )
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.core import HomeAssistant from homeassistant.core import Context, HomeAssistant
from .const import DOMAIN from .const import DOMAIN
@ -82,8 +82,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.audio_timeout = audio_timeout self.audio_timeout = audio_timeout
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue() self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._pipeline_task: asyncio.Task | None = None self._context = Context()
self._conversation_id: str | None = None self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
def connection_made(self, transport): def connection_made(self, transport):
"""Server is ready.""" """Server is ready."""
@ -133,6 +134,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with async_timeout.timeout(self.pipeline_timeout): async with async_timeout.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream( await async_pipeline_from_audio_stream(
self.hass, self.hass,
context=self._context,
event_callback=self._event_callback, event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream language="", # set in async_pipeline_from_audio_stream

View file

@ -4,7 +4,7 @@ from dataclasses import asdict
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt from homeassistant.components import assist_pipeline, stt
from homeassistant.core import HomeAssistant from homeassistant.core import Context, HomeAssistant
async def test_pipeline_from_audio_stream( async def test_pipeline_from_audio_stream(
@ -21,6 +21,7 @@ async def test_pipeline_from_audio_stream(
await assist_pipeline.async_pipeline_from_audio_stream( await assist_pipeline.async_pipeline_from_audio_stream(
hass, hass,
Context(),
events.append, events.append,
stt.SpeechMetadata( stt.SpeechMetadata(
language="", language="",