diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 58b0dd00bc9..070dd062e42 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -286,9 +286,9 @@ class Stream: """Handle consuming streams and restart keepalive streams.""" # Keep import here so that we can import stream integration without installing reqs # pylint: disable=import-outside-toplevel - from .worker import SegmentBuffer, StreamWorkerError, stream_worker + from .worker import StreamState, StreamWorkerError, stream_worker - segment_buffer = SegmentBuffer(self.hass, self.outputs) + stream_state = StreamState(self.hass, self.outputs) wait_timeout = 0 while not self._thread_quit.wait(timeout=wait_timeout): start_time = time.time() @@ -298,14 +298,14 @@ class Stream: stream_worker( self.source, self.options, - segment_buffer, + stream_state, self._thread_quit, ) except StreamWorkerError as err: _LOGGER.error("Error from stream worker: %s", str(err)) self._available = False - segment_buffer.discontinuity() + stream_state.discontinuity() if not self.keepalive or self._thread_quit.is_set(): if self._fast_restart_once: # The stream source is updated, restart without any delay. diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 5176b93dedf..b1d79e52800 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -40,28 +40,77 @@ class StreamEndedError(StreamWorkerError): """Raised when the stream is complete, exposed for facilitating testing.""" -class SegmentBuffer: - """Buffer for writing a sequence of packets to the output as a segment.""" +class StreamState: + """Responsible for trakcing output and playback state for a stream. + + Holds state used for playback to interpret a decoded stream. A source stream + may be reset (e.g. reconnecting to an rtsp stream) and this object tracks + the state to inform the player. + """ def __init__( self, hass: HomeAssistant, outputs_callback: Callable[[], Mapping[str, StreamOutput]], ) -> None: - """Initialize SegmentBuffer.""" + """Initialize StreamState.""" self._stream_id: int = 0 - self._hass = hass + self.hass = hass self._outputs_callback: Callable[ [], Mapping[str, StreamOutput] ] = outputs_callback # sequence gets incremented before the first segment so the first segment # has a sequence number of 0. self._sequence = -1 + + @property + def sequence(self) -> int: + """Return the current sequence for the latest segment.""" + return self._sequence + + def next_sequence(self) -> int: + """Increment the sequence number.""" + self._sequence += 1 + return self._sequence + + @property + def stream_id(self) -> int: + """Return the readonly stream_id attribute.""" + return self._stream_id + + def discontinuity(self) -> None: + """Mark the stream as having been restarted.""" + # Preserving sequence and stream_id here keep the HLS playlist logic + # simple to check for discontinuity at output time, and to determine + # the discontinuity sequence number. + self._stream_id += 1 + # Call discontinuity to remove incomplete segment from the HLS output + if hls_output := self._outputs_callback().get(HLS_PROVIDER): + cast(HlsStreamOutput, hls_output).discontinuity() + + @property + def outputs(self) -> list[StreamOutput]: + """Return the active stream outputs.""" + return list(self._outputs_callback().values()) + + +class StreamMuxer: + """StreamMuxer re-packages video/audio packets for output.""" + + def __init__( + self, + hass: HomeAssistant, + video_stream: av.video.VideoStream, + audio_stream: av.audio.stream.AudioStream | None, + stream_state: StreamState, + ) -> None: + """Initialize StreamMuxer.""" + self._hass = hass self._segment_start_dts: int = cast(int, None) self._memory_file: BytesIO = cast(BytesIO, None) self._av_output: av.container.OutputContainer = None - self._input_video_stream: av.video.VideoStream = None - self._input_audio_stream: av.audio.stream.AudioStream | None = None + self._input_video_stream: av.video.VideoStream = video_stream + self._input_audio_stream: av.audio.stream.AudioStream | None = audio_stream self._output_video_stream: av.video.VideoStream = None self._output_audio_stream: av.audio.stream.AudioStream | None = None self._segment: Segment | None = None @@ -70,6 +119,7 @@ class SegmentBuffer: self._part_start_dts: int = cast(int, None) self._part_has_keyframe = False self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS] + self._stream_state = stream_state self._start_time = datetime.datetime.utcnow() def make_new_av( @@ -77,14 +127,13 @@ class SegmentBuffer: memory_file: BytesIO, sequence: int, input_vstream: av.video.VideoStream, - input_astream: av.audio.stream.AudioStream, + input_astream: av.audio.stream.AudioStream | None, ) -> tuple[ av.container.OutputContainer, av.video.VideoStream, av.audio.stream.AudioStream | None, ]: """Make a new av OutputContainer and add output streams.""" - add_audio = input_astream and input_astream.name in AUDIO_CODECS container = av.open( memory_file, mode="w", @@ -135,24 +184,12 @@ class SegmentBuffer: output_vstream = container.add_stream(template=input_vstream) # Check if audio is requested output_astream = None - if add_audio: + if input_astream: output_astream = container.add_stream(template=input_astream) return container, output_vstream, output_astream - def set_streams( - self, - video_stream: av.video.VideoStream, - audio_stream: Any, - # no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged - ) -> None: - """Initialize output buffer with streams from container.""" - self._input_video_stream = video_stream - self._input_audio_stream = audio_stream - def reset(self, video_dts: int) -> None: """Initialize a new stream segment.""" - # Keep track of the number of segments we've processed - self._sequence += 1 self._part_start_dts = self._segment_start_dts = video_dts self._segment = None self._memory_file = BytesIO() @@ -163,7 +200,7 @@ class SegmentBuffer: self._output_audio_stream, ) = self.make_new_av( memory_file=self._memory_file, - sequence=self._sequence, + sequence=self._stream_state.next_sequence(), input_vstream=self._input_video_stream, input_astream=self._input_audio_stream, ) @@ -201,12 +238,12 @@ class SegmentBuffer: # We have our first non-zero byte position. This means the init has just # been written. Create a Segment and put it to the queue of each output. self._segment = Segment( - sequence=self._sequence, - stream_id=self._stream_id, + sequence=self._stream_state.sequence, + stream_id=self._stream_state.stream_id, init=self._memory_file.getvalue(), # Fetch the latest StreamOutputs, which may have changed since the # worker started. - stream_outputs=self._outputs_callback().values(), + stream_outputs=self._stream_state.outputs, start_time=self._start_time, ) self._memory_file_pos = self._memory_file.tell() @@ -283,17 +320,6 @@ class SegmentBuffer: self._part_start_dts = adjusted_dts self._part_has_keyframe = False - def discontinuity(self) -> None: - """Mark the stream as having been restarted.""" - # Preserving sequence and stream_id here keep the HLS playlist logic - # simple to check for discontinuity at output time, and to determine - # the discontinuity sequence number. - self._stream_id += 1 - self._start_time = datetime.datetime.utcnow() - # Call discontinuity to remove incomplete segment from the HLS output - if hls_output := self._outputs_callback().get(HLS_PROVIDER): - cast(HlsStreamOutput, hls_output).discontinuity() - def close(self) -> None: """Close stream buffer.""" self._av_output.close() @@ -412,7 +438,7 @@ def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool: def stream_worker( source: str, options: dict[str, str], - segment_buffer: SegmentBuffer, + stream_state: StreamState, quit_event: Event, ) -> None: """Handle consuming streams.""" @@ -431,6 +457,8 @@ def stream_worker( audio_stream = container.streams.audio[0] except (KeyError, IndexError): audio_stream = None + if audio_stream and audio_stream.name not in AUDIO_CODECS: + audio_stream = None # These formats need aac_adtstoasc bitstream filter, but auto_bsf not # compatible with empty_moov and manual bitstream filters not in PyAV if container.format.name in {"hls", "mpegts"}: @@ -489,13 +517,13 @@ def stream_worker( "Error demuxing stream while finding first packet: %s" % str(ex) ) from ex - segment_buffer.set_streams(video_stream, audio_stream) - segment_buffer.reset(start_dts) + muxer = StreamMuxer(stream_state.hass, video_stream, audio_stream, stream_state) + muxer.reset(start_dts) # Mux the first keyframe, then proceed through the rest of the packets - segment_buffer.mux_packet(first_keyframe) + muxer.mux_packet(first_keyframe) - with contextlib.closing(container), contextlib.closing(segment_buffer): + with contextlib.closing(container), contextlib.closing(muxer): while not quit_event.is_set(): try: packet = next(container_packets) @@ -506,4 +534,4 @@ def stream_worker( except av.AVError as ex: raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex - segment_buffer.mux_packet(packet) + muxer.mux_packet(packet) diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py index 62c62593c57..10328a8f87b 100644 --- a/tests/components/stream/conftest.py +++ b/tests/components/stream/conftest.py @@ -23,7 +23,7 @@ import async_timeout import pytest from homeassistant.components.stream.core import Segment, StreamOutput -from homeassistant.components.stream.worker import SegmentBuffer +from homeassistant.components.stream.worker import StreamState TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout @@ -34,7 +34,7 @@ class WorkerSync: def __init__(self): """Initialize WorkerSync.""" self._event = None - self._original = SegmentBuffer.discontinuity + self._original = StreamState.discontinuity def pause(self): """Pause the worker before it finalizes the stream.""" @@ -45,7 +45,7 @@ class WorkerSync: logging.debug("waking blocked worker") self._event.set() - def blocking_discontinuity(self, stream: SegmentBuffer): + def blocking_discontinuity(self, stream_state: StreamState): """Intercept call to pause stream worker.""" # Worker is ending the stream, which clears all output buffers. # Block the worker thread until the test has a chance to verify @@ -55,7 +55,7 @@ class WorkerSync: self._event.wait() # Forward to actual implementation - self._original(stream) + self._original(stream_state) @pytest.fixture() @@ -63,7 +63,7 @@ def stream_worker_sync(hass): """Patch StreamOutput to allow test to synchronize worker stream end.""" sync = WorkerSync() with patch( - "homeassistant.components.stream.worker.SegmentBuffer.discontinuity", + "homeassistant.components.stream.worker.StreamState.discontinuity", side_effect=sync.blocking_discontinuity, autospec=True, ): diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index c65e10d65f3..3e9ea157934 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -38,8 +38,8 @@ from homeassistant.components.stream.const import ( ) from homeassistant.components.stream.core import StreamSettings from homeassistant.components.stream.worker import ( - SegmentBuffer, StreamEndedError, + StreamState, StreamWorkerError, stream_worker, ) @@ -255,6 +255,12 @@ class MockPyAv: return self.container +def run_worker(hass, stream, stream_source): + """Run the stream worker under test.""" + stream_state = StreamState(hass, stream.outputs) + stream_worker(stream_source, {}, stream_state, threading.Event()) + + async def async_decode_stream(hass, packets, py_av=None): """Start a stream worker that decodes incoming stream packets into output segments.""" stream = Stream(hass, STREAM_SOURCE, {}) @@ -268,9 +274,8 @@ async def async_decode_stream(hass, packets, py_av=None): "homeassistant.components.stream.core.StreamOutput.put", side_effect=py_av.capture_buffer.capture_output_segment, ): - segment_buffer = SegmentBuffer(hass, stream.outputs) try: - stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) + run_worker(hass, stream, STREAM_SOURCE) except StreamEndedError: # Tests only use a limited number of packets, then the worker exits as expected. In # production, stream ending would be unexpected. @@ -288,8 +293,7 @@ async def test_stream_open_fails(hass): stream.add_provider(HLS_PROVIDER) with patch("av.open") as av_open, pytest.raises(StreamWorkerError): av_open.side_effect = av.error.InvalidDataError(-2, "error") - segment_buffer = SegmentBuffer(hass, stream.outputs) - stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) + run_worker(hass, stream, STREAM_SOURCE) await hass.async_block_till_done() av_open.assert_called_once() @@ -695,10 +699,7 @@ async def test_worker_log(hass, caplog): with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err: av_open.side_effect = av.error.InvalidDataError(-2, "error") - segment_buffer = SegmentBuffer(hass, stream.outputs) - stream_worker( - "https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event() - ) + run_worker(hass, stream, "https://abcd:efgh@foo.bar") await hass.async_block_till_done() assert str(err.value) == "Error opening stream https://****:****@foo.bar" assert "https://abcd:efgh@foo.bar" not in caplog.text