Split StreamState class out of SegmentBuffer (#60423)

This refactoring was pulled out of https://github.com/home-assistant/core/pull/53676 as an
initial step towards reverting the addition of the SegmentBuffer class, which will be
unrolled back into a for loop.

The StreamState class holds the persistent state in stream that is used across stream worker
instantiations, e.g. state across a retry or url expiration, which primarily handles
discontinuities. By itself, this PR is not a large win until follow up PRs further simplify
the SegmentBuffer class.
This commit is contained in:
Allen Porter 2021-11-29 22:25:28 -08:00 committed by GitHub
parent 890790a659
commit 8ca89b10eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 60 deletions

View file

@ -286,9 +286,9 @@ class Stream:
"""Handle consuming streams and restart keepalive streams.""" """Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs # Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .worker import SegmentBuffer, StreamWorkerError, stream_worker from .worker import StreamState, StreamWorkerError, stream_worker
segment_buffer = SegmentBuffer(self.hass, self.outputs) stream_state = StreamState(self.hass, self.outputs)
wait_timeout = 0 wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout): while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time() start_time = time.time()
@ -298,14 +298,14 @@ class Stream:
stream_worker( stream_worker(
self.source, self.source,
self.options, self.options,
segment_buffer, stream_state,
self._thread_quit, self._thread_quit,
) )
except StreamWorkerError as err: except StreamWorkerError as err:
_LOGGER.error("Error from stream worker: %s", str(err)) _LOGGER.error("Error from stream worker: %s", str(err))
self._available = False self._available = False
segment_buffer.discontinuity() stream_state.discontinuity()
if not self.keepalive or self._thread_quit.is_set(): if not self.keepalive or self._thread_quit.is_set():
if self._fast_restart_once: if self._fast_restart_once:
# The stream source is updated, restart without any delay. # The stream source is updated, restart without any delay.

View file

@ -40,28 +40,77 @@ class StreamEndedError(StreamWorkerError):
"""Raised when the stream is complete, exposed for facilitating testing.""" """Raised when the stream is complete, exposed for facilitating testing."""
class SegmentBuffer: class StreamState:
"""Buffer for writing a sequence of packets to the output as a segment.""" """Responsible for trakcing output and playback state for a stream.
Holds state used for playback to interpret a decoded stream. A source stream
may be reset (e.g. reconnecting to an rtsp stream) and this object tracks
the state to inform the player.
"""
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
outputs_callback: Callable[[], Mapping[str, StreamOutput]], outputs_callback: Callable[[], Mapping[str, StreamOutput]],
) -> None: ) -> None:
"""Initialize SegmentBuffer.""" """Initialize StreamState."""
self._stream_id: int = 0 self._stream_id: int = 0
self._hass = hass self.hass = hass
self._outputs_callback: Callable[ self._outputs_callback: Callable[
[], Mapping[str, StreamOutput] [], Mapping[str, StreamOutput]
] = outputs_callback ] = outputs_callback
# sequence gets incremented before the first segment so the first segment # sequence gets incremented before the first segment so the first segment
# has a sequence number of 0. # has a sequence number of 0.
self._sequence = -1 self._sequence = -1
@property
def sequence(self) -> int:
"""Return the current sequence for the latest segment."""
return self._sequence
def next_sequence(self) -> int:
"""Increment the sequence number."""
self._sequence += 1
return self._sequence
@property
def stream_id(self) -> int:
"""Return the readonly stream_id attribute."""
return self._stream_id
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
@property
def outputs(self) -> list[StreamOutput]:
"""Return the active stream outputs."""
return list(self._outputs_callback().values())
class StreamMuxer:
"""StreamMuxer re-packages video/audio packets for output."""
def __init__(
self,
hass: HomeAssistant,
video_stream: av.video.VideoStream,
audio_stream: av.audio.stream.AudioStream | None,
stream_state: StreamState,
) -> None:
"""Initialize StreamMuxer."""
self._hass = hass
self._segment_start_dts: int = cast(int, None) self._segment_start_dts: int = cast(int, None)
self._memory_file: BytesIO = cast(BytesIO, None) self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None self._input_video_stream: av.video.VideoStream = video_stream
self._input_audio_stream: av.audio.stream.AudioStream | None = None self._input_audio_stream: av.audio.stream.AudioStream | None = audio_stream
self._output_video_stream: av.video.VideoStream = None self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream: av.audio.stream.AudioStream | None = None self._output_audio_stream: av.audio.stream.AudioStream | None = None
self._segment: Segment | None = None self._segment: Segment | None = None
@ -70,6 +119,7 @@ class SegmentBuffer:
self._part_start_dts: int = cast(int, None) self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False self._part_has_keyframe = False
self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS] self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS]
self._stream_state = stream_state
self._start_time = datetime.datetime.utcnow() self._start_time = datetime.datetime.utcnow()
def make_new_av( def make_new_av(
@ -77,14 +127,13 @@ class SegmentBuffer:
memory_file: BytesIO, memory_file: BytesIO,
sequence: int, sequence: int,
input_vstream: av.video.VideoStream, input_vstream: av.video.VideoStream,
input_astream: av.audio.stream.AudioStream, input_astream: av.audio.stream.AudioStream | None,
) -> tuple[ ) -> tuple[
av.container.OutputContainer, av.container.OutputContainer,
av.video.VideoStream, av.video.VideoStream,
av.audio.stream.AudioStream | None, av.audio.stream.AudioStream | None,
]: ]:
"""Make a new av OutputContainer and add output streams.""" """Make a new av OutputContainer and add output streams."""
add_audio = input_astream and input_astream.name in AUDIO_CODECS
container = av.open( container = av.open(
memory_file, memory_file,
mode="w", mode="w",
@ -135,24 +184,12 @@ class SegmentBuffer:
output_vstream = container.add_stream(template=input_vstream) output_vstream = container.add_stream(template=input_vstream)
# Check if audio is requested # Check if audio is requested
output_astream = None output_astream = None
if add_audio: if input_astream:
output_astream = container.add_stream(template=input_astream) output_astream = container.add_stream(template=input_astream)
return container, output_vstream, output_astream return container, output_vstream, output_astream
def set_streams(
self,
video_stream: av.video.VideoStream,
audio_stream: Any,
# no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged
) -> None:
"""Initialize output buffer with streams from container."""
self._input_video_stream = video_stream
self._input_audio_stream = audio_stream
def reset(self, video_dts: int) -> None: def reset(self, video_dts: int) -> None:
"""Initialize a new stream segment.""" """Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._part_start_dts = self._segment_start_dts = video_dts self._part_start_dts = self._segment_start_dts = video_dts
self._segment = None self._segment = None
self._memory_file = BytesIO() self._memory_file = BytesIO()
@ -163,7 +200,7 @@ class SegmentBuffer:
self._output_audio_stream, self._output_audio_stream,
) = self.make_new_av( ) = self.make_new_av(
memory_file=self._memory_file, memory_file=self._memory_file,
sequence=self._sequence, sequence=self._stream_state.next_sequence(),
input_vstream=self._input_video_stream, input_vstream=self._input_video_stream,
input_astream=self._input_audio_stream, input_astream=self._input_audio_stream,
) )
@ -201,12 +238,12 @@ class SegmentBuffer:
# We have our first non-zero byte position. This means the init has just # We have our first non-zero byte position. This means the init has just
# been written. Create a Segment and put it to the queue of each output. # been written. Create a Segment and put it to the queue of each output.
self._segment = Segment( self._segment = Segment(
sequence=self._sequence, sequence=self._stream_state.sequence,
stream_id=self._stream_id, stream_id=self._stream_state.stream_id,
init=self._memory_file.getvalue(), init=self._memory_file.getvalue(),
# Fetch the latest StreamOutputs, which may have changed since the # Fetch the latest StreamOutputs, which may have changed since the
# worker started. # worker started.
stream_outputs=self._outputs_callback().values(), stream_outputs=self._stream_state.outputs,
start_time=self._start_time, start_time=self._start_time,
) )
self._memory_file_pos = self._memory_file.tell() self._memory_file_pos = self._memory_file.tell()
@ -283,17 +320,6 @@ class SegmentBuffer:
self._part_start_dts = adjusted_dts self._part_start_dts = adjusted_dts
self._part_has_keyframe = False self._part_has_keyframe = False
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
self._start_time = datetime.datetime.utcnow()
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
def close(self) -> None: def close(self) -> None:
"""Close stream buffer.""" """Close stream buffer."""
self._av_output.close() self._av_output.close()
@ -412,7 +438,7 @@ def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool:
def stream_worker( def stream_worker(
source: str, source: str,
options: dict[str, str], options: dict[str, str],
segment_buffer: SegmentBuffer, stream_state: StreamState,
quit_event: Event, quit_event: Event,
) -> None: ) -> None:
"""Handle consuming streams.""" """Handle consuming streams."""
@ -431,6 +457,8 @@ def stream_worker(
audio_stream = container.streams.audio[0] audio_stream = container.streams.audio[0]
except (KeyError, IndexError): except (KeyError, IndexError):
audio_stream = None audio_stream = None
if audio_stream and audio_stream.name not in AUDIO_CODECS:
audio_stream = None
# These formats need aac_adtstoasc bitstream filter, but auto_bsf not # These formats need aac_adtstoasc bitstream filter, but auto_bsf not
# compatible with empty_moov and manual bitstream filters not in PyAV # compatible with empty_moov and manual bitstream filters not in PyAV
if container.format.name in {"hls", "mpegts"}: if container.format.name in {"hls", "mpegts"}:
@ -489,13 +517,13 @@ def stream_worker(
"Error demuxing stream while finding first packet: %s" % str(ex) "Error demuxing stream while finding first packet: %s" % str(ex)
) from ex ) from ex
segment_buffer.set_streams(video_stream, audio_stream) muxer = StreamMuxer(stream_state.hass, video_stream, audio_stream, stream_state)
segment_buffer.reset(start_dts) muxer.reset(start_dts)
# Mux the first keyframe, then proceed through the rest of the packets # Mux the first keyframe, then proceed through the rest of the packets
segment_buffer.mux_packet(first_keyframe) muxer.mux_packet(first_keyframe)
with contextlib.closing(container), contextlib.closing(segment_buffer): with contextlib.closing(container), contextlib.closing(muxer):
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
packet = next(container_packets) packet = next(container_packets)
@ -506,4 +534,4 @@ def stream_worker(
except av.AVError as ex: except av.AVError as ex:
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
segment_buffer.mux_packet(packet) muxer.mux_packet(packet)

View file

@ -23,7 +23,7 @@ import async_timeout
import pytest import pytest
from homeassistant.components.stream.core import Segment, StreamOutput from homeassistant.components.stream.core import Segment, StreamOutput
from homeassistant.components.stream.worker import SegmentBuffer from homeassistant.components.stream.worker import StreamState
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
@ -34,7 +34,7 @@ class WorkerSync:
def __init__(self): def __init__(self):
"""Initialize WorkerSync.""" """Initialize WorkerSync."""
self._event = None self._event = None
self._original = SegmentBuffer.discontinuity self._original = StreamState.discontinuity
def pause(self): def pause(self):
"""Pause the worker before it finalizes the stream.""" """Pause the worker before it finalizes the stream."""
@ -45,7 +45,7 @@ class WorkerSync:
logging.debug("waking blocked worker") logging.debug("waking blocked worker")
self._event.set() self._event.set()
def blocking_discontinuity(self, stream: SegmentBuffer): def blocking_discontinuity(self, stream_state: StreamState):
"""Intercept call to pause stream worker.""" """Intercept call to pause stream worker."""
# Worker is ending the stream, which clears all output buffers. # Worker is ending the stream, which clears all output buffers.
# Block the worker thread until the test has a chance to verify # Block the worker thread until the test has a chance to verify
@ -55,7 +55,7 @@ class WorkerSync:
self._event.wait() self._event.wait()
# Forward to actual implementation # Forward to actual implementation
self._original(stream) self._original(stream_state)
@pytest.fixture() @pytest.fixture()
@ -63,7 +63,7 @@ def stream_worker_sync(hass):
"""Patch StreamOutput to allow test to synchronize worker stream end.""" """Patch StreamOutput to allow test to synchronize worker stream end."""
sync = WorkerSync() sync = WorkerSync()
with patch( with patch(
"homeassistant.components.stream.worker.SegmentBuffer.discontinuity", "homeassistant.components.stream.worker.StreamState.discontinuity",
side_effect=sync.blocking_discontinuity, side_effect=sync.blocking_discontinuity,
autospec=True, autospec=True,
): ):

View file

@ -38,8 +38,8 @@ from homeassistant.components.stream.const import (
) )
from homeassistant.components.stream.core import StreamSettings from homeassistant.components.stream.core import StreamSettings
from homeassistant.components.stream.worker import ( from homeassistant.components.stream.worker import (
SegmentBuffer,
StreamEndedError, StreamEndedError,
StreamState,
StreamWorkerError, StreamWorkerError,
stream_worker, stream_worker,
) )
@ -255,6 +255,12 @@ class MockPyAv:
return self.container return self.container
def run_worker(hass, stream, stream_source):
"""Run the stream worker under test."""
stream_state = StreamState(hass, stream.outputs)
stream_worker(stream_source, {}, stream_state, threading.Event())
async def async_decode_stream(hass, packets, py_av=None): async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments.""" """Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE, {}) stream = Stream(hass, STREAM_SOURCE, {})
@ -268,9 +274,8 @@ async def async_decode_stream(hass, packets, py_av=None):
"homeassistant.components.stream.core.StreamOutput.put", "homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment, side_effect=py_av.capture_buffer.capture_output_segment,
): ):
segment_buffer = SegmentBuffer(hass, stream.outputs)
try: try:
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) run_worker(hass, stream, STREAM_SOURCE)
except StreamEndedError: except StreamEndedError:
# Tests only use a limited number of packets, then the worker exits as expected. In # Tests only use a limited number of packets, then the worker exits as expected. In
# production, stream ending would be unexpected. # production, stream ending would be unexpected.
@ -288,8 +293,7 @@ async def test_stream_open_fails(hass):
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open, pytest.raises(StreamWorkerError): with patch("av.open") as av_open, pytest.raises(StreamWorkerError):
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs) run_worker(hass, stream, STREAM_SOURCE)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
await hass.async_block_till_done() await hass.async_block_till_done()
av_open.assert_called_once() av_open.assert_called_once()
@ -695,10 +699,7 @@ async def test_worker_log(hass, caplog):
with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err: with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err:
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs) run_worker(hass, stream, "https://abcd:efgh@foo.bar")
stream_worker(
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert str(err.value) == "Error opening stream https://****:****@foo.bar" assert str(err.value) == "Error opening stream https://****:****@foo.bar"
assert "https://abcd:efgh@foo.bar" not in caplog.text assert "https://abcd:efgh@foo.bar" not in caplog.text