Rollback stream StreamOutput refactoring in PR#46610 (#46684)

* Rollback PR#46610

* Update stream tests post-merge
This commit is contained in:
Allen Porter 2021-02-20 06:49:39 -08:00 committed by GitHub
parent 788134cbc4
commit 4aa4f7e285
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 295 additions and 240 deletions

View file

@ -24,11 +24,7 @@ from homeassistant.components.media_player.const import (
SERVICE_PLAY_MEDIA,
)
from homeassistant.components.stream import Stream, create_stream
from homeassistant.components.stream.const import (
FORMAT_CONTENT_TYPE,
HLS_OUTPUT,
OUTPUT_FORMATS,
)
from homeassistant.components.stream.const import FORMAT_CONTENT_TYPE, OUTPUT_FORMATS
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_FILENAME,
@ -259,7 +255,7 @@ async def async_setup(hass, config):
if not stream:
continue
stream.keepalive = True
stream.hls_output()
stream.add_provider("hls")
stream.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream)
@ -707,8 +703,6 @@ async def async_handle_play_stream_service(camera, service_call):
async def _async_stream_endpoint_url(hass, camera, fmt):
if fmt != HLS_OUTPUT:
raise ValueError("Only format {HLS_OUTPUT} is supported")
stream = await camera.create_stream()
if not stream:
raise HomeAssistantError(
@ -719,9 +713,9 @@ async def _async_stream_endpoint_url(hass, camera, fmt):
camera_prefs = hass.data[DATA_CAMERA_PREFS].get(camera.entity_id)
stream.keepalive = camera_prefs.preload_stream
stream.hls_output()
stream.add_provider(fmt)
stream.start()
return stream.endpoint_url()
return stream.endpoint_url(fmt)
async def async_handle_record_service(camera, call):

View file

@ -7,25 +7,25 @@ a new Stream object. Stream manages:
- Home Assistant URLs for viewing a stream
- Access tokens for URLs for viewing a stream
A Stream consists of a background worker and multiple output streams (e.g. hls
and recorder). The worker has a callback to retrieve the current active output
streams where it writes the decoded output packets. The HLS stream has an
inactivity idle timeout that expires the access token. When all output streams
are inactive, the background worker is shut down. Alternatively, a Stream
can be configured with keepalive to always keep workers active.
A Stream consists of a background worker, and one or more output formats each
with their own idle timeout managed by the stream component. When an output
format is no longer in use, the stream component will expire it. When there
are no active output formats, the background worker is shut down and access
tokens are expired. Alternatively, a Stream can be configured with keepalive
to always keep workers active.
"""
import logging
import secrets
import threading
import time
from typing import List
from types import MappingProxyType
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from .const import (
ATTR_HLS_ENDPOINT,
ATTR_ENDPOINTS,
ATTR_STREAMS,
DOMAIN,
MAX_SEGMENTS,
@ -33,8 +33,8 @@ from .const import (
STREAM_RESTART_INCREMENT,
STREAM_RESTART_RESET_TIME,
)
from .core import IdleTimer, StreamOutput
from .hls import HlsStreamOutput, async_setup_hls
from .core import PROVIDERS, IdleTimer
from .hls import async_setup_hls
_LOGGER = logging.getLogger(__name__)
@ -75,10 +75,12 @@ async def async_setup(hass, config):
from .recorder import async_setup_recorder
hass.data[DOMAIN] = {}
hass.data[DOMAIN][ATTR_ENDPOINTS] = {}
hass.data[DOMAIN][ATTR_STREAMS] = []
# Setup HLS
hass.data[DOMAIN][ATTR_HLS_ENDPOINT] = async_setup_hls(hass)
hls_endpoint = async_setup_hls(hass)
hass.data[DOMAIN][ATTR_ENDPOINTS]["hls"] = hls_endpoint
# Setup Recorder
async_setup_recorder(hass)
@ -87,6 +89,7 @@ async def async_setup(hass, config):
def shutdown(event):
"""Stop all stream workers."""
for stream in hass.data[DOMAIN][ATTR_STREAMS]:
stream.keepalive = False
stream.stop()
_LOGGER.info("Stopped stream workers")
@ -107,54 +110,58 @@ class Stream:
self.access_token = None
self._thread = None
self._thread_quit = threading.Event()
self._hls = None
self._hls_timer = None
self._recorder = None
self._outputs = {}
self._fast_restart_once = False
if self.options is None:
self.options = {}
def endpoint_url(self) -> str:
"""Start the stream and returns a url for the hls endpoint."""
if not self._hls:
raise ValueError("Stream is not configured for hls")
def endpoint_url(self, fmt):
"""Start the stream and returns a url for the output format."""
if fmt not in self._outputs:
raise ValueError(f"Stream is not configured for format '{fmt}'")
if not self.access_token:
self.access_token = secrets.token_hex()
return self.hass.data[DOMAIN][ATTR_HLS_ENDPOINT].format(self.access_token)
return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token)
def outputs(self) -> List[StreamOutput]:
"""Return the active stream outputs."""
return [output for output in [self._hls, self._recorder] if output]
def outputs(self):
"""Return a copy of the stream outputs."""
# A copy is returned so the caller can iterate through the outputs
# without concern about self._outputs being modified from another thread.
return MappingProxyType(self._outputs.copy())
def hls_output(self) -> StreamOutput:
"""Return the hls output stream, creating if not already active."""
if not self._hls:
self._hls = HlsStreamOutput(self.hass)
self._hls_timer = IdleTimer(self.hass, OUTPUT_IDLE_TIMEOUT, self._hls_idle)
self._hls_timer.start()
self._hls_timer.awake()
return self._hls
def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT):
"""Add provider output stream."""
if not self._outputs.get(fmt):
@callback
def _hls_idle(self):
"""Reset access token and cleanup stream due to inactivity."""
self.access_token = None
if not self.keepalive:
if self._hls:
self._hls.cleanup()
self._hls = None
self._hls_timer = None
self._check_idle()
@callback
def idle_callback():
if not self.keepalive and fmt in self._outputs:
self.remove_provider(self._outputs[fmt])
self.check_idle()
def _check_idle(self):
"""Check if all outputs are idle and shut down worker."""
if self.keepalive or self.outputs():
return
self.stop()
provider = PROVIDERS[fmt](
self.hass, IdleTimer(self.hass, timeout, idle_callback)
)
self._outputs[fmt] = provider
return self._outputs[fmt]
def remove_provider(self, provider):
"""Remove provider output stream."""
if provider.name in self._outputs:
self._outputs[provider.name].cleanup()
del self._outputs[provider.name]
if not self._outputs:
self.stop()
def check_idle(self):
"""Reset access token if all providers are idle."""
if all([p.idle for p in self._outputs.values()]):
self.access_token = None
def start(self):
"""Start stream decode worker."""
"""Start a stream."""
if self._thread is None or not self._thread.is_alive():
if self._thread is not None:
# The thread must have crashed/exited. Join to clean up the
@ -210,21 +217,21 @@ class Stream:
def _worker_finished(self):
"""Schedule cleanup of all outputs."""
self.hass.loop.call_soon_threadsafe(self.stop)
@callback
def remove_outputs():
for provider in self.outputs().values():
self.remove_provider(provider)
self.hass.loop.call_soon_threadsafe(remove_outputs)
def stop(self):
"""Remove outputs and access token."""
self._outputs = {}
self.access_token = None
if self._hls_timer:
self._hls_timer.clear()
self._hls_timer = None
if self._hls:
self._hls.cleanup()
self._hls = None
if self._recorder:
self._recorder.save()
self._recorder = None
self._stop()
if not self.keepalive:
self._stop()
def _stop(self):
"""Stop worker thread."""
@ -237,35 +244,25 @@ class Stream:
async def async_record(self, video_path, duration=30, lookback=5):
"""Make a .mp4 recording from a provided stream."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
from .recorder import RecorderOutput
# Check for file access
if not self.hass.config.is_allowed_path(video_path):
raise HomeAssistantError(f"Can't write {video_path}, no access to path!")
# Add recorder
if self._recorder:
recorder = self.outputs().get("recorder")
if recorder:
raise HomeAssistantError(
f"Stream already recording to {self._recorder.video_path}!"
f"Stream already recording to {recorder.video_path}!"
)
self._recorder = RecorderOutput(self.hass)
self._recorder.video_path = video_path
recorder = self.add_provider("recorder", timeout=duration)
recorder.video_path = video_path
self.start()
# Take advantage of lookback
if lookback > 0 and self._hls:
num_segments = min(int(lookback // self._hls.target_duration), MAX_SEGMENTS)
hls = self.outputs().get("hls")
if lookback > 0 and hls:
num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS)
# Wait for latest segment, then add the lookback
await self._hls.recv()
self._recorder.prepend(list(self._hls.get_segment())[-num_segments:])
@callback
def save_recording():
if self._recorder:
self._recorder.save()
self._recorder = None
self._check_idle()
IdleTimer(self.hass, duration, save_recording).start()
await hls.recv()
recorder.prepend(list(hls.get_segment())[-num_segments:])

View file

@ -1,14 +1,10 @@
"""Constants for Stream component."""
DOMAIN = "stream"
ATTR_HLS_ENDPOINT = "hls_endpoint"
ATTR_ENDPOINTS = "endpoints"
ATTR_STREAMS = "streams"
HLS_OUTPUT = "hls"
OUTPUT_FORMATS = [HLS_OUTPUT]
OUTPUT_CONTAINER_FORMAT = "mp4"
OUTPUT_VIDEO_CODECS = {"hevc", "h264"}
OUTPUT_AUDIO_CODECS = {"aac", "mp3"}
OUTPUT_FORMATS = ["hls"]
FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"}

View file

@ -1,7 +1,8 @@
"""Provides core stream functionality."""
import abc
import asyncio
from collections import deque
import io
from typing import Callable
from typing import Any, Callable, List
from aiohttp import web
import attr
@ -9,8 +10,11 @@ import attr
from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry
from .const import ATTR_STREAMS, DOMAIN
from .const import ATTR_STREAMS, DOMAIN, MAX_SEGMENTS
PROVIDERS = Registry()
@attr.s
@ -76,18 +80,86 @@ class IdleTimer:
self._callback()
class StreamOutput(abc.ABC):
class StreamOutput:
"""Represents a stream output."""
def __init__(self, hass: HomeAssistant):
def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
"""Initialize a stream output."""
self._hass = hass
self._idle_timer = idle_timer
self._cursor = None
self._event = asyncio.Event()
self._segments = deque(maxlen=MAX_SEGMENTS)
@property
def name(self) -> str:
"""Return provider name."""
return None
@property
def idle(self) -> bool:
"""Return True if the output is idle."""
return self._idle_timer.idle
@property
def format(self) -> str:
"""Return container format."""
return None
@property
def audio_codecs(self) -> str:
"""Return desired audio codecs."""
return None
@property
def video_codecs(self) -> tuple:
"""Return desired video codecs."""
return None
@property
def container_options(self) -> Callable[[int], dict]:
"""Return Callable which takes a sequence number and returns container options."""
return None
@property
def segments(self) -> List[int]:
"""Return current sequence from segments."""
return [s.sequence for s in self._segments]
@property
def target_duration(self) -> int:
"""Return the max duration of any given segment in seconds."""
segment_length = len(self._segments)
if not segment_length:
return 1
durations = [s.duration for s in self._segments]
return round(max(durations)) or 1
def get_segment(self, sequence: int = None) -> Any:
"""Retrieve a specific segment, or the whole list."""
self._idle_timer.awake()
if not sequence:
return self._segments
for segment in self._segments:
if segment.sequence == sequence:
return segment
return None
async def recv(self) -> Segment:
"""Wait for and retrieve the latest segment."""
last_segment = max(self.segments, default=0)
if self._cursor is None or self._cursor <= last_segment:
await self._event.wait()
if not self._segments:
return None
segment = self.get_segment()[-1]
self._cursor = segment.sequence
return segment
def put(self, segment: Segment) -> None:
"""Store output."""
self._hass.loop.call_soon_threadsafe(self._async_put, segment)
@ -95,6 +167,17 @@ class StreamOutput(abc.ABC):
@callback
def _async_put(self, segment: Segment) -> None:
"""Store output from event loop."""
# Start idle timeout when we start receiving data
self._idle_timer.start()
self._segments.append(segment)
self._event.set()
self._event.clear()
def cleanup(self):
"""Handle cleanup."""
self._event.set()
self._idle_timer.clear()
self._segments = deque(maxlen=MAX_SEGMENTS)
class StreamView(HomeAssistantView):

View file

@ -1,15 +1,13 @@
"""Provide functionality to stream HLS."""
import asyncio
from collections import deque
import io
from typing import Any, Callable, List
from typing import Callable
from aiohttp import web
from homeassistant.core import callback
from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS
from .core import Segment, StreamOutput, StreamView
from .const import FORMAT_CONTENT_TYPE, NUM_PLAYLIST_SEGMENTS
from .core import PROVIDERS, StreamOutput, StreamView
from .fmp4utils import get_codec_string, get_init, get_m4s
@ -50,7 +48,8 @@ class HlsMasterPlaylistView(StreamView):
async def handle(self, request, stream, sequence):
"""Return m3u8 playlist."""
track = stream.hls_output()
track = stream.add_provider("hls")
stream.start()
# Wait for a segment to be ready
if not track.segments:
if not await track.recv():
@ -109,7 +108,8 @@ class HlsPlaylistView(StreamView):
async def handle(self, request, stream, sequence):
"""Return m3u8 playlist."""
track = stream.hls_output()
track = stream.add_provider("hls")
stream.start()
# Wait for a segment to be ready
if not track.segments:
if not await track.recv():
@ -127,7 +127,7 @@ class HlsInitView(StreamView):
async def handle(self, request, stream, sequence):
"""Return init.mp4."""
track = stream.hls_output()
track = stream.add_provider("hls")
segments = track.get_segment()
if not segments:
return web.HTTPNotFound()
@ -144,7 +144,7 @@ class HlsSegmentView(StreamView):
async def handle(self, request, stream, sequence):
"""Return fmp4 segment."""
track = stream.hls_output()
track = stream.add_provider("hls")
segment = track.get_segment(int(sequence))
if not segment:
return web.HTTPNotFound()
@ -155,15 +155,29 @@ class HlsSegmentView(StreamView):
)
@PROVIDERS.register("hls")
class HlsStreamOutput(StreamOutput):
"""Represents HLS Output formats."""
def __init__(self, hass) -> None:
"""Initialize HlsStreamOutput."""
super().__init__(hass)
self._cursor = None
self._event = asyncio.Event()
self._segments = deque(maxlen=MAX_SEGMENTS)
@property
def name(self) -> str:
"""Return provider name."""
return "hls"
@property
def format(self) -> str:
"""Return container format."""
return "mp4"
@property
def audio_codecs(self) -> str:
"""Return desired audio codecs."""
return {"aac", "mp3"}
@property
def video_codecs(self) -> tuple:
"""Return desired video codecs."""
return {"hevc", "h264"}
@property
def container_options(self) -> Callable[[int], dict]:
@ -174,51 +188,3 @@ class HlsStreamOutput(StreamOutput):
"avoid_negative_ts": "make_non_negative",
"fragment_index": str(sequence),
}
@property
def segments(self) -> List[int]:
"""Return current sequence from segments."""
return [s.sequence for s in self._segments]
@property
def target_duration(self) -> int:
"""Return the max duration of any given segment in seconds."""
segment_length = len(self._segments)
if not segment_length:
return 1
durations = [s.duration for s in self._segments]
return round(max(durations)) or 1
def get_segment(self, sequence: int = None) -> Any:
"""Retrieve a specific segment, or the whole list."""
if not sequence:
return self._segments
for segment in self._segments:
if segment.sequence == sequence:
return segment
return None
async def recv(self) -> Segment:
"""Wait for and retrieve the latest segment."""
last_segment = max(self.segments, default=0)
if self._cursor is None or self._cursor <= last_segment:
await self._event.wait()
if not self._segments:
return None
segment = self.get_segment()[-1]
self._cursor = segment.sequence
return segment
def _async_put(self, segment: Segment) -> None:
"""Store output from event loop."""
self._segments.append(segment)
self._event.set()
self._event.clear()
def cleanup(self):
"""Handle cleanup."""
self._event.set()
self._segments = deque(maxlen=MAX_SEGMENTS)

View file

@ -6,10 +6,9 @@ from typing import List
import av
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from .const import OUTPUT_CONTAINER_FORMAT
from .core import Segment, StreamOutput
from .core import PROVIDERS, IdleTimer, Segment, StreamOutput
_LOGGER = logging.getLogger(__name__)
@ -19,7 +18,7 @@ def async_setup_recorder(hass):
"""Only here so Provider Registry works."""
def recorder_save_worker(file_out: str, segments: List[Segment], container_format):
def recorder_save_worker(file_out: str, segments: List[Segment], container_format: str):
"""Handle saving stream."""
if not os.path.exists(os.path.dirname(file_out)):
os.makedirs(os.path.dirname(file_out), exist_ok=True)
@ -76,31 +75,51 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma
output.close()
@PROVIDERS.register("recorder")
class RecorderOutput(StreamOutput):
"""Represents HLS Output formats."""
def __init__(self, hass) -> None:
def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
"""Initialize recorder output."""
super().__init__(hass)
super().__init__(hass, idle_timer)
self.video_path = None
self._segments = []
def _async_put(self, segment: Segment) -> None:
"""Store output."""
self._segments.append(segment)
@property
def name(self) -> str:
"""Return provider name."""
return "recorder"
@property
def format(self) -> str:
"""Return container format."""
return "mp4"
@property
def audio_codecs(self) -> str:
"""Return desired audio codec."""
return {"aac", "mp3"}
@property
def video_codecs(self) -> tuple:
"""Return desired video codecs."""
return {"hevc", "h264"}
def prepend(self, segments: List[Segment]) -> None:
"""Prepend segments to existing list."""
segments = [s for s in segments if s.sequence not in self._segments]
own_segments = self.segments
segments = [s for s in segments if s.sequence not in own_segments]
self._segments = segments + self._segments
def save(self):
def cleanup(self):
"""Write recording and clean up."""
_LOGGER.debug("Starting recorder worker thread")
thread = threading.Thread(
name="recorder_save_worker",
target=recorder_save_worker,
args=(self.video_path, self._segments, OUTPUT_CONTAINER_FORMAT),
args=(self.video_path, self._segments, self.format),
)
thread.start()
super().cleanup()
self._segments = []

View file

@ -9,9 +9,6 @@ from .const import (
MAX_MISSING_DTS,
MAX_TIMESTAMP_GAP,
MIN_SEGMENT_DURATION,
OUTPUT_AUDIO_CODECS,
OUTPUT_CONTAINER_FORMAT,
OUTPUT_VIDEO_CODECS,
PACKETS_TO_WAIT_FOR_AUDIO,
STREAM_TIMEOUT,
)
@ -32,7 +29,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
output = av.open(
segment,
mode="w",
format=OUTPUT_CONTAINER_FORMAT,
format=stream_output.format,
container_options={
"video_track_timescale": str(int(1 / video_stream.time_base)),
**container_options,
@ -41,7 +38,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
vstream = output.add_stream(template=video_stream)
# Check if audio is requested
astream = None
if audio_stream and audio_stream.name in OUTPUT_AUDIO_CODECS:
if audio_stream and audio_stream.name in stream_output.audio_codecs:
astream = output.add_stream(template=audio_stream)
return StreamBuffer(segment, output, vstream, astream)
@ -74,8 +71,8 @@ class SegmentBuffer:
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
self._outputs = []
for stream_output in self._outputs_callback():
if self._video_stream.name not in OUTPUT_VIDEO_CODECS:
for stream_output in self._outputs_callback().values():
if self._video_stream.name not in stream_output.video_codecs:
continue
buffer = create_stream_buffer(
stream_output, self._video_stream, self._audio_stream, self._sequence

View file

@ -45,7 +45,7 @@ def hls_stream(hass, hass_client):
async def create_client_for_stream(stream):
http_client = await hass_client()
parsed_url = urlparse(stream.endpoint_url())
parsed_url = urlparse(stream.endpoint_url("hls"))
return HlsClient(http_client, parsed_url)
return create_client_for_stream
@ -91,7 +91,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
stream = create_stream(hass, source)
# Request stream
stream.hls_output()
stream.add_provider("hls")
stream.start()
hls_client = await hls_stream(stream)
@ -132,9 +132,9 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
stream = create_stream(hass, source)
# Request stream
stream.hls_output()
stream.add_provider("hls")
stream.start()
url = stream.endpoint_url()
url = stream.endpoint_url("hls")
http_client = await hass_client()
@ -174,16 +174,8 @@ async def test_stream_timeout_after_stop(hass, hass_client, stream_worker_sync):
stream = create_stream(hass, source)
# Request stream
stream.hls_output()
stream.add_provider("hls")
stream.start()
url = stream.endpoint_url()
http_client = await hass_client()
# Fetch playlist
parsed_url = urlparse(url)
playlist_response = await http_client.get(parsed_url.path)
assert playlist_response.status == 200
stream_worker_sync.resume()
stream.stop()
@ -204,10 +196,12 @@ async def test_stream_ended(hass, stream_worker_sync):
# Setup demo HLS track
source = generate_h264_video()
stream = create_stream(hass, source)
track = stream.add_provider("hls")
# Request stream
track = stream.hls_output()
stream.add_provider("hls")
stream.start()
stream.endpoint_url("hls")
# Run it dead
while True:
@ -233,7 +227,7 @@ async def test_stream_keepalive(hass):
# Setup demo HLS track
source = "test_stream_keepalive_source"
stream = create_stream(hass, source)
track = stream.hls_output()
track = stream.add_provider("hls")
track.num_segments = 2
stream.start()
@ -264,12 +258,12 @@ async def test_stream_keepalive(hass):
stream.stop()
async def test_hls_playlist_view_no_output(hass, hls_stream):
async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream):
"""Test rendering the hls playlist with no output segments."""
await async_setup_component(hass, "stream", {"stream": {}})
stream = create_stream(hass, STREAM_SOURCE)
stream.hls_output()
stream.add_provider("hls")
hls_client = await hls_stream(stream)
@ -284,7 +278,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION))
await hass.async_block_till_done()
@ -313,7 +307,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls = stream.add_provider("hls")
hls_client = await hls_stream(stream)
@ -358,7 +352,7 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls = stream.add_provider("hls")
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0))
@ -388,7 +382,7 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
stream = create_stream(hass, STREAM_SOURCE)
stream_worker_sync.pause()
hls = stream.hls_output()
hls = stream.add_provider("hls")
hls_client = await hls_stream(stream)

View file

@ -1,12 +1,10 @@
"""The tests for hls streams."""
import asyncio
from datetime import timedelta
import logging
import os
import threading
from unittest.mock import patch
import async_timeout
import av
import pytest
@ -34,30 +32,23 @@ class SaveRecordWorkerSync:
def __init__(self):
"""Initialize SaveRecordWorkerSync."""
self.reset()
self._segments = None
def recorder_save_worker(self, file_out, segments, container_format):
def recorder_save_worker(self, *args, **kwargs):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
self._segments = segments
assert self._save_thread is None
self._save_thread = threading.current_thread()
self._save_event.set()
async def get_segments(self):
"""Verify save worker thread was invoked and return saved segments."""
with async_timeout.timeout(TEST_TIMEOUT):
assert await self._save_event.wait()
return self._segments
def join(self):
"""Block until the record worker thread exist to ensure cleanup."""
"""Verify save worker was invoked and block on shutdown."""
assert self._save_event.wait(timeout=TEST_TIMEOUT)
self._save_thread.join()
def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()
self._save_event = threading.Event()
@pytest.fixture()
@ -72,7 +63,7 @@ def record_worker_sync(hass):
yield sync
async def test_record_stream(hass, hass_client, record_worker_sync):
async def test_record_stream(hass, hass_client, stream_worker_sync, record_worker_sync):
"""
Test record stream.
@ -82,14 +73,28 @@ async def test_record_stream(hass, hass_client, record_worker_sync):
"""
await async_setup_component(hass, "stream", {"stream": {}})
stream_worker_sync.pause()
# Setup demo track
source = generate_h264_video()
stream = create_stream(hass, source)
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
segments = await record_worker_sync.get_segments()
assert len(segments) > 1
recorder = stream.add_provider("recorder")
while True:
segment = await recorder.recv()
if not segment:
break
segments = segment.sequence
if segments > 1:
stream_worker_sync.resume()
stream.stop()
assert segments > 1
# Verify that the save worker was invoked, then block until its
# thread completes and is shutdown completely to avoid thread leaks.
record_worker_sync.join()
@ -102,24 +107,19 @@ async def test_record_lookback(
source = generate_h264_video()
stream = create_stream(hass, source)
# Don't let the stream finish (and clean itself up) until the test has had
# a chance to perform lookback
stream_worker_sync.pause()
# Start an HLS feed to enable lookback
stream.hls_output()
stream.add_provider("hls")
stream.start()
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path", lookback=4)
# This test does not need recorder cleanup since it is not fully exercised
stream_worker_sync.resume()
stream.stop()
async def test_recorder_timeout(
hass, hass_client, stream_worker_sync, record_worker_sync
):
async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
"""
Test recorder timeout.
@ -137,8 +137,9 @@ async def test_recorder_timeout(
stream = create_stream(hass, source)
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
recorder = stream.add_provider("recorder")
assert not mock_timeout.called
await recorder.recv()
# Wait a minute
future = dt_util.utcnow() + timedelta(minutes=1)
@ -148,10 +149,6 @@ async def test_recorder_timeout(
assert mock_timeout.called
stream_worker_sync.resume()
# Verify worker is invoked, and do clean shutdown of worker thread
await record_worker_sync.get_segments()
record_worker_sync.join()
stream.stop()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -183,7 +180,9 @@ async def test_recorder_save(tmpdir):
assert os.path.exists(filename)
async def test_record_stream_audio(hass, hass_client, record_worker_sync):
async def test_record_stream_audio(
hass, hass_client, stream_worker_sync, record_worker_sync
):
"""
Test treatment of different audio inputs.
@ -199,6 +198,7 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync):
(None, 0), # no audio stream
):
record_worker_sync.reset()
stream_worker_sync.pause()
# Setup demo track
source = generate_h264_video(
@ -207,14 +207,22 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync):
stream = create_stream(hass, source)
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
recorder = stream.add_provider("recorder")
segments = await record_worker_sync.get_segments()
last_segment = segments[-1]
while True:
segment = await recorder.recv()
if not segment:
break
last_segment = segment
stream_worker_sync.resume()
result = av.open(last_segment.segment, "r", format="mp4")
assert len(result.streams.audio) == expected_audio_streams
result.close()
stream.stop()
await hass.async_block_till_done()
# Verify that the save worker was invoked, then block until its
# thread completes and is shutdown completely to avoid thread leaks.
record_worker_sync.join()

View file

@ -31,6 +31,7 @@ from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
STREAM_SOURCE = "some-stream-source"
# Formats here are arbitrary, not exercised by tests
STREAM_OUTPUT_FORMAT = "hls"
AUDIO_STREAM_FORMAT = "mp3"
VIDEO_STREAM_FORMAT = "h264"
VIDEO_FRAME_RATE = 12
@ -187,7 +188,7 @@ class MockPyAv:
async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE)
stream.hls_output()
stream.add_provider(STREAM_OUTPUT_FORMAT)
if not py_av:
py_av = MockPyAv()
@ -207,7 +208,7 @@ async def async_decode_stream(hass, packets, py_av=None):
async def test_stream_open_fails(hass):
"""Test failure on stream open."""
stream = Stream(hass, STREAM_SOURCE)
stream.hls_output()
stream.add_provider(STREAM_OUTPUT_FORMAT)
with patch("av.open") as av_open:
av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(stream.outputs)
@ -484,7 +485,7 @@ async def test_stream_stopped_while_decoding(hass):
worker_wake = threading.Event()
stream = Stream(hass, STREAM_SOURCE)
stream.hls_output()
stream.add_provider(STREAM_OUTPUT_FORMAT)
py_av = MockPyAv()
py_av.container.packets = PacketSequence(TEST_SEQUENCE_LENGTH)
@ -511,7 +512,7 @@ async def test_update_stream_source(hass):
worker_wake = threading.Event()
stream = Stream(hass, STREAM_SOURCE)
stream.hls_output()
stream.add_provider(STREAM_OUTPUT_FORMAT)
# Note that keepalive is not set here. The stream is "restarted" even though
# it is not stopping due to failure.