Add an available property on Stream (#60429)

This commit is contained in:
Allen Porter 2021-11-29 21:23:58 -08:00 committed by GitHub
parent f0f88d56bd
commit df90fdf641
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 41 deletions

View file

@ -204,6 +204,7 @@ class Stream:
self._thread_quit = threading.Event() self._thread_quit = threading.Event()
self._outputs: dict[str, StreamOutput] = {} self._outputs: dict[str, StreamOutput] = {}
self._fast_restart_once = False self._fast_restart_once = False
self._available = True
def endpoint_url(self, fmt: str) -> str: def endpoint_url(self, fmt: str) -> str:
"""Start the stream and returns a url for the output format.""" """Start the stream and returns a url for the output format."""
@ -254,6 +255,11 @@ class Stream:
if all(p.idle for p in self._outputs.values()): if all(p.idle for p in self._outputs.values()):
self.access_token = None self.access_token = None
@property
def available(self) -> bool:
"""Return False if the stream is started and known to be unavailable."""
return self._available
def start(self) -> None: def start(self) -> None:
"""Start a stream.""" """Start a stream."""
if self._thread is None or not self._thread.is_alive(): if self._thread is None or not self._thread.is_alive():
@ -280,18 +286,25 @@ class Stream:
"""Handle consuming streams and restart keepalive streams.""" """Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs # Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .worker import SegmentBuffer, stream_worker from .worker import SegmentBuffer, StreamWorkerError, stream_worker
segment_buffer = SegmentBuffer(self.hass, self.outputs) segment_buffer = SegmentBuffer(self.hass, self.outputs)
wait_timeout = 0 wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout): while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time() start_time = time.time()
self._available = True
try:
stream_worker( stream_worker(
self.source, self.source,
self.options, self.options,
segment_buffer, segment_buffer,
self._thread_quit, self._thread_quit,
) )
except StreamWorkerError as err:
_LOGGER.error("Error from stream worker: %s", str(err))
self._available = False
segment_buffer.discontinuity() segment_buffer.discontinuity()
if not self.keepalive or self._thread_quit.is_set(): if not self.keepalive or self._thread_quit.is_set():
if self._fast_restart_once: if self._fast_restart_once:
@ -300,6 +313,7 @@ class Stream:
self._thread_quit.clear() self._thread_quit.clear()
continue continue
break break
# To avoid excessive restarts, wait before restarting # To avoid excessive restarts, wait before restarting
# As the required recovery time may be different for different setups, start # As the required recovery time may be different for different setups, start
# with trying a short wait_timeout and increase it on each reconnection attempt. # with trying a short wait_timeout and increase it on each reconnection attempt.

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable, Generator, Iterator, Mapping from collections.abc import Callable, Generator, Iterator, Mapping
import contextlib
import datetime import datetime
from io import BytesIO from io import BytesIO
import logging import logging
@ -31,6 +32,14 @@ from .hls import HlsStreamOutput
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class StreamWorkerError(Exception):
"""An exception thrown while processing a stream."""
class StreamEndedError(StreamWorkerError):
"""Raised when the stream is complete, exposed for facilitating testing."""
class SegmentBuffer: class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment.""" """Buffer for writing a sequence of packets to the output as a segment."""
@ -356,7 +365,7 @@ class TimestampValidator:
# Discard packets missing DTS. Terminate if too many are missing. # Discard packets missing DTS. Terminate if too many are missing.
if packet.dts is None: if packet.dts is None:
if self._missing_dts >= MAX_MISSING_DTS: if self._missing_dts >= MAX_MISSING_DTS:
raise StopIteration( raise StreamWorkerError(
f"No dts in {MAX_MISSING_DTS+1} consecutive packets" f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
) )
self._missing_dts += 1 self._missing_dts += 1
@ -367,7 +376,7 @@ class TimestampValidator:
if packet.dts <= prev_dts: if packet.dts <= prev_dts:
gap = packet.time_base * (prev_dts - packet.dts) gap = packet.time_base * (prev_dts - packet.dts)
if gap > MAX_TIMESTAMP_GAP: if gap > MAX_TIMESTAMP_GAP:
raise StopIteration( raise StreamWorkerError(
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}" f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
) )
return False return False
@ -410,15 +419,14 @@ def stream_worker(
try: try:
container = av.open(source, options=options, timeout=SOURCE_TIMEOUT) container = av.open(source, options=options, timeout=SOURCE_TIMEOUT)
except av.AVError: except av.AVError as err:
_LOGGER.error("Error opening stream %s", redact_credentials(str(source))) raise StreamWorkerError(
return "Error opening stream %s" % redact_credentials(str(source))
) from err
try: try:
video_stream = container.streams.video[0] video_stream = container.streams.video[0]
except (KeyError, IndexError): except (KeyError, IndexError) as ex:
_LOGGER.error("Stream has no video") raise StreamWorkerError("Stream has no video") from ex
container.close()
return
try: try:
audio_stream = container.streams.audio[0] audio_stream = container.streams.audio[0]
except (KeyError, IndexError): except (KeyError, IndexError):
@ -469,10 +477,17 @@ def stream_worker(
# dts. Use "or 1" to deal with this. # dts. Use "or 1" to deal with this.
start_dts = next_video_packet.dts - (next_video_packet.duration or 1) start_dts = next_video_packet.dts - (next_video_packet.duration or 1)
first_keyframe.dts = first_keyframe.pts = start_dts first_keyframe.dts = first_keyframe.pts = start_dts
except (av.AVError, StopIteration) as ex: except StreamWorkerError as ex:
_LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex))
container.close() container.close()
return raise ex
except StopIteration as ex:
container.close()
raise StreamEndedError("Stream ended; no additional packets") from ex
except av.AVError as ex:
container.close()
raise StreamWorkerError(
"Error demuxing stream while finding first packet: %s" % str(ex)
) from ex
segment_buffer.set_streams(video_stream, audio_stream) segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(start_dts) segment_buffer.reset(start_dts)
@ -480,14 +495,15 @@ def stream_worker(
# Mux the first keyframe, then proceed through the rest of the packets # Mux the first keyframe, then proceed through the rest of the packets
segment_buffer.mux_packet(first_keyframe) segment_buffer.mux_packet(first_keyframe)
with contextlib.closing(container), contextlib.closing(segment_buffer):
while not quit_event.is_set(): while not quit_event.is_set():
try: try:
packet = next(container_packets) packet = next(container_packets)
except (av.AVError, StopIteration) as ex: except StreamWorkerError as ex:
_LOGGER.error("Error demuxing stream: %s", str(ex)) raise ex
break except StopIteration as ex:
segment_buffer.mux_packet(packet) raise StreamEndedError("Stream ended; no additional packets") from ex
except av.AVError as ex:
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
# Close stream segment_buffer.mux_packet(packet)
segment_buffer.close()
container.close()

View file

@ -135,6 +135,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
assert stream.available
stream.start() stream.start()
hls_client = await hls_stream(stream) hls_client = await hls_stream(stream)
@ -161,6 +162,9 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
stream_worker_sync.resume() stream_worker_sync.resume()
# The stream worker reported end of stream and exited
assert not stream.available
# Stop stream, if it hasn't quit already # Stop stream, if it hasn't quit already
stream.stop() stream.stop()
@ -181,6 +185,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
# Request stream # Request stream
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
assert stream.available
stream.start() stream.start()
url = stream.endpoint_url(HLS_PROVIDER) url = stream.endpoint_url(HLS_PROVIDER)
@ -267,6 +272,7 @@ async def test_stream_keepalive(hass):
stream._thread.join() stream._thread.join()
stream._thread = None stream._thread = None
assert av_open.call_count == 2 assert av_open.call_count == 2
assert not stream.available
# Stop stream, if it hasn't quit already # Stop stream, if it hasn't quit already
stream.stop() stream.stop()

View file

@ -37,7 +37,12 @@ from homeassistant.components.stream.const import (
TARGET_SEGMENT_DURATION_NON_LL_HLS, TARGET_SEGMENT_DURATION_NON_LL_HLS,
) )
from homeassistant.components.stream.core import StreamSettings from homeassistant.components.stream.core import StreamSettings
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker from homeassistant.components.stream.worker import (
SegmentBuffer,
StreamEndedError,
StreamWorkerError,
stream_worker,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.components.stream.common import generate_h264_video, generate_h265_video from tests.components.stream.common import generate_h264_video, generate_h265_video
@ -264,7 +269,14 @@ async def async_decode_stream(hass, packets, py_av=None):
side_effect=py_av.capture_buffer.capture_output_segment, side_effect=py_av.capture_buffer.capture_output_segment,
): ):
segment_buffer = SegmentBuffer(hass, stream.outputs) segment_buffer = SegmentBuffer(hass, stream.outputs)
try:
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
except StreamEndedError:
# Tests only use a limited number of packets, then the worker exits as expected. In
# production, stream ending would be unexpected.
pass
finally:
# Wait for all packets to be flushed even when exceptions are thrown
await hass.async_block_till_done() await hass.async_block_till_done()
return py_av.capture_buffer return py_av.capture_buffer
@ -274,7 +286,7 @@ async def test_stream_open_fails(hass):
"""Test failure on stream open.""" """Test failure on stream open."""
stream = Stream(hass, STREAM_SOURCE, {}) stream = Stream(hass, STREAM_SOURCE, {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open: with patch("av.open") as av_open, pytest.raises(StreamWorkerError):
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs) segment_buffer = SegmentBuffer(hass, stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
@ -371,7 +383,10 @@ async def test_packet_overflow(hass):
# Packet is so far out of order, exceeds max gap and looks like overflow # Packet is so far out of order, exceeds max gap and looks like overflow
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
decoded_stream = await async_decode_stream(hass, packets) py_av = MockPyAv()
with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"):
await async_decode_stream(hass, packets, py_av=py_av)
decoded_stream = py_av.capture_buffer
segments = decoded_stream.segments segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
# Check number of segments # Check number of segments
@ -425,7 +440,10 @@ async def test_too_many_initial_bad_packets_fails(hass):
for i in range(0, num_bad_packets): for i in range(0, num_bad_packets):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, packets) py_av = MockPyAv()
with pytest.raises(StreamWorkerError, match=r"No dts"):
await async_decode_stream(hass, packets, py_av=py_av)
decoded_stream = py_av.capture_buffer
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == 0 assert len(segments) == 0
assert len(decoded_stream.video_packets) == 0 assert len(decoded_stream.video_packets) == 0
@ -466,7 +484,10 @@ async def test_too_many_bad_packets(hass):
for i in range(bad_packet_start, bad_packet_start + num_bad_packets): for i in range(bad_packet_start, bad_packet_start + num_bad_packets):
packets[i].dts = None packets[i].dts = None
decoded_stream = await async_decode_stream(hass, packets) py_av = MockPyAv()
with pytest.raises(StreamWorkerError, match=r"No dts"):
await async_decode_stream(hass, packets, py_av=py_av)
decoded_stream = py_av.capture_buffer
complete_segments = decoded_stream.complete_segments complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start assert len(decoded_stream.video_packets) == bad_packet_start
@ -477,9 +498,11 @@ async def test_no_video_stream(hass):
"""Test no video stream in the container means no resulting output.""" """Test no video stream in the container means no resulting output."""
py_av = MockPyAv(video=False) py_av = MockPyAv(video=False)
decoded_stream = await async_decode_stream( with pytest.raises(StreamWorkerError, match=r"Stream has no video"):
await async_decode_stream(
hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av
) )
decoded_stream = py_av.capture_buffer
# Note: This failure scenario does not output an end of stream # Note: This failure scenario does not output an end of stream
segments = decoded_stream.segments segments = decoded_stream.segments
assert len(segments) == 0 assert len(segments) == 0
@ -616,6 +639,9 @@ async def test_stream_stopped_while_decoding(hass):
worker_wake.set() worker_wake.set()
stream.stop() stream.stop()
# Stream is still considered available when the worker was still active and asked to stop
assert stream.available
async def test_update_stream_source(hass): async def test_update_stream_source(hass):
"""Tests that the worker is re-invoked when the stream source is updated.""" """Tests that the worker is re-invoked when the stream source is updated."""
@ -646,6 +672,7 @@ async def test_update_stream_source(hass):
stream.start() stream.start()
assert worker_open.wait(TIMEOUT) assert worker_open.wait(TIMEOUT)
assert last_stream_source == STREAM_SOURCE assert last_stream_source == STREAM_SOURCE
assert stream.available
# Update the stream source, then the test wakes up the worker and assert # Update the stream source, then the test wakes up the worker and assert
# that it re-opens the new stream (the test again waits on thread_started) # that it re-opens the new stream (the test again waits on thread_started)
@ -655,6 +682,7 @@ async def test_update_stream_source(hass):
assert worker_open.wait(TIMEOUT) assert worker_open.wait(TIMEOUT)
assert last_stream_source == STREAM_SOURCE + "-updated-source" assert last_stream_source == STREAM_SOURCE + "-updated-source"
worker_wake.set() worker_wake.set()
assert stream.available
# Cleanup # Cleanup
stream.stop() stream.stop()
@ -664,15 +692,16 @@ async def test_worker_log(hass, caplog):
"""Test that the worker logs the url without username and password.""" """Test that the worker logs the url without username and password."""
stream = Stream(hass, "https://abcd:efgh@foo.bar", {}) stream = Stream(hass, "https://abcd:efgh@foo.bar", {})
stream.add_provider(HLS_PROVIDER) stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open:
with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err:
av_open.side_effect = av.error.InvalidDataError(-2, "error") av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs) segment_buffer = SegmentBuffer(hass, stream.outputs)
stream_worker( stream_worker(
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event() "https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert str(err.value) == "Error opening stream https://****:****@foo.bar"
assert "https://abcd:efgh@foo.bar" not in caplog.text assert "https://abcd:efgh@foo.bar" not in caplog.text
assert "https://****:****@foo.bar" in caplog.text
async def test_durations(hass, record_worker_sync): async def test_durations(hass, record_worker_sync):