diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 7f88885ac0b..6c3f0104ad0 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -170,7 +170,7 @@ class Stream: def update_source(self, new_source): """Restart the stream with a new stream source.""" - _LOGGER.debug("Updating stream source %s", self.source) + _LOGGER.debug("Updating stream source %s", new_source) self.source = new_source self._fast_restart_once = True self._thread_quit.set() @@ -179,12 +179,14 @@ class Stream: """Handle consuming streams and restart keepalive streams.""" # Keep import here so that we can import stream integration without installing reqs # pylint: disable=import-outside-toplevel - from .worker import stream_worker + from .worker import SegmentBuffer, stream_worker + segment_buffer = SegmentBuffer(self.outputs) wait_timeout = 0 while not self._thread_quit.wait(timeout=wait_timeout): start_time = time.time() - stream_worker(self.source, self.options, self.outputs, self._thread_quit) + stream_worker(self.source, self.options, segment_buffer, self._thread_quit) + segment_buffer.discontinuity() if not self.keepalive or self._thread_quit.is_set(): if self._fast_restart_once: # The stream source is updated, restart without any delay. diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index f7beb3aa754..7a46de547d7 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -30,6 +30,8 @@ class Segment: sequence: int = attr.ib() segment: io.BytesIO = attr.ib() duration: float = attr.ib() + # For detecting discontinuities across stream restarts + stream_id: int = attr.ib(default=0) class IdleTimer: diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 57894d17711..85102d208e7 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -78,21 +78,27 @@ class HlsPlaylistView(StreamView): @staticmethod def render_playlist(track): """Render playlist.""" - segments = track.segments[-NUM_PLAYLIST_SEGMENTS:] + segments = list(track.get_segment())[-NUM_PLAYLIST_SEGMENTS:] if not segments: return [] - playlist = ["#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0])] + playlist = [ + "#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0].sequence), + "#EXT-X-DISCONTINUITY-SEQUENCE:{}".format(segments[0].stream_id), + ] - for sequence in segments: - segment = track.get_segment(sequence) + last_stream_id = segments[0].stream_id + for segment in segments: + if last_stream_id != segment.stream_id: + playlist.append("#EXT-X-DISCONTINUITY") playlist.extend( [ "#EXTINF:{:.04f},".format(float(segment.duration)), f"./segment/{segment.sequence}.m4s", ] ) + last_stream_id = segment.stream_id return playlist diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 41cb4bafd90..2592a74584e 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -49,16 +49,22 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): class SegmentBuffer: """Buffer for writing a sequence of packets to the output as a segment.""" - def __init__(self, video_stream, audio_stream, outputs_callback) -> None: + def __init__(self, outputs_callback) -> None: """Initialize SegmentBuffer.""" - self._video_stream = video_stream - self._audio_stream = audio_stream + self._stream_id = 0 + self._video_stream = None + self._audio_stream = None self._outputs_callback = outputs_callback # tuple of StreamOutput, StreamBuffer self._outputs = [] self._sequence = 0 self._segment_start_pts = None + def set_streams(self, video_stream, audio_stream): + """Initialize output buffer with streams from container.""" + self._video_stream = video_stream + self._audio_stream = audio_stream + def reset(self, video_pts): """Initialize a new stream segment.""" # Keep track of the number of segments we've processed @@ -103,7 +109,16 @@ class SegmentBuffer: """Create a segment from the buffered packets and write to output.""" for (buffer, stream_output) in self._outputs: buffer.output.close() - stream_output.put(Segment(self._sequence, buffer.segment, duration)) + stream_output.put( + Segment(self._sequence, buffer.segment, duration, self._stream_id) + ) + + def discontinuity(self): + """Mark the stream as having been restarted.""" + # Preserving sequence and stream_id here keep the HLS playlist logic + # simple to check for discontinuity at output time, and to determine + # the discontinuity sequence number. + self._stream_id += 1 def close(self): """Close all StreamBuffers.""" @@ -111,7 +126,7 @@ class SegmentBuffer: buffer.output.close() -def stream_worker(source, options, outputs_callback, quit_event): +def stream_worker(source, options, segment_buffer, quit_event): """Handle consuming streams.""" try: @@ -143,8 +158,6 @@ def stream_worker(source, options, outputs_callback, quit_event): last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")} # Keep track of consecutive packets without a dts to detect end of stream. missing_dts = 0 - # Holds the buffers for each stream provider - segment_buffer = SegmentBuffer(video_stream, audio_stream, outputs_callback) # The video pts at the beginning of the segment segment_start_pts = None # Because of problems 1 and 2 below, we need to store the first few packets and replay them @@ -225,6 +238,7 @@ def stream_worker(source, options, outputs_callback, quit_event): container.close() return + segment_buffer.set_streams(video_stream, audio_stream) segment_buffer.reset(segment_start_pts) while not quit_event.is_set(): diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 55b79684b7b..ffe32d13c61 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -51,7 +51,16 @@ def hls_stream(hass, hass_client): return create_client_for_stream -def playlist_response(sequence, segments): +def make_segment(segment, discontinuity=False): + """Create a playlist response for a segment.""" + response = [] + if discontinuity: + response.append("#EXT-X-DISCONTINUITY") + response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]), + return "\n".join(response) + + +def make_playlist(sequence, discontinuity_sequence=0, segments=[]): """Create a an hls playlist response for tests to assert on.""" response = [ "#EXTM3U", @@ -59,14 +68,9 @@ def playlist_response(sequence, segments): "#EXT-X-TARGETDURATION:10", '#EXT-X-MAP:URI="init.mp4"', f"#EXT-X-MEDIA-SEQUENCE:{sequence}", + f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}", ] - for segment in segments: - response.extend( - [ - "#EXTINF:10.0000,", - f"./segment/{segment}.m4s", - ] - ) + response.extend(segments) response.append("") return "\n".join(response) @@ -289,13 +293,15 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 - assert await resp.text() == playlist_response(sequence=1, segments=[1]) + assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)]) hls.put(Segment(2, SEQUENCE_BYTES, DURATION)) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 - assert await resp.text() == playlist_response(sequence=1, segments=[1, 2]) + assert await resp.text() == make_playlist( + sequence=1, segments=[make_segment(1), make_segment(2)] + ) stream_worker_sync.resume() stream.stop() @@ -321,8 +327,12 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist. start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS - assert await resp.text() == playlist_response( - sequence=start, segments=range(start, MAX_SEGMENTS + 2) + segments = [] + for sequence in range(start, MAX_SEGMENTS + 2): + segments.append(make_segment(sequence)) + assert await resp.text() == make_playlist( + sequence=start, + segments=segments, ) # Fetch the actual segments with a fake byte payload @@ -340,3 +350,70 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): stream_worker_sync.resume() stream.stop() + + +async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_sync): + """Test a discontinuity across segments in the stream with 3 segments.""" + await async_setup_component(hass, "stream", {"stream": {}}) + + stream = create_stream(hass, STREAM_SOURCE) + stream_worker_sync.pause() + hls = stream.hls_output() + + hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) + hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0)) + hls.put(Segment(3, SEQUENCE_BYTES, DURATION, stream_id=1)) + await hass.async_block_till_done() + + hls_client = await hls_stream(stream) + + resp = await hls_client.get("/playlist.m3u8") + assert resp.status == 200 + assert await resp.text() == make_playlist( + sequence=1, + segments=[ + make_segment(1), + make_segment(2), + make_segment(3, discontinuity=True), + ], + ) + + stream_worker_sync.resume() + stream.stop() + + +async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sync): + """Test a discontinuity with more segments than the segment deque can hold.""" + await async_setup_component(hass, "stream", {"stream": {}}) + + stream = create_stream(hass, STREAM_SOURCE) + stream_worker_sync.pause() + hls = stream.hls_output() + + hls_client = await hls_stream(stream) + + hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) + + # Produce enough segments to overfill the output buffer by one + for sequence in range(1, MAX_SEGMENTS + 2): + hls.put(Segment(sequence, SEQUENCE_BYTES, DURATION, stream_id=1)) + await hass.async_block_till_done() + + resp = await hls_client.get("/playlist.m3u8") + assert resp.status == 200 + + # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the + # EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE + # returned instead. + start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS + segments = [] + for sequence in range(start, MAX_SEGMENTS + 2): + segments.append(make_segment(sequence)) + assert await resp.text() == make_playlist( + sequence=start, + discontinuity_sequence=1, + segments=segments, + ) + + stream_worker_sync.resume() + stream.stop() diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index d9006c81ad5..f7952b7db44 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -27,7 +27,7 @@ from homeassistant.components.stream.const import ( MIN_SEGMENT_DURATION, PACKETS_TO_WAIT_FOR_AUDIO, ) -from homeassistant.components.stream.worker import stream_worker +from homeassistant.components.stream.worker import SegmentBuffer, stream_worker STREAM_SOURCE = "some-stream-source" # Formats here are arbitrary, not exercised by tests @@ -197,7 +197,8 @@ async def async_decode_stream(hass, packets, py_av=None): "homeassistant.components.stream.core.StreamOutput.put", side_effect=py_av.capture_buffer.capture_output_segment, ): - stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event()) + segment_buffer = SegmentBuffer(stream.outputs) + stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) await hass.async_block_till_done() return py_av.capture_buffer @@ -209,7 +210,8 @@ async def test_stream_open_fails(hass): stream.hls_output() with patch("av.open") as av_open: av_open.side_effect = av.error.InvalidDataError(-2, "error") - stream_worker(STREAM_SOURCE, {}, stream.outputs, threading.Event()) + segment_buffer = SegmentBuffer(stream.outputs) + stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) await hass.async_block_till_done() av_open.assert_called_once()