Add an available property on Stream (#60429)
This commit is contained in:
parent
f0f88d56bd
commit
df90fdf641
4 changed files with 106 additions and 41 deletions
|
@ -204,6 +204,7 @@ class Stream:
|
||||||
self._thread_quit = threading.Event()
|
self._thread_quit = threading.Event()
|
||||||
self._outputs: dict[str, StreamOutput] = {}
|
self._outputs: dict[str, StreamOutput] = {}
|
||||||
self._fast_restart_once = False
|
self._fast_restart_once = False
|
||||||
|
self._available = True
|
||||||
|
|
||||||
def endpoint_url(self, fmt: str) -> str:
|
def endpoint_url(self, fmt: str) -> str:
|
||||||
"""Start the stream and returns a url for the output format."""
|
"""Start the stream and returns a url for the output format."""
|
||||||
|
@ -254,6 +255,11 @@ class Stream:
|
||||||
if all(p.idle for p in self._outputs.values()):
|
if all(p.idle for p in self._outputs.values()):
|
||||||
self.access_token = None
|
self.access_token = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
"""Return False if the stream is started and known to be unavailable."""
|
||||||
|
return self._available
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Start a stream."""
|
"""Start a stream."""
|
||||||
if self._thread is None or not self._thread.is_alive():
|
if self._thread is None or not self._thread.is_alive():
|
||||||
|
@ -280,18 +286,25 @@ class Stream:
|
||||||
"""Handle consuming streams and restart keepalive streams."""
|
"""Handle consuming streams and restart keepalive streams."""
|
||||||
# Keep import here so that we can import stream integration without installing reqs
|
# Keep import here so that we can import stream integration without installing reqs
|
||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from .worker import SegmentBuffer, stream_worker
|
from .worker import SegmentBuffer, StreamWorkerError, stream_worker
|
||||||
|
|
||||||
segment_buffer = SegmentBuffer(self.hass, self.outputs)
|
segment_buffer = SegmentBuffer(self.hass, self.outputs)
|
||||||
wait_timeout = 0
|
wait_timeout = 0
|
||||||
while not self._thread_quit.wait(timeout=wait_timeout):
|
while not self._thread_quit.wait(timeout=wait_timeout):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
self._available = True
|
||||||
|
try:
|
||||||
stream_worker(
|
stream_worker(
|
||||||
self.source,
|
self.source,
|
||||||
self.options,
|
self.options,
|
||||||
segment_buffer,
|
segment_buffer,
|
||||||
self._thread_quit,
|
self._thread_quit,
|
||||||
)
|
)
|
||||||
|
except StreamWorkerError as err:
|
||||||
|
_LOGGER.error("Error from stream worker: %s", str(err))
|
||||||
|
self._available = False
|
||||||
|
|
||||||
segment_buffer.discontinuity()
|
segment_buffer.discontinuity()
|
||||||
if not self.keepalive or self._thread_quit.is_set():
|
if not self.keepalive or self._thread_quit.is_set():
|
||||||
if self._fast_restart_once:
|
if self._fast_restart_once:
|
||||||
|
@ -300,6 +313,7 @@ class Stream:
|
||||||
self._thread_quit.clear()
|
self._thread_quit.clear()
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
# To avoid excessive restarts, wait before restarting
|
# To avoid excessive restarts, wait before restarting
|
||||||
# As the required recovery time may be different for different setups, start
|
# As the required recovery time may be different for different setups, start
|
||||||
# with trying a short wait_timeout and increase it on each reconnection attempt.
|
# with trying a short wait_timeout and increase it on each reconnection attempt.
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Callable, Generator, Iterator, Mapping
|
from collections.abc import Callable, Generator, Iterator, Mapping
|
||||||
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
|
@ -31,6 +32,14 @@ from .hls import HlsStreamOutput
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamWorkerError(Exception):
|
||||||
|
"""An exception thrown while processing a stream."""
|
||||||
|
|
||||||
|
|
||||||
|
class StreamEndedError(StreamWorkerError):
|
||||||
|
"""Raised when the stream is complete, exposed for facilitating testing."""
|
||||||
|
|
||||||
|
|
||||||
class SegmentBuffer:
|
class SegmentBuffer:
|
||||||
"""Buffer for writing a sequence of packets to the output as a segment."""
|
"""Buffer for writing a sequence of packets to the output as a segment."""
|
||||||
|
|
||||||
|
@ -356,7 +365,7 @@ class TimestampValidator:
|
||||||
# Discard packets missing DTS. Terminate if too many are missing.
|
# Discard packets missing DTS. Terminate if too many are missing.
|
||||||
if packet.dts is None:
|
if packet.dts is None:
|
||||||
if self._missing_dts >= MAX_MISSING_DTS:
|
if self._missing_dts >= MAX_MISSING_DTS:
|
||||||
raise StopIteration(
|
raise StreamWorkerError(
|
||||||
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
|
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
|
||||||
)
|
)
|
||||||
self._missing_dts += 1
|
self._missing_dts += 1
|
||||||
|
@ -367,7 +376,7 @@ class TimestampValidator:
|
||||||
if packet.dts <= prev_dts:
|
if packet.dts <= prev_dts:
|
||||||
gap = packet.time_base * (prev_dts - packet.dts)
|
gap = packet.time_base * (prev_dts - packet.dts)
|
||||||
if gap > MAX_TIMESTAMP_GAP:
|
if gap > MAX_TIMESTAMP_GAP:
|
||||||
raise StopIteration(
|
raise StreamWorkerError(
|
||||||
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
|
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
@ -410,15 +419,14 @@ def stream_worker(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
container = av.open(source, options=options, timeout=SOURCE_TIMEOUT)
|
container = av.open(source, options=options, timeout=SOURCE_TIMEOUT)
|
||||||
except av.AVError:
|
except av.AVError as err:
|
||||||
_LOGGER.error("Error opening stream %s", redact_credentials(str(source)))
|
raise StreamWorkerError(
|
||||||
return
|
"Error opening stream %s" % redact_credentials(str(source))
|
||||||
|
) from err
|
||||||
try:
|
try:
|
||||||
video_stream = container.streams.video[0]
|
video_stream = container.streams.video[0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError) as ex:
|
||||||
_LOGGER.error("Stream has no video")
|
raise StreamWorkerError("Stream has no video") from ex
|
||||||
container.close()
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
audio_stream = container.streams.audio[0]
|
audio_stream = container.streams.audio[0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
|
@ -469,10 +477,17 @@ def stream_worker(
|
||||||
# dts. Use "or 1" to deal with this.
|
# dts. Use "or 1" to deal with this.
|
||||||
start_dts = next_video_packet.dts - (next_video_packet.duration or 1)
|
start_dts = next_video_packet.dts - (next_video_packet.duration or 1)
|
||||||
first_keyframe.dts = first_keyframe.pts = start_dts
|
first_keyframe.dts = first_keyframe.pts = start_dts
|
||||||
except (av.AVError, StopIteration) as ex:
|
except StreamWorkerError as ex:
|
||||||
_LOGGER.error("Error demuxing stream while finding first packet: %s", str(ex))
|
|
||||||
container.close()
|
container.close()
|
||||||
return
|
raise ex
|
||||||
|
except StopIteration as ex:
|
||||||
|
container.close()
|
||||||
|
raise StreamEndedError("Stream ended; no additional packets") from ex
|
||||||
|
except av.AVError as ex:
|
||||||
|
container.close()
|
||||||
|
raise StreamWorkerError(
|
||||||
|
"Error demuxing stream while finding first packet: %s" % str(ex)
|
||||||
|
) from ex
|
||||||
|
|
||||||
segment_buffer.set_streams(video_stream, audio_stream)
|
segment_buffer.set_streams(video_stream, audio_stream)
|
||||||
segment_buffer.reset(start_dts)
|
segment_buffer.reset(start_dts)
|
||||||
|
@ -480,14 +495,15 @@ def stream_worker(
|
||||||
# Mux the first keyframe, then proceed through the rest of the packets
|
# Mux the first keyframe, then proceed through the rest of the packets
|
||||||
segment_buffer.mux_packet(first_keyframe)
|
segment_buffer.mux_packet(first_keyframe)
|
||||||
|
|
||||||
|
with contextlib.closing(container), contextlib.closing(segment_buffer):
|
||||||
while not quit_event.is_set():
|
while not quit_event.is_set():
|
||||||
try:
|
try:
|
||||||
packet = next(container_packets)
|
packet = next(container_packets)
|
||||||
except (av.AVError, StopIteration) as ex:
|
except StreamWorkerError as ex:
|
||||||
_LOGGER.error("Error demuxing stream: %s", str(ex))
|
raise ex
|
||||||
break
|
except StopIteration as ex:
|
||||||
segment_buffer.mux_packet(packet)
|
raise StreamEndedError("Stream ended; no additional packets") from ex
|
||||||
|
except av.AVError as ex:
|
||||||
|
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
|
||||||
|
|
||||||
# Close stream
|
segment_buffer.mux_packet(packet)
|
||||||
segment_buffer.close()
|
|
||||||
container.close()
|
|
||||||
|
|
|
@ -135,6 +135,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
|
||||||
|
|
||||||
# Request stream
|
# Request stream
|
||||||
stream.add_provider(HLS_PROVIDER)
|
stream.add_provider(HLS_PROVIDER)
|
||||||
|
assert stream.available
|
||||||
stream.start()
|
stream.start()
|
||||||
|
|
||||||
hls_client = await hls_stream(stream)
|
hls_client = await hls_stream(stream)
|
||||||
|
@ -161,6 +162,9 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
|
||||||
|
|
||||||
stream_worker_sync.resume()
|
stream_worker_sync.resume()
|
||||||
|
|
||||||
|
# The stream worker reported end of stream and exited
|
||||||
|
assert not stream.available
|
||||||
|
|
||||||
# Stop stream, if it hasn't quit already
|
# Stop stream, if it hasn't quit already
|
||||||
stream.stop()
|
stream.stop()
|
||||||
|
|
||||||
|
@ -181,6 +185,7 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
|
||||||
|
|
||||||
# Request stream
|
# Request stream
|
||||||
stream.add_provider(HLS_PROVIDER)
|
stream.add_provider(HLS_PROVIDER)
|
||||||
|
assert stream.available
|
||||||
stream.start()
|
stream.start()
|
||||||
url = stream.endpoint_url(HLS_PROVIDER)
|
url = stream.endpoint_url(HLS_PROVIDER)
|
||||||
|
|
||||||
|
@ -267,6 +272,7 @@ async def test_stream_keepalive(hass):
|
||||||
stream._thread.join()
|
stream._thread.join()
|
||||||
stream._thread = None
|
stream._thread = None
|
||||||
assert av_open.call_count == 2
|
assert av_open.call_count == 2
|
||||||
|
assert not stream.available
|
||||||
|
|
||||||
# Stop stream, if it hasn't quit already
|
# Stop stream, if it hasn't quit already
|
||||||
stream.stop()
|
stream.stop()
|
||||||
|
|
|
@ -37,7 +37,12 @@ from homeassistant.components.stream.const import (
|
||||||
TARGET_SEGMENT_DURATION_NON_LL_HLS,
|
TARGET_SEGMENT_DURATION_NON_LL_HLS,
|
||||||
)
|
)
|
||||||
from homeassistant.components.stream.core import StreamSettings
|
from homeassistant.components.stream.core import StreamSettings
|
||||||
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
|
from homeassistant.components.stream.worker import (
|
||||||
|
SegmentBuffer,
|
||||||
|
StreamEndedError,
|
||||||
|
StreamWorkerError,
|
||||||
|
stream_worker,
|
||||||
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.components.stream.common import generate_h264_video, generate_h265_video
|
from tests.components.stream.common import generate_h264_video, generate_h265_video
|
||||||
|
@ -264,7 +269,14 @@ async def async_decode_stream(hass, packets, py_av=None):
|
||||||
side_effect=py_av.capture_buffer.capture_output_segment,
|
side_effect=py_av.capture_buffer.capture_output_segment,
|
||||||
):
|
):
|
||||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||||
|
try:
|
||||||
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
||||||
|
except StreamEndedError:
|
||||||
|
# Tests only use a limited number of packets, then the worker exits as expected. In
|
||||||
|
# production, stream ending would be unexpected.
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# Wait for all packets to be flushed even when exceptions are thrown
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
return py_av.capture_buffer
|
return py_av.capture_buffer
|
||||||
|
@ -274,7 +286,7 @@ async def test_stream_open_fails(hass):
|
||||||
"""Test failure on stream open."""
|
"""Test failure on stream open."""
|
||||||
stream = Stream(hass, STREAM_SOURCE, {})
|
stream = Stream(hass, STREAM_SOURCE, {})
|
||||||
stream.add_provider(HLS_PROVIDER)
|
stream.add_provider(HLS_PROVIDER)
|
||||||
with patch("av.open") as av_open:
|
with patch("av.open") as av_open, pytest.raises(StreamWorkerError):
|
||||||
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
||||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||||
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
||||||
|
@ -371,7 +383,10 @@ async def test_packet_overflow(hass):
|
||||||
# Packet is so far out of order, exceeds max gap and looks like overflow
|
# Packet is so far out of order, exceeds max gap and looks like overflow
|
||||||
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
|
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets)
|
py_av = MockPyAv()
|
||||||
|
with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"):
|
||||||
|
await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
|
decoded_stream = py_av.capture_buffer
|
||||||
segments = decoded_stream.segments
|
segments = decoded_stream.segments
|
||||||
complete_segments = decoded_stream.complete_segments
|
complete_segments = decoded_stream.complete_segments
|
||||||
# Check number of segments
|
# Check number of segments
|
||||||
|
@ -425,7 +440,10 @@ async def test_too_many_initial_bad_packets_fails(hass):
|
||||||
for i in range(0, num_bad_packets):
|
for i in range(0, num_bad_packets):
|
||||||
packets[i].dts = None
|
packets[i].dts = None
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets)
|
py_av = MockPyAv()
|
||||||
|
with pytest.raises(StreamWorkerError, match=r"No dts"):
|
||||||
|
await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
|
decoded_stream = py_av.capture_buffer
|
||||||
segments = decoded_stream.segments
|
segments = decoded_stream.segments
|
||||||
assert len(segments) == 0
|
assert len(segments) == 0
|
||||||
assert len(decoded_stream.video_packets) == 0
|
assert len(decoded_stream.video_packets) == 0
|
||||||
|
@ -466,7 +484,10 @@ async def test_too_many_bad_packets(hass):
|
||||||
for i in range(bad_packet_start, bad_packet_start + num_bad_packets):
|
for i in range(bad_packet_start, bad_packet_start + num_bad_packets):
|
||||||
packets[i].dts = None
|
packets[i].dts = None
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets)
|
py_av = MockPyAv()
|
||||||
|
with pytest.raises(StreamWorkerError, match=r"No dts"):
|
||||||
|
await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
|
decoded_stream = py_av.capture_buffer
|
||||||
complete_segments = decoded_stream.complete_segments
|
complete_segments = decoded_stream.complete_segments
|
||||||
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
|
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
|
||||||
assert len(decoded_stream.video_packets) == bad_packet_start
|
assert len(decoded_stream.video_packets) == bad_packet_start
|
||||||
|
@ -477,9 +498,11 @@ async def test_no_video_stream(hass):
|
||||||
"""Test no video stream in the container means no resulting output."""
|
"""Test no video stream in the container means no resulting output."""
|
||||||
py_av = MockPyAv(video=False)
|
py_av = MockPyAv(video=False)
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(
|
with pytest.raises(StreamWorkerError, match=r"Stream has no video"):
|
||||||
|
await async_decode_stream(
|
||||||
hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av
|
hass, PacketSequence(TEST_SEQUENCE_LENGTH), py_av=py_av
|
||||||
)
|
)
|
||||||
|
decoded_stream = py_av.capture_buffer
|
||||||
# Note: This failure scenario does not output an end of stream
|
# Note: This failure scenario does not output an end of stream
|
||||||
segments = decoded_stream.segments
|
segments = decoded_stream.segments
|
||||||
assert len(segments) == 0
|
assert len(segments) == 0
|
||||||
|
@ -616,6 +639,9 @@ async def test_stream_stopped_while_decoding(hass):
|
||||||
worker_wake.set()
|
worker_wake.set()
|
||||||
stream.stop()
|
stream.stop()
|
||||||
|
|
||||||
|
# Stream is still considered available when the worker was still active and asked to stop
|
||||||
|
assert stream.available
|
||||||
|
|
||||||
|
|
||||||
async def test_update_stream_source(hass):
|
async def test_update_stream_source(hass):
|
||||||
"""Tests that the worker is re-invoked when the stream source is updated."""
|
"""Tests that the worker is re-invoked when the stream source is updated."""
|
||||||
|
@ -646,6 +672,7 @@ async def test_update_stream_source(hass):
|
||||||
stream.start()
|
stream.start()
|
||||||
assert worker_open.wait(TIMEOUT)
|
assert worker_open.wait(TIMEOUT)
|
||||||
assert last_stream_source == STREAM_SOURCE
|
assert last_stream_source == STREAM_SOURCE
|
||||||
|
assert stream.available
|
||||||
|
|
||||||
# Update the stream source, then the test wakes up the worker and assert
|
# Update the stream source, then the test wakes up the worker and assert
|
||||||
# that it re-opens the new stream (the test again waits on thread_started)
|
# that it re-opens the new stream (the test again waits on thread_started)
|
||||||
|
@ -655,6 +682,7 @@ async def test_update_stream_source(hass):
|
||||||
assert worker_open.wait(TIMEOUT)
|
assert worker_open.wait(TIMEOUT)
|
||||||
assert last_stream_source == STREAM_SOURCE + "-updated-source"
|
assert last_stream_source == STREAM_SOURCE + "-updated-source"
|
||||||
worker_wake.set()
|
worker_wake.set()
|
||||||
|
assert stream.available
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
stream.stop()
|
stream.stop()
|
||||||
|
@ -664,15 +692,16 @@ async def test_worker_log(hass, caplog):
|
||||||
"""Test that the worker logs the url without username and password."""
|
"""Test that the worker logs the url without username and password."""
|
||||||
stream = Stream(hass, "https://abcd:efgh@foo.bar", {})
|
stream = Stream(hass, "https://abcd:efgh@foo.bar", {})
|
||||||
stream.add_provider(HLS_PROVIDER)
|
stream.add_provider(HLS_PROVIDER)
|
||||||
with patch("av.open") as av_open:
|
|
||||||
|
with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err:
|
||||||
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
||||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||||
stream_worker(
|
stream_worker(
|
||||||
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
|
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
assert str(err.value) == "Error opening stream https://****:****@foo.bar"
|
||||||
assert "https://abcd:efgh@foo.bar" not in caplog.text
|
assert "https://abcd:efgh@foo.bar" not in caplog.text
|
||||||
assert "https://****:****@foo.bar" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
async def test_durations(hass, record_worker_sync):
|
async def test_durations(hass, record_worker_sync):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue