From 47286fbe2ac77be57170d47b226a3569979a2c3c Mon Sep 17 00:00:00 2001
From: Alexei Chetroi <lexoid@gmail.com>
Date: Mon, 28 Sep 2020 20:55:08 -0400
Subject: [PATCH] Refactor permit services to allow joining using install codes
 (#40652)

* Update zha.permit schema to support install code

* Move install code to core helpers

* QR code converter for enbrighten

* Fix schemas

* Update test for permit service

* Refactor zha.permit to accept install codes

* Test zha.permit from QR code

* Fix regex for Embrighten QR code

* Add regex for Aqara QR codes

* Add Consciot regex for QR code

* Reuse test params for WS tests

* ZHA WS permit command with install code

* Tests for zha.permit WS service

* Refactor zha.permit and zha.remove service to use ATTR_IEEE for the address

* Make pylint happy

* Deprecate only ieee_address param for now
---
 homeassistant/components/zha/api.py          |  85 ++++--
 homeassistant/components/zha/core/helpers.py |  66 ++++-
 homeassistant/components/zha/services.yaml   |   9 +
 tests/components/zha/test_api.py             | 263 ++++++++++++++++++-
 4 files changed, 400 insertions(+), 23 deletions(-)

diff --git a/homeassistant/components/zha/api.py b/homeassistant/components/zha/api.py
index 73613c04371..a5b409c7116 100644
--- a/homeassistant/components/zha/api.py
+++ b/homeassistant/components/zha/api.py
@@ -23,6 +23,7 @@ from .core.const import (
     ATTR_COMMAND,
     ATTR_COMMAND_TYPE,
     ATTR_ENDPOINT_ID,
+    ATTR_IEEE,
     ATTR_LEVEL,
     ATTR_MANUFACTURER,
     ATTR_MEMBERS,
@@ -54,7 +55,12 @@ from .core.const import (
     WARNING_DEVICE_STROBE_YES,
 )
 from .core.group import GroupMember
-from .core.helpers import async_is_bindable_target, get_matched_clusters
+from .core.helpers import (
+    async_is_bindable_target,
+    convert_install_code,
+    get_matched_clusters,
+    qr_to_install_code,
+)
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -67,9 +73,10 @@ DEVICE_INFO = "device_info"
 ATTR_DURATION = "duration"
 ATTR_GROUP = "group"
 ATTR_IEEE_ADDRESS = "ieee_address"
-ATTR_IEEE = "ieee"
+ATTR_INSTALL_CODE = "install_code"
 ATTR_SOURCE_IEEE = "source_ieee"
 ATTR_TARGET_IEEE = "target_ieee"
+ATTR_QR_CODE = "qr_code"
 
 SERVICE_PERMIT = "permit"
 SERVICE_REMOVE = "remove"
@@ -83,16 +90,29 @@ SERVICE_WARNING_DEVICE_WARN = "warning_device_warn"
 SERVICE_ZIGBEE_BIND = "service_zigbee_bind"
 IEEE_SERVICE = "ieee_based_service"
 
+SERVICE_PERMIT_PARAMS = {
+    vol.Optional(ATTR_IEEE, default=None): EUI64.convert,
+    vol.Optional(ATTR_DURATION, default=60): vol.All(
+        vol.Coerce(int), vol.Range(0, 254)
+    ),
+    vol.Inclusive(ATTR_SOURCE_IEEE, "install_code"): EUI64.convert,
+    vol.Inclusive(ATTR_INSTALL_CODE, "install_code"): convert_install_code,
+    vol.Exclusive(ATTR_QR_CODE, "install_code"): vol.All(str, qr_to_install_code),
+}
+
 SERVICE_SCHEMAS = {
     SERVICE_PERMIT: vol.Schema(
-        {
-            vol.Optional(ATTR_IEEE_ADDRESS, default=None): EUI64.convert,
-            vol.Optional(ATTR_DURATION, default=60): vol.All(
-                vol.Coerce(int), vol.Range(0, 254)
-            ),
-        }
+        vol.All(
+            cv.deprecated(ATTR_IEEE_ADDRESS, replacement_key=ATTR_IEEE),
+            SERVICE_PERMIT_PARAMS,
+        )
+    ),
+    IEEE_SERVICE: vol.Schema(
+        vol.All(
+            cv.deprecated(ATTR_IEEE_ADDRESS, replacement_key=ATTR_IEEE),
+            {vol.Required(ATTR_IEEE): EUI64.convert},
+        )
     ),
-    IEEE_SERVICE: vol.Schema({vol.Required(ATTR_IEEE_ADDRESS): EUI64.convert}),
     SERVICE_SET_ZIGBEE_CLUSTER_ATTRIBUTE: vol.Schema(
         {
             vol.Required(ATTR_IEEE): EUI64.convert,
@@ -169,13 +189,7 @@ ClusterBinding = collections.namedtuple("ClusterBinding", "id endpoint_id type n
 @websocket_api.require_admin
 @websocket_api.async_response
 @websocket_api.websocket_command(
-    {
-        vol.Required("type"): "zha/devices/permit",
-        vol.Optional(ATTR_IEEE, default=None): EUI64.convert,
-        vol.Optional(ATTR_DURATION, default=60): vol.All(
-            vol.Coerce(int), vol.Range(0, 254)
-        ),
-    }
+    {vol.Required("type"): "zha/devices/permit", **SERVICE_PERMIT_PARAMS}
 )
 async def websocket_permit_devices(hass, connection, msg):
     """Permit ZHA zigbee devices."""
@@ -199,7 +213,21 @@ async def websocket_permit_devices(hass, connection, msg):
 
     connection.subscriptions[msg["id"]] = async_cleanup
     zha_gateway.async_enable_debug_mode()
-    await zha_gateway.application_controller.permit(time_s=duration, node=ieee)
+    if ATTR_SOURCE_IEEE in msg:
+        src_ieee = msg[ATTR_SOURCE_IEEE]
+        code = msg[ATTR_INSTALL_CODE]
+        _LOGGER.debug("Allowing join for %s device with install code", src_ieee)
+        await zha_gateway.application_controller.permit_with_key(
+            time_s=duration, node=src_ieee, code=code
+        )
+    elif ATTR_QR_CODE in msg:
+        src_ieee, code = msg[ATTR_QR_CODE]
+        _LOGGER.debug("Allowing join for %s device with install code", src_ieee)
+        await zha_gateway.application_controller.permit_with_key(
+            time_s=duration, node=src_ieee, code=code
+        )
+    else:
+        await zha_gateway.application_controller.permit(time_s=duration, node=ieee)
     connection.send_result(msg["id"])
 
 
@@ -826,8 +854,25 @@ def async_load_api(hass):
 
     async def permit(service):
         """Allow devices to join this network."""
-        duration = service.data.get(ATTR_DURATION)
-        ieee = service.data.get(ATTR_IEEE_ADDRESS)
+        duration = service.data[ATTR_DURATION]
+        ieee = service.data.get(ATTR_IEEE)
+        if ATTR_SOURCE_IEEE in service.data:
+            src_ieee = service.data[ATTR_SOURCE_IEEE]
+            code = service.data[ATTR_INSTALL_CODE]
+            _LOGGER.info("Allowing join for %s device with install code", src_ieee)
+            await application_controller.permit_with_key(
+                time_s=duration, node=src_ieee, code=code
+            )
+            return
+
+        if ATTR_QR_CODE in service.data:
+            src_ieee, code = service.data[ATTR_QR_CODE]
+            _LOGGER.info("Allowing join for %s device with install code", src_ieee)
+            await application_controller.permit_with_key(
+                time_s=duration, node=src_ieee, code=code
+            )
+            return
+
         if ieee:
             _LOGGER.info("Permitting joins for %ss on %s device", duration, ieee)
         else:
@@ -840,7 +885,7 @@ def async_load_api(hass):
 
     async def remove(service):
         """Remove a node from the network."""
-        ieee = service.data[ATTR_IEEE_ADDRESS]
+        ieee = service.data[ATTR_IEEE]
         zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
         zha_device = zha_gateway.get_device(ieee)
         if zha_device is not None and zha_device.is_coordinator:
diff --git a/homeassistant/components/zha/core/helpers.py b/homeassistant/components/zha/core/helpers.py
index 7813c7133ad..0e967a7a123 100644
--- a/homeassistant/components/zha/core/helpers.py
+++ b/homeassistant/components/zha/core/helpers.py
@@ -6,15 +6,19 @@ https://home-assistant.io/integrations/zha/
 """
 
 import asyncio
+import binascii
 import collections
 import functools
 import itertools
 import logging
 from random import uniform
-from typing import Any, Callable, Iterator, List, Optional
+import re
+from typing import Any, Callable, Iterator, List, Optional, Tuple
 
+import voluptuous as vol
 import zigpy.exceptions
 import zigpy.types
+import zigpy.util
 
 from homeassistant.core import State, callback
 
@@ -205,3 +209,63 @@ def retryable_req(
         return wrapper
 
     return decorator
+
+
+def convert_install_code(value: str) -> bytes:
+    """Convert string to install code bytes and validate length."""
+
+    try:
+        code = binascii.unhexlify(value.replace("-", "").lower())
+    except binascii.Error as exc:
+        raise vol.Invalid(f"invalid hex string: {value}") from exc
+
+    if len(code) != 18:  # 16 byte code + 2 crc bytes
+        raise vol.Invalid("invalid length of the install code")
+
+    if zigpy.util.convert_install_code(code) is None:
+        raise vol.Invalid("invalid install code")
+
+    return code
+
+
+QR_CODES = (
+    # Consciot
+    r"^([\da-fA-F]{16})\|([\da-fA-F]{36})$",
+    # Enbrighten
+    r"""
+        ^Z:
+        ([0-9a-fA-F]{16})  # IEEE address
+        \$I:
+        ([0-9a-fA-F]{36})  # install code
+        $
+    """,
+    # Aqara
+    r"""
+        \$A:
+        ([0-9a-fA-F]{16})  # IEEE address
+        \$I:
+        ([0-9a-fA-F]{36})  # install code
+        $
+    """,
+)
+
+
+def qr_to_install_code(qr_code: str) -> Tuple[zigpy.types.EUI64, bytes]:
+    """Try to parse the QR code.
+
+    if successful, return a tuple of a EUI64 address and install code.
+    """
+
+    for code_pattern in QR_CODES:
+        match = re.search(code_pattern, qr_code, re.VERBOSE)
+        if match is None:
+            continue
+
+        ieee_hex = binascii.unhexlify(match[1])
+        ieee = zigpy.types.EUI64(ieee_hex[::-1])
+        install_code = match[2]
+        # install_code sanity check
+        install_code = convert_install_code(install_code)
+        return ieee, install_code
+
+    raise vol.Invalid(f"couldn't convert qr code: {qr_code}")
diff --git a/homeassistant/components/zha/services.yaml b/homeassistant/components/zha/services.yaml
index 257d1026f7f..74793d6000f 100644
--- a/homeassistant/components/zha/services.yaml
+++ b/homeassistant/components/zha/services.yaml
@@ -9,6 +9,15 @@ permit:
     ieee_address:
       description: IEEE address of the node permitting new joins
       example: "00:0d:6f:00:05:7d:2d:34"
+    source_ieee:
+      description: IEEE address of the joining device (must be used with install code)
+      example: "00:0a:bf:00:01:10:23:35"
+    install_code:
+      description: Install code of the joining device (must be used with source_ieee)
+      example: "1234-5678-1234-5678-AABB-CCDD-AABB-CCDD-EEFF"
+    qr_code:
+      description: value of the QR install code (different between vendors)
+      example: "Z:000D6FFFFED4163B$I:52797BF4A5084DAA8E1712B61741CA024051"
 
 remove:
   description: Remove a node from the Zigbee network.
diff --git a/tests/components/zha/test_api.py b/tests/components/zha/test_api.py
index 0587bd14c8c..b67891d7cbb 100644
--- a/tests/components/zha/test_api.py
+++ b/tests/components/zha/test_api.py
@@ -1,11 +1,24 @@
 """Test ZHA API."""
+from binascii import unhexlify
 
 import pytest
+import voluptuous as vol
 import zigpy.profiles.zha
+import zigpy.types
 import zigpy.zcl.clusters.general as general
 
 from homeassistant.components.websocket_api import const
-from homeassistant.components.zha.api import ID, TYPE, async_load_api
+from homeassistant.components.zha import DOMAIN
+from homeassistant.components.zha.api import (
+    ATTR_DURATION,
+    ATTR_INSTALL_CODE,
+    ATTR_QR_CODE,
+    ATTR_SOURCE_IEEE,
+    ID,
+    SERVICE_PERMIT,
+    TYPE,
+    async_load_api,
+)
 from homeassistant.components.zha.core.const import (
     ATTR_CLUSTER_ID,
     ATTR_CLUSTER_TYPE,
@@ -16,13 +29,18 @@ from homeassistant.components.zha.core.const import (
     ATTR_NAME,
     ATTR_QUIRK_APPLIED,
     CLUSTER_TYPE_IN,
+    DATA_ZHA,
+    DATA_ZHA_GATEWAY,
     GROUP_ID,
     GROUP_IDS,
     GROUP_NAME,
 )
+from homeassistant.core import Context
 
 from .conftest import FIXTURE_GRP_ID, FIXTURE_GRP_NAME
 
+from tests.async_mock import AsyncMock, patch
+
 IEEE_SWITCH_DEVICE = "01:2d:6f:00:0a:90:69:e7"
 IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
 
@@ -225,7 +243,7 @@ async def test_get_group(zha_client):
 
 async def test_get_group_not_found(zha_client):
     """Test not found response from get group API."""
-    await zha_client.send_json({ID: 9, TYPE: "zha/group", GROUP_ID: 1234567})
+    await zha_client.send_json({ID: 9, TYPE: "zha/group", GROUP_ID: 1_234_567})
 
     msg = await zha_client.receive_json()
 
@@ -335,3 +353,244 @@ async def test_remove_group(zha_client):
 
     groups = msg["result"]
     assert len(groups) == 0
+
+
+@pytest.fixture
+async def app_controller(hass, setup_zha):
+    """Fixture for zigpy Application Controller."""
+    await setup_zha()
+    controller = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY].application_controller
+    p1 = patch.object(controller, "permit")
+    p2 = patch.object(controller, "permit_with_key", new=AsyncMock())
+    with p1, p2:
+        yield controller
+
+
+@pytest.mark.parametrize(
+    "params, duration, node",
+    (
+        ({}, 60, None),
+        ({ATTR_DURATION: 30}, 30, None),
+        (
+            {ATTR_DURATION: 33, ATTR_IEEE: "aa:bb:cc:dd:aa:bb:cc:dd"},
+            33,
+            zigpy.types.EUI64.convert("aa:bb:cc:dd:aa:bb:cc:dd"),
+        ),
+        (
+            {ATTR_IEEE: "aa:bb:cc:dd:aa:bb:cc:d1"},
+            60,
+            zigpy.types.EUI64.convert("aa:bb:cc:dd:aa:bb:cc:d1"),
+        ),
+    ),
+)
+async def test_permit_ha12(
+    hass, app_controller, hass_admin_user, params, duration, node
+):
+    """Test permit service."""
+
+    await hass.services.async_call(
+        DOMAIN, SERVICE_PERMIT, params, True, Context(user_id=hass_admin_user.id)
+    )
+    assert app_controller.permit.await_count == 1
+    assert app_controller.permit.await_args[1]["time_s"] == duration
+    assert app_controller.permit.await_args[1]["node"] == node
+    assert app_controller.permit_with_key.call_count == 0
+
+
+IC_TEST_PARAMS = (
+    (
+        {
+            ATTR_SOURCE_IEEE: IEEE_SWITCH_DEVICE,
+            ATTR_INSTALL_CODE: "5279-7BF4-A508-4DAA-8E17-12B6-1741-CA02-4051",
+        },
+        zigpy.types.EUI64.convert(IEEE_SWITCH_DEVICE),
+        unhexlify("52797BF4A5084DAA8E1712B61741CA024051"),
+    ),
+    (
+        {
+            ATTR_SOURCE_IEEE: IEEE_SWITCH_DEVICE,
+            ATTR_INSTALL_CODE: "52797BF4A5084DAA8E1712B61741CA024051",
+        },
+        zigpy.types.EUI64.convert(IEEE_SWITCH_DEVICE),
+        unhexlify("52797BF4A5084DAA8E1712B61741CA024051"),
+    ),
+)
+
+
+@pytest.mark.parametrize("params, src_ieee, code", IC_TEST_PARAMS)
+async def test_permit_with_install_code(
+    hass, app_controller, hass_admin_user, params, src_ieee, code
+):
+    """Test permit service with install code."""
+
+    await hass.services.async_call(
+        DOMAIN, SERVICE_PERMIT, params, True, Context(user_id=hass_admin_user.id)
+    )
+    assert app_controller.permit.await_count == 0
+    assert app_controller.permit_with_key.call_count == 1
+    assert app_controller.permit_with_key.await_args[1]["time_s"] == 60
+    assert app_controller.permit_with_key.await_args[1]["node"] == src_ieee
+    assert app_controller.permit_with_key.await_args[1]["code"] == code
+
+
+IC_FAIL_PARAMS = (
+    {
+        # wrong install code
+        ATTR_SOURCE_IEEE: IEEE_SWITCH_DEVICE,
+        ATTR_INSTALL_CODE: "5279-7BF4-A508-4DAA-8E17-12B6-1741-CA02-4052",
+    },
+    # incorrect service params
+    {ATTR_INSTALL_CODE: "5279-7BF4-A508-4DAA-8E17-12B6-1741-CA02-4051"},
+    {ATTR_SOURCE_IEEE: IEEE_SWITCH_DEVICE},
+    {
+        # incorrect service params
+        ATTR_INSTALL_CODE: "5279-7BF4-A508-4DAA-8E17-12B6-1741-CA02-4051",
+        ATTR_QR_CODE: "Z:000D6FFFFED4163B$I:52797BF4A5084DAA8E1712B61741CA024051",
+    },
+    {
+        # incorrect service params
+        ATTR_SOURCE_IEEE: IEEE_SWITCH_DEVICE,
+        ATTR_QR_CODE: "Z:000D6FFFFED4163B$I:52797BF4A5084DAA8E1712B61741CA024051",
+    },
+    {
+        # good regex match, but bad code
+        ATTR_QR_CODE: "Z:000D6FFFFED4163B$I:52797BF4A5084DAA8E1712B61741CA024052"
+    },
+    {
+        # good aqara regex match, but bad code
+        ATTR_QR_CODE: (
+            "G$M:751$S:357S00001579$D:000000000F350FFD%Z$A:04CF8CDF"
+            "3C3C3C3C$I:52797BF4A5084DAA8E1712B61741CA024052"
+        )
+    },
+    # good consciot regex match, but bad code
+    {ATTR_QR_CODE: "000D6FFFFED4163B|52797BF4A5084DAA8E1712B61741CA024052"},
+)
+
+
+@pytest.mark.parametrize("params", IC_FAIL_PARAMS)
+async def test_permit_with_install_code_fail(
+    hass, app_controller, hass_admin_user, params
+):
+    """Test permit service with install code."""
+
+    with pytest.raises(vol.Invalid):
+        await hass.services.async_call(
+            DOMAIN, SERVICE_PERMIT, params, True, Context(user_id=hass_admin_user.id)
+        )
+    assert app_controller.permit.await_count == 0
+    assert app_controller.permit_with_key.call_count == 0
+
+
+IC_QR_CODE_TEST_PARAMS = (
+    (
+        {ATTR_QR_CODE: "000D6FFFFED4163B|52797BF4A5084DAA8E1712B61741CA024051"},
+        zigpy.types.EUI64.convert("00:0D:6F:FF:FE:D4:16:3B"),
+        unhexlify("52797BF4A5084DAA8E1712B61741CA024051"),
+    ),
+    (
+        {ATTR_QR_CODE: "Z:000D6FFFFED4163B$I:52797BF4A5084DAA8E1712B61741CA024051"},
+        zigpy.types.EUI64.convert("00:0D:6F:FF:FE:D4:16:3B"),
+        unhexlify("52797BF4A5084DAA8E1712B61741CA024051"),
+    ),
+    (
+        {
+            ATTR_QR_CODE: (
+                "G$M:751$S:357S00001579$D:000000000F350FFD%Z$A:04CF8CDF"
+                "3C3C3C3C$I:52797BF4A5084DAA8E1712B61741CA024051"
+            )
+        },
+        zigpy.types.EUI64.convert("04:CF:8C:DF:3C:3C:3C:3C"),
+        unhexlify("52797BF4A5084DAA8E1712B61741CA024051"),
+    ),
+)
+
+
+@pytest.mark.parametrize("params, src_ieee, code", IC_QR_CODE_TEST_PARAMS)
+async def test_permit_with_qr_code(
+    hass, app_controller, hass_admin_user, params, src_ieee, code
+):
+    """Test permit service with install code from qr code."""
+
+    await hass.services.async_call(
+        DOMAIN, SERVICE_PERMIT, params, True, Context(user_id=hass_admin_user.id)
+    )
+    assert app_controller.permit.await_count == 0
+    assert app_controller.permit_with_key.call_count == 1
+    assert app_controller.permit_with_key.await_args[1]["time_s"] == 60
+    assert app_controller.permit_with_key.await_args[1]["node"] == src_ieee
+    assert app_controller.permit_with_key.await_args[1]["code"] == code
+
+
+@pytest.mark.parametrize("params, src_ieee, code", IC_QR_CODE_TEST_PARAMS)
+async def test_ws_permit_with_qr_code(
+    app_controller, zha_client, params, src_ieee, code
+):
+    """Test permit service with install code from qr code."""
+
+    await zha_client.send_json(
+        {ID: 14, TYPE: f"{DOMAIN}/devices/{SERVICE_PERMIT}", **params}
+    )
+
+    msg = await zha_client.receive_json()
+    assert msg["id"] == 14
+    assert msg["type"] == const.TYPE_RESULT
+    assert msg["success"]
+
+    assert app_controller.permit.await_count == 0
+    assert app_controller.permit_with_key.call_count == 1
+    assert app_controller.permit_with_key.await_args[1]["time_s"] == 60
+    assert app_controller.permit_with_key.await_args[1]["node"] == src_ieee
+    assert app_controller.permit_with_key.await_args[1]["code"] == code
+
+
+@pytest.mark.parametrize("params", IC_FAIL_PARAMS)
+async def test_ws_permit_with_install_code_fail(app_controller, zha_client, params):
+    """Test permit ws service with install code."""
+
+    await zha_client.send_json(
+        {ID: 14, TYPE: f"{DOMAIN}/devices/{SERVICE_PERMIT}", **params}
+    )
+
+    msg = await zha_client.receive_json()
+    assert msg["id"] == 14
+    assert msg["type"] == const.TYPE_RESULT
+    assert msg["success"] is False
+
+    assert app_controller.permit.await_count == 0
+    assert app_controller.permit_with_key.call_count == 0
+
+
+@pytest.mark.parametrize(
+    "params, duration, node",
+    (
+        ({}, 60, None),
+        ({ATTR_DURATION: 30}, 30, None),
+        (
+            {ATTR_DURATION: 33, ATTR_IEEE: "aa:bb:cc:dd:aa:bb:cc:dd"},
+            33,
+            zigpy.types.EUI64.convert("aa:bb:cc:dd:aa:bb:cc:dd"),
+        ),
+        (
+            {ATTR_IEEE: "aa:bb:cc:dd:aa:bb:cc:d1"},
+            60,
+            zigpy.types.EUI64.convert("aa:bb:cc:dd:aa:bb:cc:d1"),
+        ),
+    ),
+)
+async def test_ws_permit_ha12(app_controller, zha_client, params, duration, node):
+    """Test permit ws service."""
+
+    await zha_client.send_json(
+        {ID: 14, TYPE: f"{DOMAIN}/devices/{SERVICE_PERMIT}", **params}
+    )
+
+    msg = await zha_client.receive_json()
+    assert msg["id"] == 14
+    assert msg["type"] == const.TYPE_RESULT
+    assert msg["success"]
+
+    assert app_controller.permit.await_count == 1
+    assert app_controller.permit.await_args[1]["time_s"] == duration
+    assert app_controller.permit.await_args[1]["node"] == node
+    assert app_controller.permit_with_key.call_count == 0