Refactor stream worker responsibilities for segmenting into a separate class (#46563)

* Remove stream_worker dependencies on Stream

Removee stream_worker dependencies on Stream and split out the logic
for writing segments to a stream buffer.

* Stop calling internal stream methods

* Update homeassistant/components/stream/worker.py

Co-authored-by: uvjustin <46082645+uvjustin@users.noreply.github.com>

* Reuse self._outputs when creating new streams

Co-authored-by: uvjustin <46082645+uvjustin@users.noreply.github.com>
This commit is contained in:
Allen Porter 2021-02-15 09:52:37 -08:00 committed by GitHub
parent f2ca4acff0
commit 89aaeb3c35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 75 deletions

View file

@ -124,7 +124,6 @@ class Stream:
self.access_token = secrets.token_hex()
return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token)
@property
def outputs(self):
"""Return a copy of the stream outputs."""
# A copy is returned so the caller can iterate through the outputs
@ -192,7 +191,7 @@ class Stream:
wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time()
stream_worker(self.hass, self, self._thread_quit)
stream_worker(self.source, self.options, self.outputs, self._thread_quit)
if not self.keepalive or self._thread_quit.is_set():
if self._fast_restart_once:
# The stream source is updated, restart without any delay.
@ -219,7 +218,7 @@ class Stream:
@callback
def remove_outputs():
for provider in self.outputs.values():
for provider in self.outputs().values():
self.remove_provider(provider)
self.hass.loop.call_soon_threadsafe(remove_outputs)
@ -248,7 +247,7 @@ class Stream:
raise HomeAssistantError(f"Can't write {video_path}, no access to path!")
# Add recorder
recorder = self.outputs.get("recorder")
recorder = self.outputs().get("recorder")
if recorder:
raise HomeAssistantError(
f"Stream already recording to {recorder.video_path}!"
@ -259,7 +258,7 @@ class Stream:
self.start()
# Take advantage of lookback
hls = self.outputs.get("hls")
hls = self.outputs().get("hls")
if lookback > 0 and hls:
num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS)
# Wait for latest segment, then add the lookback

View file

@ -43,15 +43,78 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
return StreamBuffer(segment, output, vstream, astream)
def stream_worker(hass, stream, quit_event):
class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment."""
def __init__(self, video_stream, audio_stream, outputs_callback) -> None:
"""Initialize SegmentBuffer."""
self._video_stream = video_stream
self._audio_stream = audio_stream
self._outputs_callback = outputs_callback
# tuple of StreamOutput, StreamBuffer
self._outputs = []
self._sequence = 0
self._segment_start_pts = None
def reset(self, video_pts):
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._segment_start_pts = video_pts
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
self._outputs = []
for stream_output in self._outputs_callback().values():
if self._video_stream.name not in stream_output.video_codecs:
continue
buffer = create_stream_buffer(
stream_output, self._video_stream, self._audio_stream, self._sequence
)
self._outputs.append((buffer, stream_output))
def mux_packet(self, packet):
"""Mux a packet to the appropriate StreamBuffers."""
# Check for end of segment
if packet.stream == self._video_stream and packet.is_keyframe:
duration = (packet.pts - self._segment_start_pts) * packet.time_base
if duration >= MIN_SEGMENT_DURATION:
# Save segment to outputs
self.flush(duration)
# Reinitialize
self.reset(packet.pts)
# Mux the packet
for (buffer, _) in self._outputs:
if packet.stream == self._video_stream:
packet.stream = buffer.vstream
elif packet.stream == self._audio_stream:
packet.stream = buffer.astream
else:
continue
buffer.output.mux(packet)
def flush(self, duration):
"""Create a segment from the buffered packets and write to output."""
for (buffer, stream_output) in self._outputs:
buffer.output.close()
stream_output.put(Segment(self._sequence, buffer.segment, duration))
def close(self):
"""Close all StreamBuffers."""
for (buffer, _) in self._outputs:
buffer.output.close()
def stream_worker(source, options, outputs_callback, quit_event):
"""Handle consuming streams."""
try:
container = av.open(
stream.source, options=stream.options, timeout=STREAM_TIMEOUT
)
container = av.open(source, options=options, timeout=STREAM_TIMEOUT)
except av.AVError:
_LOGGER.error("Error opening stream %s", stream.source)
_LOGGER.error("Error opening stream %s", source)
return
try:
video_stream = container.streams.video[0]
@ -78,9 +141,7 @@ def stream_worker(hass, stream, quit_event):
# Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0
# Holds the buffers for each stream provider
outputs = None
# Keep track of the number of segments we've processed
sequence = 0
segment_buffer = SegmentBuffer(video_stream, audio_stream, outputs_callback)
# The video pts at the beginning of the segment
segment_start_pts = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
@ -157,44 +218,11 @@ def stream_worker(hass, stream, quit_event):
return False
return True
def initialize_segment(video_pts):
"""Reset some variables and initialize outputs for each segment."""
nonlocal outputs, sequence, segment_start_pts
# Clear outputs and increment sequence
outputs = {}
sequence += 1
segment_start_pts = video_pts
for stream_output in stream.outputs.values():
if video_stream.name not in stream_output.video_codecs:
continue
buffer = create_stream_buffer(
stream_output, video_stream, audio_stream, sequence
)
outputs[stream_output.name] = (
buffer,
{video_stream: buffer.vstream, audio_stream: buffer.astream},
)
def mux_video_packet(packet):
# mux packets to each buffer
for buffer, output_streams in outputs.values():
# Assign the packet to the new stream & mux
packet.stream = output_streams[video_stream]
buffer.output.mux(packet)
def mux_audio_packet(packet):
# almost the same as muxing video but add extra check
for buffer, output_streams in outputs.values():
# Assign the packet to the new stream & mux
if output_streams.get(audio_stream):
packet.stream = output_streams[audio_stream]
buffer.output.mux(packet)
if not peek_first_pts():
container.close()
return
initialize_segment(segment_start_pts)
segment_buffer.reset(segment_start_pts)
while not quit_event.is_set():
try:
@ -229,34 +257,13 @@ def stream_worker(hass, stream, quit_event):
break
continue
# Check for end of segment
if packet.stream == video_stream and packet.is_keyframe:
segment_duration = (packet.pts - segment_start_pts) * packet.time_base
if segment_duration >= MIN_SEGMENT_DURATION:
# Save segment to outputs
for fmt, (buffer, _) in outputs.items():
buffer.output.close()
if stream.outputs.get(fmt):
stream.outputs[fmt].put(
Segment(
sequence,
buffer.segment,
segment_duration,
),
)
# Reinitialize
initialize_segment(packet.pts)
# Update last_dts processed
last_dts[packet.stream] = packet.dts
# mux packets
if packet.stream == video_stream:
mux_video_packet(packet) # mutates packet timestamps
else:
mux_audio_packet(packet) # mutates packet timestamps
# Mux packets, and possibly write a segment to the output stream.
# This mutates packet timestamps and stream
segment_buffer.mux_packet(packet)
# Close stream
for buffer, _ in outputs.values():
buffer.output.close()
segment_buffer.close()
container.close()

View file

@ -198,7 +198,7 @@ async def async_decode_stream(hass, packets, py_av=None):
"homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment,
):
stream_worker(hass, stream, threading.Event())
stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event())
await hass.async_block_till_done()
return py_av.capture_buffer
@ -210,7 +210,7 @@ async def test_stream_open_fails(hass):
stream.add_provider(STREAM_OUTPUT_FORMAT)
with patch("av.open") as av_open:
av_open.side_effect = av.error.InvalidDataError(-2, "error")
stream_worker(hass, stream, threading.Event())
stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event())
await hass.async_block_till_done()
av_open.assert_called_once()