From 4aa4f7e28525d79d8b12d79ff74ba6a22bd8d5a8 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 20 Feb 2021 06:49:39 -0800 Subject: [PATCH] Rollback stream StreamOutput refactoring in PR#46610 (#46684) * Rollback PR#46610 * Update stream tests post-merge --- homeassistant/components/camera/__init__.py | 14 +- homeassistant/components/stream/__init__.py | 153 ++++++++++---------- homeassistant/components/stream/const.py | 8 +- homeassistant/components/stream/core.py | 93 +++++++++++- homeassistant/components/stream/hls.py | 92 ++++-------- homeassistant/components/stream/recorder.py | 43 ++++-- homeassistant/components/stream/worker.py | 11 +- tests/components/stream/test_hls.py | 36 ++--- tests/components/stream/test_recorder.py | 76 +++++----- tests/components/stream/test_worker.py | 9 +- 10 files changed, 295 insertions(+), 240 deletions(-) diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 4c5d89030b4..99b5cebc2a3 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -24,11 +24,7 @@ from homeassistant.components.media_player.const import ( SERVICE_PLAY_MEDIA, ) from homeassistant.components.stream import Stream, create_stream -from homeassistant.components.stream.const import ( - FORMAT_CONTENT_TYPE, - HLS_OUTPUT, - OUTPUT_FORMATS, -) +from homeassistant.components.stream.const import FORMAT_CONTENT_TYPE, OUTPUT_FORMATS from homeassistant.const import ( ATTR_ENTITY_ID, CONF_FILENAME, @@ -259,7 +255,7 @@ async def async_setup(hass, config): if not stream: continue stream.keepalive = True - stream.hls_output() + stream.add_provider("hls") stream.start() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream) @@ -707,8 +703,6 @@ async def async_handle_play_stream_service(camera, service_call): async def _async_stream_endpoint_url(hass, camera, fmt): - if fmt != HLS_OUTPUT: - raise ValueError("Only format {HLS_OUTPUT} is supported") stream = await camera.create_stream() if not stream: raise HomeAssistantError( @@ -719,9 +713,9 @@ async def _async_stream_endpoint_url(hass, camera, fmt): camera_prefs = hass.data[DATA_CAMERA_PREFS].get(camera.entity_id) stream.keepalive = camera_prefs.preload_stream - stream.hls_output() + stream.add_provider(fmt) stream.start() - return stream.endpoint_url() + return stream.endpoint_url(fmt) async def async_handle_record_service(camera, call): diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 6c3f0104ad0..0027152dbd6 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -7,25 +7,25 @@ a new Stream object. Stream manages: - Home Assistant URLs for viewing a stream - Access tokens for URLs for viewing a stream -A Stream consists of a background worker and multiple output streams (e.g. hls -and recorder). The worker has a callback to retrieve the current active output -streams where it writes the decoded output packets. The HLS stream has an -inactivity idle timeout that expires the access token. When all output streams -are inactive, the background worker is shut down. Alternatively, a Stream -can be configured with keepalive to always keep workers active. +A Stream consists of a background worker, and one or more output formats each +with their own idle timeout managed by the stream component. When an output +format is no longer in use, the stream component will expire it. When there +are no active output formats, the background worker is shut down and access +tokens are expired. Alternatively, a Stream can be configured with keepalive +to always keep workers active. """ import logging import secrets import threading import time -from typing import List +from types import MappingProxyType from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError from .const import ( - ATTR_HLS_ENDPOINT, + ATTR_ENDPOINTS, ATTR_STREAMS, DOMAIN, MAX_SEGMENTS, @@ -33,8 +33,8 @@ from .const import ( STREAM_RESTART_INCREMENT, STREAM_RESTART_RESET_TIME, ) -from .core import IdleTimer, StreamOutput -from .hls import HlsStreamOutput, async_setup_hls +from .core import PROVIDERS, IdleTimer +from .hls import async_setup_hls _LOGGER = logging.getLogger(__name__) @@ -75,10 +75,12 @@ async def async_setup(hass, config): from .recorder import async_setup_recorder hass.data[DOMAIN] = {} + hass.data[DOMAIN][ATTR_ENDPOINTS] = {} hass.data[DOMAIN][ATTR_STREAMS] = [] # Setup HLS - hass.data[DOMAIN][ATTR_HLS_ENDPOINT] = async_setup_hls(hass) + hls_endpoint = async_setup_hls(hass) + hass.data[DOMAIN][ATTR_ENDPOINTS]["hls"] = hls_endpoint # Setup Recorder async_setup_recorder(hass) @@ -87,6 +89,7 @@ async def async_setup(hass, config): def shutdown(event): """Stop all stream workers.""" for stream in hass.data[DOMAIN][ATTR_STREAMS]: + stream.keepalive = False stream.stop() _LOGGER.info("Stopped stream workers") @@ -107,54 +110,58 @@ class Stream: self.access_token = None self._thread = None self._thread_quit = threading.Event() - self._hls = None - self._hls_timer = None - self._recorder = None + self._outputs = {} self._fast_restart_once = False if self.options is None: self.options = {} - def endpoint_url(self) -> str: - """Start the stream and returns a url for the hls endpoint.""" - if not self._hls: - raise ValueError("Stream is not configured for hls") + def endpoint_url(self, fmt): + """Start the stream and returns a url for the output format.""" + if fmt not in self._outputs: + raise ValueError(f"Stream is not configured for format '{fmt}'") if not self.access_token: self.access_token = secrets.token_hex() - return self.hass.data[DOMAIN][ATTR_HLS_ENDPOINT].format(self.access_token) + return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token) - def outputs(self) -> List[StreamOutput]: - """Return the active stream outputs.""" - return [output for output in [self._hls, self._recorder] if output] + def outputs(self): + """Return a copy of the stream outputs.""" + # A copy is returned so the caller can iterate through the outputs + # without concern about self._outputs being modified from another thread. + return MappingProxyType(self._outputs.copy()) - def hls_output(self) -> StreamOutput: - """Return the hls output stream, creating if not already active.""" - if not self._hls: - self._hls = HlsStreamOutput(self.hass) - self._hls_timer = IdleTimer(self.hass, OUTPUT_IDLE_TIMEOUT, self._hls_idle) - self._hls_timer.start() - self._hls_timer.awake() - return self._hls + def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT): + """Add provider output stream.""" + if not self._outputs.get(fmt): - @callback - def _hls_idle(self): - """Reset access token and cleanup stream due to inactivity.""" - self.access_token = None - if not self.keepalive: - if self._hls: - self._hls.cleanup() - self._hls = None - self._hls_timer = None - self._check_idle() + @callback + def idle_callback(): + if not self.keepalive and fmt in self._outputs: + self.remove_provider(self._outputs[fmt]) + self.check_idle() - def _check_idle(self): - """Check if all outputs are idle and shut down worker.""" - if self.keepalive or self.outputs(): - return - self.stop() + provider = PROVIDERS[fmt]( + self.hass, IdleTimer(self.hass, timeout, idle_callback) + ) + self._outputs[fmt] = provider + return self._outputs[fmt] + + def remove_provider(self, provider): + """Remove provider output stream.""" + if provider.name in self._outputs: + self._outputs[provider.name].cleanup() + del self._outputs[provider.name] + + if not self._outputs: + self.stop() + + def check_idle(self): + """Reset access token if all providers are idle.""" + if all([p.idle for p in self._outputs.values()]): + self.access_token = None def start(self): - """Start stream decode worker.""" + """Start a stream.""" if self._thread is None or not self._thread.is_alive(): if self._thread is not None: # The thread must have crashed/exited. Join to clean up the @@ -210,21 +217,21 @@ class Stream: def _worker_finished(self): """Schedule cleanup of all outputs.""" - self.hass.loop.call_soon_threadsafe(self.stop) + + @callback + def remove_outputs(): + for provider in self.outputs().values(): + self.remove_provider(provider) + + self.hass.loop.call_soon_threadsafe(remove_outputs) def stop(self): """Remove outputs and access token.""" + self._outputs = {} self.access_token = None - if self._hls_timer: - self._hls_timer.clear() - self._hls_timer = None - if self._hls: - self._hls.cleanup() - self._hls = None - if self._recorder: - self._recorder.save() - self._recorder = None - self._stop() + + if not self.keepalive: + self._stop() def _stop(self): """Stop worker thread.""" @@ -237,35 +244,25 @@ class Stream: async def async_record(self, video_path, duration=30, lookback=5): """Make a .mp4 recording from a provided stream.""" - # Keep import here so that we can import stream integration without installing reqs - # pylint: disable=import-outside-toplevel - from .recorder import RecorderOutput - # Check for file access if not self.hass.config.is_allowed_path(video_path): raise HomeAssistantError(f"Can't write {video_path}, no access to path!") # Add recorder - if self._recorder: + recorder = self.outputs().get("recorder") + if recorder: raise HomeAssistantError( - f"Stream already recording to {self._recorder.video_path}!" + f"Stream already recording to {recorder.video_path}!" ) - self._recorder = RecorderOutput(self.hass) - self._recorder.video_path = video_path + recorder = self.add_provider("recorder", timeout=duration) + recorder.video_path = video_path + self.start() # Take advantage of lookback - if lookback > 0 and self._hls: - num_segments = min(int(lookback // self._hls.target_duration), MAX_SEGMENTS) + hls = self.outputs().get("hls") + if lookback > 0 and hls: + num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS) # Wait for latest segment, then add the lookback - await self._hls.recv() - self._recorder.prepend(list(self._hls.get_segment())[-num_segments:]) - - @callback - def save_recording(): - if self._recorder: - self._recorder.save() - self._recorder = None - self._check_idle() - - IdleTimer(self.hass, duration, save_recording).start() + await hls.recv() + recorder.prepend(list(hls.get_segment())[-num_segments:]) diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index 55f447a9a69..41df806d020 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -1,14 +1,10 @@ """Constants for Stream component.""" DOMAIN = "stream" -ATTR_HLS_ENDPOINT = "hls_endpoint" +ATTR_ENDPOINTS = "endpoints" ATTR_STREAMS = "streams" -HLS_OUTPUT = "hls" -OUTPUT_FORMATS = [HLS_OUTPUT] -OUTPUT_CONTAINER_FORMAT = "mp4" -OUTPUT_VIDEO_CODECS = {"hevc", "h264"} -OUTPUT_AUDIO_CODECS = {"aac", "mp3"} +OUTPUT_FORMATS = ["hls"] FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"} diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 7a46de547d7..eba6a069698 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -1,7 +1,8 @@ """Provides core stream functionality.""" -import abc +import asyncio +from collections import deque import io -from typing import Callable +from typing import Any, Callable, List from aiohttp import web import attr @@ -9,8 +10,11 @@ import attr from homeassistant.components.http import HomeAssistantView from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.event import async_call_later +from homeassistant.util.decorator import Registry -from .const import ATTR_STREAMS, DOMAIN +from .const import ATTR_STREAMS, DOMAIN, MAX_SEGMENTS + +PROVIDERS = Registry() @attr.s @@ -76,18 +80,86 @@ class IdleTimer: self._callback() -class StreamOutput(abc.ABC): +class StreamOutput: """Represents a stream output.""" - def __init__(self, hass: HomeAssistant): + def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: """Initialize a stream output.""" self._hass = hass + self._idle_timer = idle_timer + self._cursor = None + self._event = asyncio.Event() + self._segments = deque(maxlen=MAX_SEGMENTS) + + @property + def name(self) -> str: + """Return provider name.""" + return None + + @property + def idle(self) -> bool: + """Return True if the output is idle.""" + return self._idle_timer.idle + + @property + def format(self) -> str: + """Return container format.""" + return None + + @property + def audio_codecs(self) -> str: + """Return desired audio codecs.""" + return None + + @property + def video_codecs(self) -> tuple: + """Return desired video codecs.""" + return None @property def container_options(self) -> Callable[[int], dict]: """Return Callable which takes a sequence number and returns container options.""" return None + @property + def segments(self) -> List[int]: + """Return current sequence from segments.""" + return [s.sequence for s in self._segments] + + @property + def target_duration(self) -> int: + """Return the max duration of any given segment in seconds.""" + segment_length = len(self._segments) + if not segment_length: + return 1 + durations = [s.duration for s in self._segments] + return round(max(durations)) or 1 + + def get_segment(self, sequence: int = None) -> Any: + """Retrieve a specific segment, or the whole list.""" + self._idle_timer.awake() + + if not sequence: + return self._segments + + for segment in self._segments: + if segment.sequence == sequence: + return segment + return None + + async def recv(self) -> Segment: + """Wait for and retrieve the latest segment.""" + last_segment = max(self.segments, default=0) + if self._cursor is None or self._cursor <= last_segment: + await self._event.wait() + + if not self._segments: + return None + + segment = self.get_segment()[-1] + self._cursor = segment.sequence + return segment + def put(self, segment: Segment) -> None: """Store output.""" self._hass.loop.call_soon_threadsafe(self._async_put, segment) @@ -95,6 +167,17 @@ class StreamOutput(abc.ABC): @callback def _async_put(self, segment: Segment) -> None: """Store output from event loop.""" + # Start idle timeout when we start receiving data + self._idle_timer.start() + self._segments.append(segment) + self._event.set() + self._event.clear() + + def cleanup(self): + """Handle cleanup.""" + self._event.set() + self._idle_timer.clear() + self._segments = deque(maxlen=MAX_SEGMENTS) class StreamView(HomeAssistantView): diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 85102d208e7..28d6a300ae7 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -1,15 +1,13 @@ """Provide functionality to stream HLS.""" -import asyncio -from collections import deque import io -from typing import Any, Callable, List +from typing import Callable from aiohttp import web from homeassistant.core import callback -from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS -from .core import Segment, StreamOutput, StreamView +from .const import FORMAT_CONTENT_TYPE, NUM_PLAYLIST_SEGMENTS +from .core import PROVIDERS, StreamOutput, StreamView from .fmp4utils import get_codec_string, get_init, get_m4s @@ -50,7 +48,8 @@ class HlsMasterPlaylistView(StreamView): async def handle(self, request, stream, sequence): """Return m3u8 playlist.""" - track = stream.hls_output() + track = stream.add_provider("hls") + stream.start() # Wait for a segment to be ready if not track.segments: if not await track.recv(): @@ -109,7 +108,8 @@ class HlsPlaylistView(StreamView): async def handle(self, request, stream, sequence): """Return m3u8 playlist.""" - track = stream.hls_output() + track = stream.add_provider("hls") + stream.start() # Wait for a segment to be ready if not track.segments: if not await track.recv(): @@ -127,7 +127,7 @@ class HlsInitView(StreamView): async def handle(self, request, stream, sequence): """Return init.mp4.""" - track = stream.hls_output() + track = stream.add_provider("hls") segments = track.get_segment() if not segments: return web.HTTPNotFound() @@ -144,7 +144,7 @@ class HlsSegmentView(StreamView): async def handle(self, request, stream, sequence): """Return fmp4 segment.""" - track = stream.hls_output() + track = stream.add_provider("hls") segment = track.get_segment(int(sequence)) if not segment: return web.HTTPNotFound() @@ -155,15 +155,29 @@ class HlsSegmentView(StreamView): ) +@PROVIDERS.register("hls") class HlsStreamOutput(StreamOutput): """Represents HLS Output formats.""" - def __init__(self, hass) -> None: - """Initialize HlsStreamOutput.""" - super().__init__(hass) - self._cursor = None - self._event = asyncio.Event() - self._segments = deque(maxlen=MAX_SEGMENTS) + @property + def name(self) -> str: + """Return provider name.""" + return "hls" + + @property + def format(self) -> str: + """Return container format.""" + return "mp4" + + @property + def audio_codecs(self) -> str: + """Return desired audio codecs.""" + return {"aac", "mp3"} + + @property + def video_codecs(self) -> tuple: + """Return desired video codecs.""" + return {"hevc", "h264"} @property def container_options(self) -> Callable[[int], dict]: @@ -174,51 +188,3 @@ class HlsStreamOutput(StreamOutput): "avoid_negative_ts": "make_non_negative", "fragment_index": str(sequence), } - - @property - def segments(self) -> List[int]: - """Return current sequence from segments.""" - return [s.sequence for s in self._segments] - - @property - def target_duration(self) -> int: - """Return the max duration of any given segment in seconds.""" - segment_length = len(self._segments) - if not segment_length: - return 1 - durations = [s.duration for s in self._segments] - return round(max(durations)) or 1 - - def get_segment(self, sequence: int = None) -> Any: - """Retrieve a specific segment, or the whole list.""" - if not sequence: - return self._segments - - for segment in self._segments: - if segment.sequence == sequence: - return segment - return None - - async def recv(self) -> Segment: - """Wait for and retrieve the latest segment.""" - last_segment = max(self.segments, default=0) - if self._cursor is None or self._cursor <= last_segment: - await self._event.wait() - - if not self._segments: - return None - - segment = self.get_segment()[-1] - self._cursor = segment.sequence - return segment - - def _async_put(self, segment: Segment) -> None: - """Store output from event loop.""" - self._segments.append(segment) - self._event.set() - self._event.clear() - - def cleanup(self): - """Handle cleanup.""" - self._event.set() - self._segments = deque(maxlen=MAX_SEGMENTS) diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index 96531233771..0b77d0ba630 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -6,10 +6,9 @@ from typing import List import av -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback -from .const import OUTPUT_CONTAINER_FORMAT -from .core import Segment, StreamOutput +from .core import PROVIDERS, IdleTimer, Segment, StreamOutput _LOGGER = logging.getLogger(__name__) @@ -19,7 +18,7 @@ def async_setup_recorder(hass): """Only here so Provider Registry works.""" -def recorder_save_worker(file_out: str, segments: List[Segment], container_format): +def recorder_save_worker(file_out: str, segments: List[Segment], container_format: str): """Handle saving stream.""" if not os.path.exists(os.path.dirname(file_out)): os.makedirs(os.path.dirname(file_out), exist_ok=True) @@ -76,31 +75,51 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma output.close() +@PROVIDERS.register("recorder") class RecorderOutput(StreamOutput): """Represents HLS Output formats.""" - def __init__(self, hass) -> None: + def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None: """Initialize recorder output.""" - super().__init__(hass) + super().__init__(hass, idle_timer) self.video_path = None self._segments = [] - def _async_put(self, segment: Segment) -> None: - """Store output.""" - self._segments.append(segment) + @property + def name(self) -> str: + """Return provider name.""" + return "recorder" + + @property + def format(self) -> str: + """Return container format.""" + return "mp4" + + @property + def audio_codecs(self) -> str: + """Return desired audio codec.""" + return {"aac", "mp3"} + + @property + def video_codecs(self) -> tuple: + """Return desired video codecs.""" + return {"hevc", "h264"} def prepend(self, segments: List[Segment]) -> None: """Prepend segments to existing list.""" - segments = [s for s in segments if s.sequence not in self._segments] + own_segments = self.segments + segments = [s for s in segments if s.sequence not in own_segments] self._segments = segments + self._segments - def save(self): + def cleanup(self): """Write recording and clean up.""" _LOGGER.debug("Starting recorder worker thread") thread = threading.Thread( name="recorder_save_worker", target=recorder_save_worker, - args=(self.video_path, self._segments, OUTPUT_CONTAINER_FORMAT), + args=(self.video_path, self._segments, self.format), ) thread.start() + + super().cleanup() self._segments = [] diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 2592a74584e..61d4f5db17a 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -9,9 +9,6 @@ from .const import ( MAX_MISSING_DTS, MAX_TIMESTAMP_GAP, MIN_SEGMENT_DURATION, - OUTPUT_AUDIO_CODECS, - OUTPUT_CONTAINER_FORMAT, - OUTPUT_VIDEO_CODECS, PACKETS_TO_WAIT_FOR_AUDIO, STREAM_TIMEOUT, ) @@ -32,7 +29,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): output = av.open( segment, mode="w", - format=OUTPUT_CONTAINER_FORMAT, + format=stream_output.format, container_options={ "video_track_timescale": str(int(1 / video_stream.time_base)), **container_options, @@ -41,7 +38,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence): vstream = output.add_stream(template=video_stream) # Check if audio is requested astream = None - if audio_stream and audio_stream.name in OUTPUT_AUDIO_CODECS: + if audio_stream and audio_stream.name in stream_output.audio_codecs: astream = output.add_stream(template=audio_stream) return StreamBuffer(segment, output, vstream, astream) @@ -74,8 +71,8 @@ class SegmentBuffer: # Fetch the latest StreamOutputs, which may have changed since the # worker started. self._outputs = [] - for stream_output in self._outputs_callback(): - if self._video_stream.name not in OUTPUT_VIDEO_CODECS: + for stream_output in self._outputs_callback().values(): + if self._video_stream.name not in stream_output.video_codecs: continue buffer = create_stream_buffer( stream_output, self._video_stream, self._audio_stream, self._sequence diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index ffe32d13c61..c11576d2570 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -45,7 +45,7 @@ def hls_stream(hass, hass_client): async def create_client_for_stream(stream): http_client = await hass_client() - parsed_url = urlparse(stream.endpoint_url()) + parsed_url = urlparse(stream.endpoint_url("hls")) return HlsClient(http_client, parsed_url) return create_client_for_stream @@ -91,7 +91,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, source) # Request stream - stream.hls_output() + stream.add_provider("hls") stream.start() hls_client = await hls_stream(stream) @@ -132,9 +132,9 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync): stream = create_stream(hass, source) # Request stream - stream.hls_output() + stream.add_provider("hls") stream.start() - url = stream.endpoint_url() + url = stream.endpoint_url("hls") http_client = await hass_client() @@ -174,16 +174,8 @@ async def test_stream_timeout_after_stop(hass, hass_client, stream_worker_sync): stream = create_stream(hass, source) # Request stream - stream.hls_output() + stream.add_provider("hls") stream.start() - url = stream.endpoint_url() - - http_client = await hass_client() - - # Fetch playlist - parsed_url = urlparse(url) - playlist_response = await http_client.get(parsed_url.path) - assert playlist_response.status == 200 stream_worker_sync.resume() stream.stop() @@ -204,10 +196,12 @@ async def test_stream_ended(hass, stream_worker_sync): # Setup demo HLS track source = generate_h264_video() stream = create_stream(hass, source) + track = stream.add_provider("hls") # Request stream - track = stream.hls_output() + stream.add_provider("hls") stream.start() + stream.endpoint_url("hls") # Run it dead while True: @@ -233,7 +227,7 @@ async def test_stream_keepalive(hass): # Setup demo HLS track source = "test_stream_keepalive_source" stream = create_stream(hass, source) - track = stream.hls_output() + track = stream.add_provider("hls") track.num_segments = 2 stream.start() @@ -264,12 +258,12 @@ async def test_stream_keepalive(hass): stream.stop() -async def test_hls_playlist_view_no_output(hass, hls_stream): +async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream): """Test rendering the hls playlist with no output segments.""" await async_setup_component(hass, "stream", {"stream": {}}) stream = create_stream(hass, STREAM_SOURCE) - stream.hls_output() + stream.add_provider("hls") hls_client = await hls_stream(stream) @@ -284,7 +278,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.hls_output() + hls = stream.add_provider("hls") hls.put(Segment(1, SEQUENCE_BYTES, DURATION)) await hass.async_block_till_done() @@ -313,7 +307,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.hls_output() + hls = stream.add_provider("hls") hls_client = await hls_stream(stream) @@ -358,7 +352,7 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.hls_output() + hls = stream.add_provider("hls") hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0)) hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0)) @@ -388,7 +382,7 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy stream = create_stream(hass, STREAM_SOURCE) stream_worker_sync.pause() - hls = stream.hls_output() + hls = stream.add_provider("hls") hls_client = await hls_stream(stream) diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index e8ff540ba41..9d418c360b1 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -1,12 +1,10 @@ """The tests for hls streams.""" -import asyncio from datetime import timedelta import logging import os import threading from unittest.mock import patch -import async_timeout import av import pytest @@ -34,30 +32,23 @@ class SaveRecordWorkerSync: def __init__(self): """Initialize SaveRecordWorkerSync.""" self.reset() - self._segments = None - def recorder_save_worker(self, file_out, segments, container_format): + def recorder_save_worker(self, *args, **kwargs): """Mock method for patch.""" logging.debug("recorder_save_worker thread started") - self._segments = segments assert self._save_thread is None self._save_thread = threading.current_thread() self._save_event.set() - async def get_segments(self): - """Verify save worker thread was invoked and return saved segments.""" - with async_timeout.timeout(TEST_TIMEOUT): - assert await self._save_event.wait() - return self._segments - def join(self): - """Block until the record worker thread exist to ensure cleanup.""" + """Verify save worker was invoked and block on shutdown.""" + assert self._save_event.wait(timeout=TEST_TIMEOUT) self._save_thread.join() def reset(self): """Reset callback state for reuse in tests.""" self._save_thread = None - self._save_event = asyncio.Event() + self._save_event = threading.Event() @pytest.fixture() @@ -72,7 +63,7 @@ def record_worker_sync(hass): yield sync -async def test_record_stream(hass, hass_client, record_worker_sync): +async def test_record_stream(hass, hass_client, stream_worker_sync, record_worker_sync): """ Test record stream. @@ -82,14 +73,28 @@ async def test_record_stream(hass, hass_client, record_worker_sync): """ await async_setup_component(hass, "stream", {"stream": {}}) + stream_worker_sync.pause() + # Setup demo track source = generate_h264_video() stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") - segments = await record_worker_sync.get_segments() - assert len(segments) > 1 + recorder = stream.add_provider("recorder") + while True: + segment = await recorder.recv() + if not segment: + break + segments = segment.sequence + if segments > 1: + stream_worker_sync.resume() + + stream.stop() + assert segments > 1 + + # Verify that the save worker was invoked, then block until its + # thread completes and is shutdown completely to avoid thread leaks. record_worker_sync.join() @@ -102,24 +107,19 @@ async def test_record_lookback( source = generate_h264_video() stream = create_stream(hass, source) - # Don't let the stream finish (and clean itself up) until the test has had - # a chance to perform lookback - stream_worker_sync.pause() - # Start an HLS feed to enable lookback - stream.hls_output() + stream.add_provider("hls") + stream.start() with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path", lookback=4) # This test does not need recorder cleanup since it is not fully exercised - stream_worker_sync.resume() + stream.stop() -async def test_recorder_timeout( - hass, hass_client, stream_worker_sync, record_worker_sync -): +async def test_recorder_timeout(hass, hass_client, stream_worker_sync): """ Test recorder timeout. @@ -137,8 +137,9 @@ async def test_recorder_timeout( stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") + recorder = stream.add_provider("recorder") - assert not mock_timeout.called + await recorder.recv() # Wait a minute future = dt_util.utcnow() + timedelta(minutes=1) @@ -148,10 +149,6 @@ async def test_recorder_timeout( assert mock_timeout.called stream_worker_sync.resume() - # Verify worker is invoked, and do clean shutdown of worker thread - await record_worker_sync.get_segments() - record_worker_sync.join() - stream.stop() await hass.async_block_till_done() await hass.async_block_till_done() @@ -183,7 +180,9 @@ async def test_recorder_save(tmpdir): assert os.path.exists(filename) -async def test_record_stream_audio(hass, hass_client, record_worker_sync): +async def test_record_stream_audio( + hass, hass_client, stream_worker_sync, record_worker_sync +): """ Test treatment of different audio inputs. @@ -199,6 +198,7 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync): (None, 0), # no audio stream ): record_worker_sync.reset() + stream_worker_sync.pause() # Setup demo track source = generate_h264_video( @@ -207,14 +207,22 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync): stream = create_stream(hass, source) with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path") + recorder = stream.add_provider("recorder") - segments = await record_worker_sync.get_segments() - last_segment = segments[-1] + while True: + segment = await recorder.recv() + if not segment: + break + last_segment = segment + stream_worker_sync.resume() result = av.open(last_segment.segment, "r", format="mp4") assert len(result.streams.audio) == expected_audio_streams result.close() - stream.stop() + await hass.async_block_till_done() + + # Verify that the save worker was invoked, then block until its + # thread completes and is shutdown completely to avoid thread leaks. record_worker_sync.join() diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index f7952b7db44..2c202a290ce 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -31,6 +31,7 @@ from homeassistant.components.stream.worker import SegmentBuffer, stream_worker STREAM_SOURCE = "some-stream-source" # Formats here are arbitrary, not exercised by tests +STREAM_OUTPUT_FORMAT = "hls" AUDIO_STREAM_FORMAT = "mp3" VIDEO_STREAM_FORMAT = "h264" VIDEO_FRAME_RATE = 12 @@ -187,7 +188,7 @@ class MockPyAv: async def async_decode_stream(hass, packets, py_av=None): """Start a stream worker that decodes incoming stream packets into output segments.""" stream = Stream(hass, STREAM_SOURCE) - stream.hls_output() + stream.add_provider(STREAM_OUTPUT_FORMAT) if not py_av: py_av = MockPyAv() @@ -207,7 +208,7 @@ async def async_decode_stream(hass, packets, py_av=None): async def test_stream_open_fails(hass): """Test failure on stream open.""" stream = Stream(hass, STREAM_SOURCE) - stream.hls_output() + stream.add_provider(STREAM_OUTPUT_FORMAT) with patch("av.open") as av_open: av_open.side_effect = av.error.InvalidDataError(-2, "error") segment_buffer = SegmentBuffer(stream.outputs) @@ -484,7 +485,7 @@ async def test_stream_stopped_while_decoding(hass): worker_wake = threading.Event() stream = Stream(hass, STREAM_SOURCE) - stream.hls_output() + stream.add_provider(STREAM_OUTPUT_FORMAT) py_av = MockPyAv() py_av.container.packets = PacketSequence(TEST_SEQUENCE_LENGTH) @@ -511,7 +512,7 @@ async def test_update_stream_source(hass): worker_wake = threading.Event() stream = Stream(hass, STREAM_SOURCE) - stream.hls_output() + stream.add_provider(STREAM_OUTPUT_FORMAT) # Note that keepalive is not set here. The stream is "restarted" even though # it is not stopping due to failure.