Make context a mandatory parameter for async_pipeline_from_audio_stream (#91658)

This commit is contained in:
Erik Montnemery 2023-04-19 15:30:29 +02:00 committed by GitHub
parent ebd20c8a7b
commit 090f59aaa2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 8 deletions

View file

@ -42,18 +42,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context: Context,
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
pipeline_id: str | None = None,
conversation_id: str | None = None,
context: Context | None = None,
tts_options: dict | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if context is None:
context = Context()
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(

View file

@ -16,7 +16,7 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Context, HomeAssistant, callback
from .enum_mapper import EsphomeEnumMapper
@ -50,6 +50,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize UDP receiver."""
self.context = Context()
self.hass = hass
self.queue = asyncio.Queue()
@ -151,6 +152,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=handle_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="",

View file

@ -21,7 +21,7 @@ from homeassistant.components.assist_pipeline import (
)
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.const import __version__
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from .const import DOMAIN
@ -82,8 +82,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.audio_timeout = audio_timeout
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
self._pipeline_task: asyncio.Task | None = None
self._context = Context()
self._conversation_id: str | None = None
self._pipeline_task: asyncio.Task | None = None
def connection_made(self, transport):
"""Server is ready."""
@ -133,6 +134,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with async_timeout.timeout(self.pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream

View file

@ -4,7 +4,7 @@ from dataclasses import asdict
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
async def test_pipeline_from_audio_stream(
@ -21,6 +21,7 @@ async def test_pipeline_from_audio_stream(
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="",