Rollback stream StreamOutput refactoring in PR#46610 (#46684)
* Rollback PR#46610 * Update stream tests post-merge
This commit is contained in:
parent
788134cbc4
commit
4aa4f7e285
10 changed files with 295 additions and 240 deletions
|
@ -24,11 +24,7 @@ from homeassistant.components.media_player.const import (
|
|||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.components.stream import Stream, create_stream
|
||||
from homeassistant.components.stream.const import (
|
||||
FORMAT_CONTENT_TYPE,
|
||||
HLS_OUTPUT,
|
||||
OUTPUT_FORMATS,
|
||||
)
|
||||
from homeassistant.components.stream.const import FORMAT_CONTENT_TYPE, OUTPUT_FORMATS
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_FILENAME,
|
||||
|
@ -259,7 +255,7 @@ async def async_setup(hass, config):
|
|||
if not stream:
|
||||
continue
|
||||
stream.keepalive = True
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream)
|
||||
|
@ -707,8 +703,6 @@ async def async_handle_play_stream_service(camera, service_call):
|
|||
|
||||
|
||||
async def _async_stream_endpoint_url(hass, camera, fmt):
|
||||
if fmt != HLS_OUTPUT:
|
||||
raise ValueError("Only format {HLS_OUTPUT} is supported")
|
||||
stream = await camera.create_stream()
|
||||
if not stream:
|
||||
raise HomeAssistantError(
|
||||
|
@ -719,9 +713,9 @@ async def _async_stream_endpoint_url(hass, camera, fmt):
|
|||
camera_prefs = hass.data[DATA_CAMERA_PREFS].get(camera.entity_id)
|
||||
stream.keepalive = camera_prefs.preload_stream
|
||||
|
||||
stream.hls_output()
|
||||
stream.add_provider(fmt)
|
||||
stream.start()
|
||||
return stream.endpoint_url()
|
||||
return stream.endpoint_url(fmt)
|
||||
|
||||
|
||||
async def async_handle_record_service(camera, call):
|
||||
|
|
|
@ -7,25 +7,25 @@ a new Stream object. Stream manages:
|
|||
- Home Assistant URLs for viewing a stream
|
||||
- Access tokens for URLs for viewing a stream
|
||||
|
||||
A Stream consists of a background worker and multiple output streams (e.g. hls
|
||||
and recorder). The worker has a callback to retrieve the current active output
|
||||
streams where it writes the decoded output packets. The HLS stream has an
|
||||
inactivity idle timeout that expires the access token. When all output streams
|
||||
are inactive, the background worker is shut down. Alternatively, a Stream
|
||||
can be configured with keepalive to always keep workers active.
|
||||
A Stream consists of a background worker, and one or more output formats each
|
||||
with their own idle timeout managed by the stream component. When an output
|
||||
format is no longer in use, the stream component will expire it. When there
|
||||
are no active output formats, the background worker is shut down and access
|
||||
tokens are expired. Alternatively, a Stream can be configured with keepalive
|
||||
to always keep workers active.
|
||||
"""
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
from types import MappingProxyType
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import (
|
||||
ATTR_HLS_ENDPOINT,
|
||||
ATTR_ENDPOINTS,
|
||||
ATTR_STREAMS,
|
||||
DOMAIN,
|
||||
MAX_SEGMENTS,
|
||||
|
@ -33,8 +33,8 @@ from .const import (
|
|||
STREAM_RESTART_INCREMENT,
|
||||
STREAM_RESTART_RESET_TIME,
|
||||
)
|
||||
from .core import IdleTimer, StreamOutput
|
||||
from .hls import HlsStreamOutput, async_setup_hls
|
||||
from .core import PROVIDERS, IdleTimer
|
||||
from .hls import async_setup_hls
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -75,10 +75,12 @@ async def async_setup(hass, config):
|
|||
from .recorder import async_setup_recorder
|
||||
|
||||
hass.data[DOMAIN] = {}
|
||||
hass.data[DOMAIN][ATTR_ENDPOINTS] = {}
|
||||
hass.data[DOMAIN][ATTR_STREAMS] = []
|
||||
|
||||
# Setup HLS
|
||||
hass.data[DOMAIN][ATTR_HLS_ENDPOINT] = async_setup_hls(hass)
|
||||
hls_endpoint = async_setup_hls(hass)
|
||||
hass.data[DOMAIN][ATTR_ENDPOINTS]["hls"] = hls_endpoint
|
||||
|
||||
# Setup Recorder
|
||||
async_setup_recorder(hass)
|
||||
|
@ -87,6 +89,7 @@ async def async_setup(hass, config):
|
|||
def shutdown(event):
|
||||
"""Stop all stream workers."""
|
||||
for stream in hass.data[DOMAIN][ATTR_STREAMS]:
|
||||
stream.keepalive = False
|
||||
stream.stop()
|
||||
_LOGGER.info("Stopped stream workers")
|
||||
|
||||
|
@ -107,54 +110,58 @@ class Stream:
|
|||
self.access_token = None
|
||||
self._thread = None
|
||||
self._thread_quit = threading.Event()
|
||||
self._hls = None
|
||||
self._hls_timer = None
|
||||
self._recorder = None
|
||||
self._outputs = {}
|
||||
self._fast_restart_once = False
|
||||
|
||||
if self.options is None:
|
||||
self.options = {}
|
||||
|
||||
def endpoint_url(self) -> str:
|
||||
"""Start the stream and returns a url for the hls endpoint."""
|
||||
if not self._hls:
|
||||
raise ValueError("Stream is not configured for hls")
|
||||
def endpoint_url(self, fmt):
|
||||
"""Start the stream and returns a url for the output format."""
|
||||
if fmt not in self._outputs:
|
||||
raise ValueError(f"Stream is not configured for format '{fmt}'")
|
||||
if not self.access_token:
|
||||
self.access_token = secrets.token_hex()
|
||||
return self.hass.data[DOMAIN][ATTR_HLS_ENDPOINT].format(self.access_token)
|
||||
return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token)
|
||||
|
||||
def outputs(self) -> List[StreamOutput]:
|
||||
"""Return the active stream outputs."""
|
||||
return [output for output in [self._hls, self._recorder] if output]
|
||||
def outputs(self):
|
||||
"""Return a copy of the stream outputs."""
|
||||
# A copy is returned so the caller can iterate through the outputs
|
||||
# without concern about self._outputs being modified from another thread.
|
||||
return MappingProxyType(self._outputs.copy())
|
||||
|
||||
def hls_output(self) -> StreamOutput:
|
||||
"""Return the hls output stream, creating if not already active."""
|
||||
if not self._hls:
|
||||
self._hls = HlsStreamOutput(self.hass)
|
||||
self._hls_timer = IdleTimer(self.hass, OUTPUT_IDLE_TIMEOUT, self._hls_idle)
|
||||
self._hls_timer.start()
|
||||
self._hls_timer.awake()
|
||||
return self._hls
|
||||
def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT):
|
||||
"""Add provider output stream."""
|
||||
if not self._outputs.get(fmt):
|
||||
|
||||
@callback
|
||||
def _hls_idle(self):
|
||||
"""Reset access token and cleanup stream due to inactivity."""
|
||||
self.access_token = None
|
||||
if not self.keepalive:
|
||||
if self._hls:
|
||||
self._hls.cleanup()
|
||||
self._hls = None
|
||||
self._hls_timer = None
|
||||
self._check_idle()
|
||||
@callback
|
||||
def idle_callback():
|
||||
if not self.keepalive and fmt in self._outputs:
|
||||
self.remove_provider(self._outputs[fmt])
|
||||
self.check_idle()
|
||||
|
||||
def _check_idle(self):
|
||||
"""Check if all outputs are idle and shut down worker."""
|
||||
if self.keepalive or self.outputs():
|
||||
return
|
||||
self.stop()
|
||||
provider = PROVIDERS[fmt](
|
||||
self.hass, IdleTimer(self.hass, timeout, idle_callback)
|
||||
)
|
||||
self._outputs[fmt] = provider
|
||||
return self._outputs[fmt]
|
||||
|
||||
def remove_provider(self, provider):
|
||||
"""Remove provider output stream."""
|
||||
if provider.name in self._outputs:
|
||||
self._outputs[provider.name].cleanup()
|
||||
del self._outputs[provider.name]
|
||||
|
||||
if not self._outputs:
|
||||
self.stop()
|
||||
|
||||
def check_idle(self):
|
||||
"""Reset access token if all providers are idle."""
|
||||
if all([p.idle for p in self._outputs.values()]):
|
||||
self.access_token = None
|
||||
|
||||
def start(self):
|
||||
"""Start stream decode worker."""
|
||||
"""Start a stream."""
|
||||
if self._thread is None or not self._thread.is_alive():
|
||||
if self._thread is not None:
|
||||
# The thread must have crashed/exited. Join to clean up the
|
||||
|
@ -210,21 +217,21 @@ class Stream:
|
|||
|
||||
def _worker_finished(self):
|
||||
"""Schedule cleanup of all outputs."""
|
||||
self.hass.loop.call_soon_threadsafe(self.stop)
|
||||
|
||||
@callback
|
||||
def remove_outputs():
|
||||
for provider in self.outputs().values():
|
||||
self.remove_provider(provider)
|
||||
|
||||
self.hass.loop.call_soon_threadsafe(remove_outputs)
|
||||
|
||||
def stop(self):
|
||||
"""Remove outputs and access token."""
|
||||
self._outputs = {}
|
||||
self.access_token = None
|
||||
if self._hls_timer:
|
||||
self._hls_timer.clear()
|
||||
self._hls_timer = None
|
||||
if self._hls:
|
||||
self._hls.cleanup()
|
||||
self._hls = None
|
||||
if self._recorder:
|
||||
self._recorder.save()
|
||||
self._recorder = None
|
||||
self._stop()
|
||||
|
||||
if not self.keepalive:
|
||||
self._stop()
|
||||
|
||||
def _stop(self):
|
||||
"""Stop worker thread."""
|
||||
|
@ -237,35 +244,25 @@ class Stream:
|
|||
async def async_record(self, video_path, duration=30, lookback=5):
|
||||
"""Make a .mp4 recording from a provided stream."""
|
||||
|
||||
# Keep import here so that we can import stream integration without installing reqs
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .recorder import RecorderOutput
|
||||
|
||||
# Check for file access
|
||||
if not self.hass.config.is_allowed_path(video_path):
|
||||
raise HomeAssistantError(f"Can't write {video_path}, no access to path!")
|
||||
|
||||
# Add recorder
|
||||
if self._recorder:
|
||||
recorder = self.outputs().get("recorder")
|
||||
if recorder:
|
||||
raise HomeAssistantError(
|
||||
f"Stream already recording to {self._recorder.video_path}!"
|
||||
f"Stream already recording to {recorder.video_path}!"
|
||||
)
|
||||
self._recorder = RecorderOutput(self.hass)
|
||||
self._recorder.video_path = video_path
|
||||
recorder = self.add_provider("recorder", timeout=duration)
|
||||
recorder.video_path = video_path
|
||||
|
||||
self.start()
|
||||
|
||||
# Take advantage of lookback
|
||||
if lookback > 0 and self._hls:
|
||||
num_segments = min(int(lookback // self._hls.target_duration), MAX_SEGMENTS)
|
||||
hls = self.outputs().get("hls")
|
||||
if lookback > 0 and hls:
|
||||
num_segments = min(int(lookback // hls.target_duration), MAX_SEGMENTS)
|
||||
# Wait for latest segment, then add the lookback
|
||||
await self._hls.recv()
|
||||
self._recorder.prepend(list(self._hls.get_segment())[-num_segments:])
|
||||
|
||||
@callback
|
||||
def save_recording():
|
||||
if self._recorder:
|
||||
self._recorder.save()
|
||||
self._recorder = None
|
||||
self._check_idle()
|
||||
|
||||
IdleTimer(self.hass, duration, save_recording).start()
|
||||
await hls.recv()
|
||||
recorder.prepend(list(hls.get_segment())[-num_segments:])
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
"""Constants for Stream component."""
|
||||
DOMAIN = "stream"
|
||||
|
||||
ATTR_HLS_ENDPOINT = "hls_endpoint"
|
||||
ATTR_ENDPOINTS = "endpoints"
|
||||
ATTR_STREAMS = "streams"
|
||||
|
||||
HLS_OUTPUT = "hls"
|
||||
OUTPUT_FORMATS = [HLS_OUTPUT]
|
||||
OUTPUT_CONTAINER_FORMAT = "mp4"
|
||||
OUTPUT_VIDEO_CODECS = {"hevc", "h264"}
|
||||
OUTPUT_AUDIO_CODECS = {"aac", "mp3"}
|
||||
OUTPUT_FORMATS = ["hls"]
|
||||
|
||||
FORMAT_CONTENT_TYPE = {"hls": "application/vnd.apple.mpegurl"}
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Provides core stream functionality."""
|
||||
import abc
|
||||
import asyncio
|
||||
from collections import deque
|
||||
import io
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from aiohttp import web
|
||||
import attr
|
||||
|
@ -9,8 +10,11 @@ import attr
|
|||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from .const import ATTR_STREAMS, DOMAIN
|
||||
from .const import ATTR_STREAMS, DOMAIN, MAX_SEGMENTS
|
||||
|
||||
PROVIDERS = Registry()
|
||||
|
||||
|
||||
@attr.s
|
||||
|
@ -76,18 +80,86 @@ class IdleTimer:
|
|||
self._callback()
|
||||
|
||||
|
||||
class StreamOutput(abc.ABC):
|
||||
class StreamOutput:
|
||||
"""Represents a stream output."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant):
|
||||
def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
|
||||
"""Initialize a stream output."""
|
||||
self._hass = hass
|
||||
self._idle_timer = idle_timer
|
||||
self._cursor = None
|
||||
self._event = asyncio.Event()
|
||||
self._segments = deque(maxlen=MAX_SEGMENTS)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return provider name."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
"""Return True if the output is idle."""
|
||||
return self._idle_timer.idle
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
"""Return container format."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def audio_codecs(self) -> str:
|
||||
"""Return desired audio codecs."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def video_codecs(self) -> tuple:
|
||||
"""Return desired video codecs."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def container_options(self) -> Callable[[int], dict]:
|
||||
"""Return Callable which takes a sequence number and returns container options."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def segments(self) -> List[int]:
|
||||
"""Return current sequence from segments."""
|
||||
return [s.sequence for s in self._segments]
|
||||
|
||||
@property
|
||||
def target_duration(self) -> int:
|
||||
"""Return the max duration of any given segment in seconds."""
|
||||
segment_length = len(self._segments)
|
||||
if not segment_length:
|
||||
return 1
|
||||
durations = [s.duration for s in self._segments]
|
||||
return round(max(durations)) or 1
|
||||
|
||||
def get_segment(self, sequence: int = None) -> Any:
|
||||
"""Retrieve a specific segment, or the whole list."""
|
||||
self._idle_timer.awake()
|
||||
|
||||
if not sequence:
|
||||
return self._segments
|
||||
|
||||
for segment in self._segments:
|
||||
if segment.sequence == sequence:
|
||||
return segment
|
||||
return None
|
||||
|
||||
async def recv(self) -> Segment:
|
||||
"""Wait for and retrieve the latest segment."""
|
||||
last_segment = max(self.segments, default=0)
|
||||
if self._cursor is None or self._cursor <= last_segment:
|
||||
await self._event.wait()
|
||||
|
||||
if not self._segments:
|
||||
return None
|
||||
|
||||
segment = self.get_segment()[-1]
|
||||
self._cursor = segment.sequence
|
||||
return segment
|
||||
|
||||
def put(self, segment: Segment) -> None:
|
||||
"""Store output."""
|
||||
self._hass.loop.call_soon_threadsafe(self._async_put, segment)
|
||||
|
@ -95,6 +167,17 @@ class StreamOutput(abc.ABC):
|
|||
@callback
|
||||
def _async_put(self, segment: Segment) -> None:
|
||||
"""Store output from event loop."""
|
||||
# Start idle timeout when we start receiving data
|
||||
self._idle_timer.start()
|
||||
self._segments.append(segment)
|
||||
self._event.set()
|
||||
self._event.clear()
|
||||
|
||||
def cleanup(self):
|
||||
"""Handle cleanup."""
|
||||
self._event.set()
|
||||
self._idle_timer.clear()
|
||||
self._segments = deque(maxlen=MAX_SEGMENTS)
|
||||
|
||||
|
||||
class StreamView(HomeAssistantView):
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
"""Provide functionality to stream HLS."""
|
||||
import asyncio
|
||||
from collections import deque
|
||||
import io
|
||||
from typing import Any, Callable, List
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import FORMAT_CONTENT_TYPE, MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS
|
||||
from .core import Segment, StreamOutput, StreamView
|
||||
from .const import FORMAT_CONTENT_TYPE, NUM_PLAYLIST_SEGMENTS
|
||||
from .core import PROVIDERS, StreamOutput, StreamView
|
||||
from .fmp4utils import get_codec_string, get_init, get_m4s
|
||||
|
||||
|
||||
|
@ -50,7 +48,8 @@ class HlsMasterPlaylistView(StreamView):
|
|||
|
||||
async def handle(self, request, stream, sequence):
|
||||
"""Return m3u8 playlist."""
|
||||
track = stream.hls_output()
|
||||
track = stream.add_provider("hls")
|
||||
stream.start()
|
||||
# Wait for a segment to be ready
|
||||
if not track.segments:
|
||||
if not await track.recv():
|
||||
|
@ -109,7 +108,8 @@ class HlsPlaylistView(StreamView):
|
|||
|
||||
async def handle(self, request, stream, sequence):
|
||||
"""Return m3u8 playlist."""
|
||||
track = stream.hls_output()
|
||||
track = stream.add_provider("hls")
|
||||
stream.start()
|
||||
# Wait for a segment to be ready
|
||||
if not track.segments:
|
||||
if not await track.recv():
|
||||
|
@ -127,7 +127,7 @@ class HlsInitView(StreamView):
|
|||
|
||||
async def handle(self, request, stream, sequence):
|
||||
"""Return init.mp4."""
|
||||
track = stream.hls_output()
|
||||
track = stream.add_provider("hls")
|
||||
segments = track.get_segment()
|
||||
if not segments:
|
||||
return web.HTTPNotFound()
|
||||
|
@ -144,7 +144,7 @@ class HlsSegmentView(StreamView):
|
|||
|
||||
async def handle(self, request, stream, sequence):
|
||||
"""Return fmp4 segment."""
|
||||
track = stream.hls_output()
|
||||
track = stream.add_provider("hls")
|
||||
segment = track.get_segment(int(sequence))
|
||||
if not segment:
|
||||
return web.HTTPNotFound()
|
||||
|
@ -155,15 +155,29 @@ class HlsSegmentView(StreamView):
|
|||
)
|
||||
|
||||
|
||||
@PROVIDERS.register("hls")
|
||||
class HlsStreamOutput(StreamOutput):
|
||||
"""Represents HLS Output formats."""
|
||||
|
||||
def __init__(self, hass) -> None:
|
||||
"""Initialize HlsStreamOutput."""
|
||||
super().__init__(hass)
|
||||
self._cursor = None
|
||||
self._event = asyncio.Event()
|
||||
self._segments = deque(maxlen=MAX_SEGMENTS)
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return provider name."""
|
||||
return "hls"
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
"""Return container format."""
|
||||
return "mp4"
|
||||
|
||||
@property
|
||||
def audio_codecs(self) -> str:
|
||||
"""Return desired audio codecs."""
|
||||
return {"aac", "mp3"}
|
||||
|
||||
@property
|
||||
def video_codecs(self) -> tuple:
|
||||
"""Return desired video codecs."""
|
||||
return {"hevc", "h264"}
|
||||
|
||||
@property
|
||||
def container_options(self) -> Callable[[int], dict]:
|
||||
|
@ -174,51 +188,3 @@ class HlsStreamOutput(StreamOutput):
|
|||
"avoid_negative_ts": "make_non_negative",
|
||||
"fragment_index": str(sequence),
|
||||
}
|
||||
|
||||
@property
|
||||
def segments(self) -> List[int]:
|
||||
"""Return current sequence from segments."""
|
||||
return [s.sequence for s in self._segments]
|
||||
|
||||
@property
|
||||
def target_duration(self) -> int:
|
||||
"""Return the max duration of any given segment in seconds."""
|
||||
segment_length = len(self._segments)
|
||||
if not segment_length:
|
||||
return 1
|
||||
durations = [s.duration for s in self._segments]
|
||||
return round(max(durations)) or 1
|
||||
|
||||
def get_segment(self, sequence: int = None) -> Any:
|
||||
"""Retrieve a specific segment, or the whole list."""
|
||||
if not sequence:
|
||||
return self._segments
|
||||
|
||||
for segment in self._segments:
|
||||
if segment.sequence == sequence:
|
||||
return segment
|
||||
return None
|
||||
|
||||
async def recv(self) -> Segment:
|
||||
"""Wait for and retrieve the latest segment."""
|
||||
last_segment = max(self.segments, default=0)
|
||||
if self._cursor is None or self._cursor <= last_segment:
|
||||
await self._event.wait()
|
||||
|
||||
if not self._segments:
|
||||
return None
|
||||
|
||||
segment = self.get_segment()[-1]
|
||||
self._cursor = segment.sequence
|
||||
return segment
|
||||
|
||||
def _async_put(self, segment: Segment) -> None:
|
||||
"""Store output from event loop."""
|
||||
self._segments.append(segment)
|
||||
self._event.set()
|
||||
self._event.clear()
|
||||
|
||||
def cleanup(self):
|
||||
"""Handle cleanup."""
|
||||
self._event.set()
|
||||
self._segments = deque(maxlen=MAX_SEGMENTS)
|
||||
|
|
|
@ -6,10 +6,9 @@ from typing import List
|
|||
|
||||
import av
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .const import OUTPUT_CONTAINER_FORMAT
|
||||
from .core import Segment, StreamOutput
|
||||
from .core import PROVIDERS, IdleTimer, Segment, StreamOutput
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -19,7 +18,7 @@ def async_setup_recorder(hass):
|
|||
"""Only here so Provider Registry works."""
|
||||
|
||||
|
||||
def recorder_save_worker(file_out: str, segments: List[Segment], container_format):
|
||||
def recorder_save_worker(file_out: str, segments: List[Segment], container_format: str):
|
||||
"""Handle saving stream."""
|
||||
if not os.path.exists(os.path.dirname(file_out)):
|
||||
os.makedirs(os.path.dirname(file_out), exist_ok=True)
|
||||
|
@ -76,31 +75,51 @@ def recorder_save_worker(file_out: str, segments: List[Segment], container_forma
|
|||
output.close()
|
||||
|
||||
|
||||
@PROVIDERS.register("recorder")
|
||||
class RecorderOutput(StreamOutput):
|
||||
"""Represents HLS Output formats."""
|
||||
|
||||
def __init__(self, hass) -> None:
|
||||
def __init__(self, hass: HomeAssistant, idle_timer: IdleTimer) -> None:
|
||||
"""Initialize recorder output."""
|
||||
super().__init__(hass)
|
||||
super().__init__(hass, idle_timer)
|
||||
self.video_path = None
|
||||
self._segments = []
|
||||
|
||||
def _async_put(self, segment: Segment) -> None:
|
||||
"""Store output."""
|
||||
self._segments.append(segment)
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return provider name."""
|
||||
return "recorder"
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
"""Return container format."""
|
||||
return "mp4"
|
||||
|
||||
@property
|
||||
def audio_codecs(self) -> str:
|
||||
"""Return desired audio codec."""
|
||||
return {"aac", "mp3"}
|
||||
|
||||
@property
|
||||
def video_codecs(self) -> tuple:
|
||||
"""Return desired video codecs."""
|
||||
return {"hevc", "h264"}
|
||||
|
||||
def prepend(self, segments: List[Segment]) -> None:
|
||||
"""Prepend segments to existing list."""
|
||||
segments = [s for s in segments if s.sequence not in self._segments]
|
||||
own_segments = self.segments
|
||||
segments = [s for s in segments if s.sequence not in own_segments]
|
||||
self._segments = segments + self._segments
|
||||
|
||||
def save(self):
|
||||
def cleanup(self):
|
||||
"""Write recording and clean up."""
|
||||
_LOGGER.debug("Starting recorder worker thread")
|
||||
thread = threading.Thread(
|
||||
name="recorder_save_worker",
|
||||
target=recorder_save_worker,
|
||||
args=(self.video_path, self._segments, OUTPUT_CONTAINER_FORMAT),
|
||||
args=(self.video_path, self._segments, self.format),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
super().cleanup()
|
||||
self._segments = []
|
||||
|
|
|
@ -9,9 +9,6 @@ from .const import (
|
|||
MAX_MISSING_DTS,
|
||||
MAX_TIMESTAMP_GAP,
|
||||
MIN_SEGMENT_DURATION,
|
||||
OUTPUT_AUDIO_CODECS,
|
||||
OUTPUT_CONTAINER_FORMAT,
|
||||
OUTPUT_VIDEO_CODECS,
|
||||
PACKETS_TO_WAIT_FOR_AUDIO,
|
||||
STREAM_TIMEOUT,
|
||||
)
|
||||
|
@ -32,7 +29,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
|
|||
output = av.open(
|
||||
segment,
|
||||
mode="w",
|
||||
format=OUTPUT_CONTAINER_FORMAT,
|
||||
format=stream_output.format,
|
||||
container_options={
|
||||
"video_track_timescale": str(int(1 / video_stream.time_base)),
|
||||
**container_options,
|
||||
|
@ -41,7 +38,7 @@ def create_stream_buffer(stream_output, video_stream, audio_stream, sequence):
|
|||
vstream = output.add_stream(template=video_stream)
|
||||
# Check if audio is requested
|
||||
astream = None
|
||||
if audio_stream and audio_stream.name in OUTPUT_AUDIO_CODECS:
|
||||
if audio_stream and audio_stream.name in stream_output.audio_codecs:
|
||||
astream = output.add_stream(template=audio_stream)
|
||||
return StreamBuffer(segment, output, vstream, astream)
|
||||
|
||||
|
@ -74,8 +71,8 @@ class SegmentBuffer:
|
|||
# Fetch the latest StreamOutputs, which may have changed since the
|
||||
# worker started.
|
||||
self._outputs = []
|
||||
for stream_output in self._outputs_callback():
|
||||
if self._video_stream.name not in OUTPUT_VIDEO_CODECS:
|
||||
for stream_output in self._outputs_callback().values():
|
||||
if self._video_stream.name not in stream_output.video_codecs:
|
||||
continue
|
||||
buffer = create_stream_buffer(
|
||||
stream_output, self._video_stream, self._audio_stream, self._sequence
|
||||
|
|
|
@ -45,7 +45,7 @@ def hls_stream(hass, hass_client):
|
|||
|
||||
async def create_client_for_stream(stream):
|
||||
http_client = await hass_client()
|
||||
parsed_url = urlparse(stream.endpoint_url())
|
||||
parsed_url = urlparse(stream.endpoint_url("hls"))
|
||||
return HlsClient(http_client, parsed_url)
|
||||
|
||||
return create_client_for_stream
|
||||
|
@ -91,7 +91,7 @@ async def test_hls_stream(hass, hls_stream, stream_worker_sync):
|
|||
stream = create_stream(hass, source)
|
||||
|
||||
# Request stream
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
@ -132,9 +132,9 @@ async def test_stream_timeout(hass, hass_client, stream_worker_sync):
|
|||
stream = create_stream(hass, source)
|
||||
|
||||
# Request stream
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
url = stream.endpoint_url()
|
||||
url = stream.endpoint_url("hls")
|
||||
|
||||
http_client = await hass_client()
|
||||
|
||||
|
@ -174,16 +174,8 @@ async def test_stream_timeout_after_stop(hass, hass_client, stream_worker_sync):
|
|||
stream = create_stream(hass, source)
|
||||
|
||||
# Request stream
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
url = stream.endpoint_url()
|
||||
|
||||
http_client = await hass_client()
|
||||
|
||||
# Fetch playlist
|
||||
parsed_url = urlparse(url)
|
||||
playlist_response = await http_client.get(parsed_url.path)
|
||||
assert playlist_response.status == 200
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
|
@ -204,10 +196,12 @@ async def test_stream_ended(hass, stream_worker_sync):
|
|||
# Setup demo HLS track
|
||||
source = generate_h264_video()
|
||||
stream = create_stream(hass, source)
|
||||
track = stream.add_provider("hls")
|
||||
|
||||
# Request stream
|
||||
track = stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
stream.endpoint_url("hls")
|
||||
|
||||
# Run it dead
|
||||
while True:
|
||||
|
@ -233,7 +227,7 @@ async def test_stream_keepalive(hass):
|
|||
# Setup demo HLS track
|
||||
source = "test_stream_keepalive_source"
|
||||
stream = create_stream(hass, source)
|
||||
track = stream.hls_output()
|
||||
track = stream.add_provider("hls")
|
||||
track.num_segments = 2
|
||||
stream.start()
|
||||
|
||||
|
@ -264,12 +258,12 @@ async def test_stream_keepalive(hass):
|
|||
stream.stop()
|
||||
|
||||
|
||||
async def test_hls_playlist_view_no_output(hass, hls_stream):
|
||||
async def test_hls_playlist_view_no_output(hass, hass_client, hls_stream):
|
||||
"""Test rendering the hls playlist with no output segments."""
|
||||
await async_setup_component(hass, "stream", {"stream": {}})
|
||||
|
||||
stream = create_stream(hass, STREAM_SOURCE)
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
|
@ -284,7 +278,7 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
|
|||
|
||||
stream = create_stream(hass, STREAM_SOURCE)
|
||||
stream_worker_sync.pause()
|
||||
hls = stream.hls_output()
|
||||
hls = stream.add_provider("hls")
|
||||
|
||||
hls.put(Segment(1, SEQUENCE_BYTES, DURATION))
|
||||
await hass.async_block_till_done()
|
||||
|
@ -313,7 +307,7 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
|
|||
|
||||
stream = create_stream(hass, STREAM_SOURCE)
|
||||
stream_worker_sync.pause()
|
||||
hls = stream.hls_output()
|
||||
hls = stream.add_provider("hls")
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
|
@ -358,7 +352,7 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
|
|||
|
||||
stream = create_stream(hass, STREAM_SOURCE)
|
||||
stream_worker_sync.pause()
|
||||
hls = stream.hls_output()
|
||||
hls = stream.add_provider("hls")
|
||||
|
||||
hls.put(Segment(1, SEQUENCE_BYTES, DURATION, stream_id=0))
|
||||
hls.put(Segment(2, SEQUENCE_BYTES, DURATION, stream_id=0))
|
||||
|
@ -388,7 +382,7 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
|
|||
|
||||
stream = create_stream(hass, STREAM_SOURCE)
|
||||
stream_worker_sync.pause()
|
||||
hls = stream.hls_output()
|
||||
hls = stream.add_provider("hls")
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
"""The tests for hls streams."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import async_timeout
|
||||
import av
|
||||
import pytest
|
||||
|
||||
|
@ -34,30 +32,23 @@ class SaveRecordWorkerSync:
|
|||
def __init__(self):
|
||||
"""Initialize SaveRecordWorkerSync."""
|
||||
self.reset()
|
||||
self._segments = None
|
||||
|
||||
def recorder_save_worker(self, file_out, segments, container_format):
|
||||
def recorder_save_worker(self, *args, **kwargs):
|
||||
"""Mock method for patch."""
|
||||
logging.debug("recorder_save_worker thread started")
|
||||
self._segments = segments
|
||||
assert self._save_thread is None
|
||||
self._save_thread = threading.current_thread()
|
||||
self._save_event.set()
|
||||
|
||||
async def get_segments(self):
|
||||
"""Verify save worker thread was invoked and return saved segments."""
|
||||
with async_timeout.timeout(TEST_TIMEOUT):
|
||||
assert await self._save_event.wait()
|
||||
return self._segments
|
||||
|
||||
def join(self):
|
||||
"""Block until the record worker thread exist to ensure cleanup."""
|
||||
"""Verify save worker was invoked and block on shutdown."""
|
||||
assert self._save_event.wait(timeout=TEST_TIMEOUT)
|
||||
self._save_thread.join()
|
||||
|
||||
def reset(self):
|
||||
"""Reset callback state for reuse in tests."""
|
||||
self._save_thread = None
|
||||
self._save_event = asyncio.Event()
|
||||
self._save_event = threading.Event()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -72,7 +63,7 @@ def record_worker_sync(hass):
|
|||
yield sync
|
||||
|
||||
|
||||
async def test_record_stream(hass, hass_client, record_worker_sync):
|
||||
async def test_record_stream(hass, hass_client, stream_worker_sync, record_worker_sync):
|
||||
"""
|
||||
Test record stream.
|
||||
|
||||
|
@ -82,14 +73,28 @@ async def test_record_stream(hass, hass_client, record_worker_sync):
|
|||
"""
|
||||
await async_setup_component(hass, "stream", {"stream": {}})
|
||||
|
||||
stream_worker_sync.pause()
|
||||
|
||||
# Setup demo track
|
||||
source = generate_h264_video()
|
||||
stream = create_stream(hass, source)
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path")
|
||||
|
||||
segments = await record_worker_sync.get_segments()
|
||||
assert len(segments) > 1
|
||||
recorder = stream.add_provider("recorder")
|
||||
while True:
|
||||
segment = await recorder.recv()
|
||||
if not segment:
|
||||
break
|
||||
segments = segment.sequence
|
||||
if segments > 1:
|
||||
stream_worker_sync.resume()
|
||||
|
||||
stream.stop()
|
||||
assert segments > 1
|
||||
|
||||
# Verify that the save worker was invoked, then block until its
|
||||
# thread completes and is shutdown completely to avoid thread leaks.
|
||||
record_worker_sync.join()
|
||||
|
||||
|
||||
|
@ -102,24 +107,19 @@ async def test_record_lookback(
|
|||
source = generate_h264_video()
|
||||
stream = create_stream(hass, source)
|
||||
|
||||
# Don't let the stream finish (and clean itself up) until the test has had
|
||||
# a chance to perform lookback
|
||||
stream_worker_sync.pause()
|
||||
|
||||
# Start an HLS feed to enable lookback
|
||||
stream.hls_output()
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path", lookback=4)
|
||||
|
||||
# This test does not need recorder cleanup since it is not fully exercised
|
||||
stream_worker_sync.resume()
|
||||
|
||||
stream.stop()
|
||||
|
||||
|
||||
async def test_recorder_timeout(
|
||||
hass, hass_client, stream_worker_sync, record_worker_sync
|
||||
):
|
||||
async def test_recorder_timeout(hass, hass_client, stream_worker_sync):
|
||||
"""
|
||||
Test recorder timeout.
|
||||
|
||||
|
@ -137,8 +137,9 @@ async def test_recorder_timeout(
|
|||
stream = create_stream(hass, source)
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path")
|
||||
recorder = stream.add_provider("recorder")
|
||||
|
||||
assert not mock_timeout.called
|
||||
await recorder.recv()
|
||||
|
||||
# Wait a minute
|
||||
future = dt_util.utcnow() + timedelta(minutes=1)
|
||||
|
@ -148,10 +149,6 @@ async def test_recorder_timeout(
|
|||
assert mock_timeout.called
|
||||
|
||||
stream_worker_sync.resume()
|
||||
# Verify worker is invoked, and do clean shutdown of worker thread
|
||||
await record_worker_sync.get_segments()
|
||||
record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
@ -183,7 +180,9 @@ async def test_recorder_save(tmpdir):
|
|||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
async def test_record_stream_audio(hass, hass_client, record_worker_sync):
|
||||
async def test_record_stream_audio(
|
||||
hass, hass_client, stream_worker_sync, record_worker_sync
|
||||
):
|
||||
"""
|
||||
Test treatment of different audio inputs.
|
||||
|
||||
|
@ -199,6 +198,7 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync):
|
|||
(None, 0), # no audio stream
|
||||
):
|
||||
record_worker_sync.reset()
|
||||
stream_worker_sync.pause()
|
||||
|
||||
# Setup demo track
|
||||
source = generate_h264_video(
|
||||
|
@ -207,14 +207,22 @@ async def test_record_stream_audio(hass, hass_client, record_worker_sync):
|
|||
stream = create_stream(hass, source)
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path")
|
||||
recorder = stream.add_provider("recorder")
|
||||
|
||||
segments = await record_worker_sync.get_segments()
|
||||
last_segment = segments[-1]
|
||||
while True:
|
||||
segment = await recorder.recv()
|
||||
if not segment:
|
||||
break
|
||||
last_segment = segment
|
||||
stream_worker_sync.resume()
|
||||
|
||||
result = av.open(last_segment.segment, "r", format="mp4")
|
||||
|
||||
assert len(result.streams.audio) == expected_audio_streams
|
||||
result.close()
|
||||
|
||||
stream.stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify that the save worker was invoked, then block until its
|
||||
# thread completes and is shutdown completely to avoid thread leaks.
|
||||
record_worker_sync.join()
|
||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
|
|||
|
||||
STREAM_SOURCE = "some-stream-source"
|
||||
# Formats here are arbitrary, not exercised by tests
|
||||
STREAM_OUTPUT_FORMAT = "hls"
|
||||
AUDIO_STREAM_FORMAT = "mp3"
|
||||
VIDEO_STREAM_FORMAT = "h264"
|
||||
VIDEO_FRAME_RATE = 12
|
||||
|
@ -187,7 +188,7 @@ class MockPyAv:
|
|||
async def async_decode_stream(hass, packets, py_av=None):
|
||||
"""Start a stream worker that decodes incoming stream packets into output segments."""
|
||||
stream = Stream(hass, STREAM_SOURCE)
|
||||
stream.hls_output()
|
||||
stream.add_provider(STREAM_OUTPUT_FORMAT)
|
||||
|
||||
if not py_av:
|
||||
py_av = MockPyAv()
|
||||
|
@ -207,7 +208,7 @@ async def async_decode_stream(hass, packets, py_av=None):
|
|||
async def test_stream_open_fails(hass):
|
||||
"""Test failure on stream open."""
|
||||
stream = Stream(hass, STREAM_SOURCE)
|
||||
stream.hls_output()
|
||||
stream.add_provider(STREAM_OUTPUT_FORMAT)
|
||||
with patch("av.open") as av_open:
|
||||
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
||||
segment_buffer = SegmentBuffer(stream.outputs)
|
||||
|
@ -484,7 +485,7 @@ async def test_stream_stopped_while_decoding(hass):
|
|||
worker_wake = threading.Event()
|
||||
|
||||
stream = Stream(hass, STREAM_SOURCE)
|
||||
stream.hls_output()
|
||||
stream.add_provider(STREAM_OUTPUT_FORMAT)
|
||||
|
||||
py_av = MockPyAv()
|
||||
py_av.container.packets = PacketSequence(TEST_SEQUENCE_LENGTH)
|
||||
|
@ -511,7 +512,7 @@ async def test_update_stream_source(hass):
|
|||
worker_wake = threading.Event()
|
||||
|
||||
stream = Stream(hass, STREAM_SOURCE)
|
||||
stream.hls_output()
|
||||
stream.add_provider(STREAM_OUTPUT_FORMAT)
|
||||
# Note that keepalive is not set here. The stream is "restarted" even though
|
||||
# it is not stopping due to failure.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue