From 9a32e285749518bd2ec3dcb53c5942e75fe723dc Mon Sep 17 00:00:00 2001
From: uvjustin <46082645+uvjustin@users.noreply.github.com>
Date: Mon, 28 Sep 2020 04:38:14 +0800
Subject: [PATCH] Create master playlist for cast (#40483)

Co-authored-by: Jason Hunter <hunterjm@gmail.com>
---
 homeassistant/components/stream/fmp4utils.py | 111 ++++++++++++++++
 homeassistant/components/stream/hls.py       | 126 +++++++++++--------
 homeassistant/components/stream/recorder.py  |   2 +-
 3 files changed, 188 insertions(+), 51 deletions(-)

diff --git a/homeassistant/components/stream/fmp4utils.py b/homeassistant/components/stream/fmp4utils.py
index 00603807215..dc929e531c1 100644
--- a/homeassistant/components/stream/fmp4utils.py
+++ b/homeassistant/components/stream/fmp4utils.py
@@ -36,3 +36,114 @@ def get_m4s(segment: io.BytesIO, sequence: int) -> bytes:
     mfra_location = next(find_box(segment, b"mfra"))
     segment.seek(moof_location)
     return segment.read(mfra_location - moof_location)
+
+
+def get_codec_string(segment: io.BytesIO) -> str:
+    """Get RFC 6381 codec string."""
+    codecs = []
+
+    # Find moov
+    moov_location = next(find_box(segment, b"moov"))
+
+    # Find tracks
+    for trak_location in find_box(segment, b"trak", moov_location):
+        # Drill down to media info
+        mdia_location = next(find_box(segment, b"mdia", trak_location))
+        minf_location = next(find_box(segment, b"minf", mdia_location))
+        stbl_location = next(find_box(segment, b"stbl", minf_location))
+        stsd_location = next(find_box(segment, b"stsd", stbl_location))
+
+        # Get stsd box
+        segment.seek(stsd_location)
+        stsd_length = int.from_bytes(segment.read(4), byteorder="big")
+        segment.seek(stsd_location)
+        stsd_box = segment.read(stsd_length)
+
+        # Base Codec
+        codec = stsd_box[20:24].decode("utf-8")
+
+        # Handle H264
+        if (
+            codec in ("avc1", "avc2", "avc3", "avc4")
+            and stsd_length > 110
+            and stsd_box[106:110] == b"avcC"
+        ):
+            profile = stsd_box[111:112].hex()
+            compatibility = stsd_box[112:113].hex()
+            level = stsd_box[113:114].hex()
+            codec += "." + profile + compatibility + level
+
+        # Handle H265
+        elif (
+            codec in ("hev1", "hvc1")
+            and stsd_length > 110
+            and stsd_box[106:110] == b"hvcC"
+        ):
+            tmp_byte = int.from_bytes(stsd_box[111:112], byteorder="big")
+
+            # Profile Space
+            codec += "."
+            profile_space_map = {0: "", 1: "A", 2: "B", 3: "C"}
+            profile_space = tmp_byte >> 6
+            codec += profile_space_map[profile_space]
+            general_profile_idc = tmp_byte & 31
+            codec += str(general_profile_idc)
+
+            # Compatibility
+            codec += "."
+            general_profile_compatibility = int.from_bytes(
+                stsd_box[112:116], byteorder="big"
+            )
+            reverse = 0
+            for i in range(0, 32):
+                reverse |= general_profile_compatibility & 1
+                if i == 31:
+                    break
+                reverse <<= 1
+                general_profile_compatibility >>= 1
+            codec += hex(reverse)[2:]
+
+            # Tier Flag
+            if (tmp_byte & 32) >> 5 == 0:
+                codec += ".L"
+            else:
+                codec += ".H"
+            codec += str(int.from_bytes(stsd_box[122:123], byteorder="big"))
+
+            # Constraint String
+            has_byte = False
+            constraint_string = ""
+            for i in range(121, 115, -1):
+                gci = int.from_bytes(stsd_box[i : i + 1], byteorder="big")
+                if gci or has_byte:
+                    constraint_string = "." + hex(gci)[2:] + constraint_string
+                    has_byte = True
+            codec += constraint_string
+
+        # Handle Audio
+        elif codec == "mp4a":
+            oti = None
+            dsi = None
+
+            # Parse ES Descriptors
+            oti_loc = stsd_box.find(b"\x04\x80\x80\x80")
+            if oti_loc > 0:
+                oti = stsd_box[oti_loc + 5 : oti_loc + 6].hex()
+                codec += f".{oti}"
+
+            dsi_loc = stsd_box.find(b"\x05\x80\x80\x80")
+            if dsi_loc > 0:
+                dsi_length = int.from_bytes(
+                    stsd_box[dsi_loc + 4 : dsi_loc + 5], byteorder="big"
+                )
+                dsi_data = stsd_box[dsi_loc + 5 : dsi_loc + 5 + dsi_length]
+                dsi0 = int.from_bytes(dsi_data[0:1], byteorder="big")
+                dsi = (dsi0 & 248) >> 3
+                if dsi == 31 and len(dsi_data) >= 2:
+                    dsi1 = int.from_bytes(dsi_data[1:2], byteorder="big")
+                    dsi = 32 + ((dsi0 & 7) << 3) + ((dsi1 & 224) >> 5)
+                codec += f".{dsi}"
+
+        codecs.append(codec)
+
+    return ",".join(codecs)
diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py
index 816d1231c4c..09729f79ada 100644
--- a/homeassistant/components/stream/hls.py
+++ b/homeassistant/components/stream/hls.py
@@ -1,4 +1,5 @@
 """Provide functionality to stream HLS."""
+import io
 from typing import Callable
 
 from aiohttp import web
@@ -7,7 +8,7 @@ from homeassistant.core import callback
 
 from .const import FORMAT_CONTENT_TYPE
 from .core import PROVIDERS, StreamOutput, StreamView
-from .fmp4utils import get_init, get_m4s
+from .fmp4utils import get_codec_string, get_init, get_m4s
 
 
 @callback
@@ -16,7 +17,43 @@ def async_setup_hls(hass):
     hass.http.register_view(HlsPlaylistView())
     hass.http.register_view(HlsSegmentView())
     hass.http.register_view(HlsInitView())
-    return "/api/hls/{}/playlist.m3u8"
+    hass.http.register_view(HlsMasterPlaylistView())
+    return "/api/hls/{}/master_playlist.m3u8"
+
+
+class HlsMasterPlaylistView(StreamView):
+    """Stream view used only for Chromecast compatibility."""
+
+    url = r"/api/hls/{token:[a-f0-9]+}/master_playlist.m3u8"
+    name = "api:stream:hls:master_playlist"
+    cors_allowed = True
+
+    @staticmethod
+    def render(track):
+        """Render M3U8 file."""
+        # Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work
+        # Calculate file size / duration and use a multiplier to account for variation
+        segment = track.get_segment(track.segments[-1])
+        bandwidth = round(
+            segment.segment.seek(0, io.SEEK_END) * 8 / segment.duration * 3
+        )
+        codecs = get_codec_string(segment.segment)
+        lines = [
+            "#EXTM3U",
+            f'#EXT-X-STREAM-INF:BANDWIDTH={bandwidth},CODECS="{codecs}"',
+            "playlist.m3u8",
+        ]
+        return "\n".join(lines) + "\n"
+
+    async def handle(self, request, stream, sequence):
+        """Return m3u8 playlist."""
+        track = stream.add_provider("hls")
+        stream.start()
+        # Wait for a segment to be ready
+        if not track.segments:
+            await track.recv()
+        headers = {"Content-Type": FORMAT_CONTENT_TYPE["hls"]}
+        return web.Response(body=self.render(track).encode("utf-8"), headers=headers)
 
 
 class HlsPlaylistView(StreamView):
@@ -26,18 +63,50 @@ class HlsPlaylistView(StreamView):
     name = "api:stream:hls:playlist"
     cors_allowed = True
 
+    @staticmethod
+    def render_preamble(track):
+        """Render preamble."""
+        return [
+            "#EXT-X-VERSION:7",
+            f"#EXT-X-TARGETDURATION:{track.target_duration}",
+            '#EXT-X-MAP:URI="init.mp4"',
+        ]
+
+    @staticmethod
+    def render_playlist(track):
+        """Render playlist."""
+        segments = track.segments
+
+        if not segments:
+            return []
+
+        playlist = ["#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0])]
+
+        for sequence in segments:
+            segment = track.get_segment(sequence)
+            playlist.extend(
+                [
+                    "#EXTINF:{:.04f},".format(float(segment.duration)),
+                    f"./segment/{segment.sequence}.m4s",
+                ]
+            )
+
+        return playlist
+
+    def render(self, track):
+        """Render M3U8 file."""
+        lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track)
+        return "\n".join(lines) + "\n"
+
     async def handle(self, request, stream, sequence):
         """Return m3u8 playlist."""
-        renderer = M3U8Renderer(stream)
         track = stream.add_provider("hls")
         stream.start()
         # Wait for a segment to be ready
         if not track.segments:
             await track.recv()
         headers = {"Content-Type": FORMAT_CONTENT_TYPE["hls"]}
-        return web.Response(
-            body=renderer.render(track).encode("utf-8"), headers=headers
-        )
+        return web.Response(body=self.render(track).encode("utf-8"), headers=headers)
 
 
 class HlsInitView(StreamView):
@@ -77,49 +146,6 @@ class HlsSegmentView(StreamView):
         )
 
 
-class M3U8Renderer:
-    """M3U8 Render Helper."""
-
-    def __init__(self, stream):
-        """Initialize renderer."""
-        self.stream = stream
-
-    @staticmethod
-    def render_preamble(track):
-        """Render preamble."""
-        return [
-            "#EXT-X-VERSION:7",
-            f"#EXT-X-TARGETDURATION:{track.target_duration}",
-            '#EXT-X-MAP:URI="init.mp4"',
-        ]
-
-    @staticmethod
-    def render_playlist(track):
-        """Render playlist."""
-        segments = track.segments
-
-        if not segments:
-            return []
-
-        playlist = ["#EXT-X-MEDIA-SEQUENCE:{}".format(segments[0])]
-
-        for sequence in segments:
-            segment = track.get_segment(sequence)
-            playlist.extend(
-                [
-                    "#EXTINF:{:.04f},".format(float(segment.duration)),
-                    f"./segment/{segment.sequence}.m4s",
-                ]
-            )
-
-        return playlist
-
-    def render(self, track):
-        """Render M3U8 file."""
-        lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track)
-        return "\n".join(lines) + "\n"
-
-
 @PROVIDERS.register("hls")
 class HlsStreamOutput(StreamOutput):
     """Represents HLS Output formats."""
@@ -137,7 +163,7 @@ class HlsStreamOutput(StreamOutput):
     @property
     def audio_codecs(self) -> str:
         """Return desired audio codecs."""
-        return {"aac", "ac3", "mp3"}
+        return {"aac", "mp3"}
 
     @property
     def video_codecs(self) -> tuple:
diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py
index 82b146cc51f..d0b8789f602 100644
--- a/homeassistant/components/stream/recorder.py
+++ b/homeassistant/components/stream/recorder.py
@@ -78,7 +78,7 @@ class RecorderOutput(StreamOutput):
     @property
     def audio_codecs(self) -> str:
         """Return desired audio codec."""
-        return {"aac", "ac3", "mp3"}
+        return {"aac", "mp3"}
 
     @property
     def video_codecs(self) -> tuple: