Remove incomplete segment on stream restart (#59532)

This commit is contained in:
uvjustin 2021-11-12 00:59:13 +08:00 committed by GitHub
parent 90ee1f4783
commit 9ea338c121
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 5 deletions

View file

@ -77,6 +77,16 @@ class HlsStreamOutput(StreamOutput):
or self.stream_settings.min_segment_duration or self.stream_settings.min_segment_duration
) )
def discontinuity(self) -> None:
"""Remove incomplete segment from deque."""
self._hass.loop.call_soon_threadsafe(self._async_discontinuity)
@callback
def _async_discontinuity(self) -> None:
"""Remove incomplete segment from deque in event loop."""
if self._segments and not self._segments[-1].complete:
self._segments.pop()
class HlsMasterPlaylistView(StreamView): class HlsMasterPlaylistView(StreamView):
"""Stream view used only for Chromecast compatibility.""" """Stream view used only for Chromecast compatibility."""

View file

@ -18,6 +18,7 @@ from .const import (
ATTR_SETTINGS, ATTR_SETTINGS,
AUDIO_CODECS, AUDIO_CODECS,
DOMAIN, DOMAIN,
HLS_PROVIDER,
MAX_MISSING_DTS, MAX_MISSING_DTS,
MAX_TIMESTAMP_GAP, MAX_TIMESTAMP_GAP,
PACKETS_TO_WAIT_FOR_AUDIO, PACKETS_TO_WAIT_FOR_AUDIO,
@ -25,6 +26,7 @@ from .const import (
SOURCE_TIMEOUT, SOURCE_TIMEOUT,
) )
from .core import Part, Segment, StreamOutput, StreamSettings from .core import Part, Segment, StreamOutput, StreamSettings
from .hls import HlsStreamOutput
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -279,6 +281,9 @@ class SegmentBuffer:
# the discontinuity sequence number. # the discontinuity sequence number.
self._stream_id += 1 self._stream_id += 1
self._start_time = datetime.datetime.utcnow() self._start_time = datetime.datetime.utcnow()
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
def close(self) -> None: def close(self) -> None:
"""Close stream buffer.""" """Close stream buffer."""

View file

@ -22,8 +22,8 @@ from aiohttp import web
import async_timeout import async_timeout
import pytest import pytest
from homeassistant.components.stream import Stream
from homeassistant.components.stream.core import Segment, StreamOutput from homeassistant.components.stream.core import Segment, StreamOutput
from homeassistant.components.stream.worker import SegmentBuffer
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
@ -34,7 +34,7 @@ class WorkerSync:
def __init__(self): def __init__(self):
"""Initialize WorkerSync.""" """Initialize WorkerSync."""
self._event = None self._event = None
self._original = Stream._worker_finished self._original = SegmentBuffer.discontinuity
def pause(self): def pause(self):
"""Pause the worker before it finalizes the stream.""" """Pause the worker before it finalizes the stream."""
@ -45,7 +45,7 @@ class WorkerSync:
logging.debug("waking blocked worker") logging.debug("waking blocked worker")
self._event.set() self._event.set()
def blocking_finish(self, stream: Stream): def blocking_discontinuity(self, stream: SegmentBuffer):
"""Intercept call to pause stream worker.""" """Intercept call to pause stream worker."""
# Worker is ending the stream, which clears all output buffers. # Worker is ending the stream, which clears all output buffers.
# Block the worker thread until the test has a chance to verify # Block the worker thread until the test has a chance to verify
@ -63,8 +63,8 @@ def stream_worker_sync(hass):
"""Patch StreamOutput to allow test to synchronize worker stream end.""" """Patch StreamOutput to allow test to synchronize worker stream end."""
sync = WorkerSync() sync = WorkerSync()
with patch( with patch(
"homeassistant.components.stream.Stream._worker_finished", "homeassistant.components.stream.worker.SegmentBuffer.discontinuity",
side_effect=sync.blocking_finish, side_effect=sync.blocking_discontinuity,
autospec=True, autospec=True,
): ):
yield sync yield sync

View file

@ -448,3 +448,33 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
stream_worker_sync.resume() stream_worker_sync.resume()
stream.stop() stream.stop()
async def test_remove_incomplete_segment_on_exit(hass, stream_worker_sync):
"""Test that the incomplete segment gets removed when the worker thread quits."""
await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause()
stream.start()
hls = stream.add_provider(HLS_PROVIDER)
segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION)
hls.put(segment)
segment = Segment(sequence=1, stream_id=0, duration=SEGMENT_DURATION)
hls.put(segment)
segment = Segment(sequence=2, stream_id=0, duration=0)
hls.put(segment)
await hass.async_block_till_done()
segments = hls._segments
assert len(segments) == 3
assert not segments[-1].complete
stream_worker_sync.resume()
stream._thread_quit.set()
stream._thread.join()
stream._thread = None
await hass.async_block_till_done()
assert segments[-1].complete
assert len(segments) == 2
stream.stop()