Split StreamState class out of SegmentBuffer (#60423)

This refactoring was pulled out of https://github.com/home-assistant/core/pull/53676 as an
initial step towards reverting the addition of the SegmentBuffer class, which will be
unrolled back into a for loop.

The StreamState class holds the persistent state in stream that is used across stream worker
instantiations, e.g. state across a retry or url expiration, which primarily handles
discontinuities. By itself, this PR is not a large win until follow up PRs further simplify
the SegmentBuffer class.
This commit is contained in:
Allen Porter 2021-11-29 22:25:28 -08:00 committed by GitHub
parent 890790a659
commit 8ca89b10eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 60 deletions

View file

@ -40,28 +40,77 @@ class StreamEndedError(StreamWorkerError):
"""Raised when the stream is complete, exposed for facilitating testing."""
class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment."""
class StreamState:
"""Responsible for trakcing output and playback state for a stream.
Holds state used for playback to interpret a decoded stream. A source stream
may be reset (e.g. reconnecting to an rtsp stream) and this object tracks
the state to inform the player.
"""
def __init__(
self,
hass: HomeAssistant,
outputs_callback: Callable[[], Mapping[str, StreamOutput]],
) -> None:
"""Initialize SegmentBuffer."""
"""Initialize StreamState."""
self._stream_id: int = 0
self._hass = hass
self.hass = hass
self._outputs_callback: Callable[
[], Mapping[str, StreamOutput]
] = outputs_callback
# sequence gets incremented before the first segment so the first segment
# has a sequence number of 0.
self._sequence = -1
@property
def sequence(self) -> int:
"""Return the current sequence for the latest segment."""
return self._sequence
def next_sequence(self) -> int:
"""Increment the sequence number."""
self._sequence += 1
return self._sequence
@property
def stream_id(self) -> int:
"""Return the readonly stream_id attribute."""
return self._stream_id
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
@property
def outputs(self) -> list[StreamOutput]:
"""Return the active stream outputs."""
return list(self._outputs_callback().values())
class StreamMuxer:
"""StreamMuxer re-packages video/audio packets for output."""
def __init__(
self,
hass: HomeAssistant,
video_stream: av.video.VideoStream,
audio_stream: av.audio.stream.AudioStream | None,
stream_state: StreamState,
) -> None:
"""Initialize StreamMuxer."""
self._hass = hass
self._segment_start_dts: int = cast(int, None)
self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None
self._input_audio_stream: av.audio.stream.AudioStream | None = None
self._input_video_stream: av.video.VideoStream = video_stream
self._input_audio_stream: av.audio.stream.AudioStream | None = audio_stream
self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream: av.audio.stream.AudioStream | None = None
self._segment: Segment | None = None
@ -70,6 +119,7 @@ class SegmentBuffer:
self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False
self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS]
self._stream_state = stream_state
self._start_time = datetime.datetime.utcnow()
def make_new_av(
@ -77,14 +127,13 @@ class SegmentBuffer:
memory_file: BytesIO,
sequence: int,
input_vstream: av.video.VideoStream,
input_astream: av.audio.stream.AudioStream,
input_astream: av.audio.stream.AudioStream | None,
) -> tuple[
av.container.OutputContainer,
av.video.VideoStream,
av.audio.stream.AudioStream | None,
]:
"""Make a new av OutputContainer and add output streams."""
add_audio = input_astream and input_astream.name in AUDIO_CODECS
container = av.open(
memory_file,
mode="w",
@ -135,24 +184,12 @@ class SegmentBuffer:
output_vstream = container.add_stream(template=input_vstream)
# Check if audio is requested
output_astream = None
if add_audio:
if input_astream:
output_astream = container.add_stream(template=input_astream)
return container, output_vstream, output_astream
def set_streams(
self,
video_stream: av.video.VideoStream,
audio_stream: Any,
# no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged
) -> None:
"""Initialize output buffer with streams from container."""
self._input_video_stream = video_stream
self._input_audio_stream = audio_stream
def reset(self, video_dts: int) -> None:
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._part_start_dts = self._segment_start_dts = video_dts
self._segment = None
self._memory_file = BytesIO()
@ -163,7 +200,7 @@ class SegmentBuffer:
self._output_audio_stream,
) = self.make_new_av(
memory_file=self._memory_file,
sequence=self._sequence,
sequence=self._stream_state.next_sequence(),
input_vstream=self._input_video_stream,
input_astream=self._input_audio_stream,
)
@ -201,12 +238,12 @@ class SegmentBuffer:
# We have our first non-zero byte position. This means the init has just
# been written. Create a Segment and put it to the queue of each output.
self._segment = Segment(
sequence=self._sequence,
stream_id=self._stream_id,
sequence=self._stream_state.sequence,
stream_id=self._stream_state.stream_id,
init=self._memory_file.getvalue(),
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
stream_outputs=self._outputs_callback().values(),
stream_outputs=self._stream_state.outputs,
start_time=self._start_time,
)
self._memory_file_pos = self._memory_file.tell()
@ -283,17 +320,6 @@ class SegmentBuffer:
self._part_start_dts = adjusted_dts
self._part_has_keyframe = False
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
self._start_time = datetime.datetime.utcnow()
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
def close(self) -> None:
"""Close stream buffer."""
self._av_output.close()
@ -412,7 +438,7 @@ def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool:
def stream_worker(
source: str,
options: dict[str, str],
segment_buffer: SegmentBuffer,
stream_state: StreamState,
quit_event: Event,
) -> None:
"""Handle consuming streams."""
@ -431,6 +457,8 @@ def stream_worker(
audio_stream = container.streams.audio[0]
except (KeyError, IndexError):
audio_stream = None
if audio_stream and audio_stream.name not in AUDIO_CODECS:
audio_stream = None
# These formats need aac_adtstoasc bitstream filter, but auto_bsf not
# compatible with empty_moov and manual bitstream filters not in PyAV
if container.format.name in {"hls", "mpegts"}:
@ -489,13 +517,13 @@ def stream_worker(
"Error demuxing stream while finding first packet: %s" % str(ex)
) from ex
segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(start_dts)
muxer = StreamMuxer(stream_state.hass, video_stream, audio_stream, stream_state)
muxer.reset(start_dts)
# Mux the first keyframe, then proceed through the rest of the packets
segment_buffer.mux_packet(first_keyframe)
muxer.mux_packet(first_keyframe)
with contextlib.closing(container), contextlib.closing(segment_buffer):
with contextlib.closing(container), contextlib.closing(muxer):
while not quit_event.is_set():
try:
packet = next(container_packets)
@ -506,4 +534,4 @@ def stream_worker(
except av.AVError as ex:
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
segment_buffer.mux_packet(packet)
muxer.mux_packet(packet)