Fix zha circular import (#29802)
* Refactor zha.core.helpers. * Make zha isort-able. * Const import reorg. * Keep DATA_ZHA config key on entry unload. * Cleanup ZHA config flow. * isort. * Add test.
This commit is contained in:
parent
12f273eb11
commit
315d0064fe
6 changed files with 113 additions and 73 deletions
|
@ -1,7 +1,4 @@
|
||||||
"""Support for Zigbee Home Automation devices.
|
"""Support for Zigbee Home Automation devices."""
|
||||||
|
|
||||||
isort:skip_file
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -11,7 +8,6 @@ from homeassistant import config_entries, const as ha_const
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE
|
from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE
|
||||||
|
|
||||||
from . import config_flow # noqa: F401 pylint: disable=unused-import
|
|
||||||
from . import api
|
from . import api
|
||||||
from .core import ZHAGateway
|
from .core import ZHAGateway
|
||||||
from .core.const import (
|
from .core.const import (
|
||||||
|
@ -147,5 +143,4 @@ async def async_unload_entry(hass, config_entry):
|
||||||
for component in COMPONENTS:
|
for component in COMPONENTS:
|
||||||
await hass.config_entries.async_forward_entry_unload(config_entry, component)
|
await hass.config_entries.async_forward_entry_unload(config_entry, component)
|
||||||
|
|
||||||
del hass.data[DATA_ZHA]
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -50,7 +50,11 @@ from .core.const import (
|
||||||
WARNING_DEVICE_STROBE_HIGH,
|
WARNING_DEVICE_STROBE_HIGH,
|
||||||
WARNING_DEVICE_STROBE_YES,
|
WARNING_DEVICE_STROBE_YES,
|
||||||
)
|
)
|
||||||
from .core.helpers import async_is_bindable_target, get_matched_clusters
|
from .core.helpers import (
|
||||||
|
async_get_device_info,
|
||||||
|
async_is_bindable_target,
|
||||||
|
get_matched_clusters,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -423,31 +427,6 @@ async def websocket_remove_group_members(hass, connection, msg):
|
||||||
connection.send_result(msg[ID], ret_group)
|
connection.send_result(msg[ID], ret_group)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_get_device_info(hass, device, ha_device_registry=None):
|
|
||||||
"""Get ZHA device."""
|
|
||||||
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
|
||||||
ret_device = {}
|
|
||||||
ret_device.update(device.device_info)
|
|
||||||
ret_device["entities"] = [
|
|
||||||
{
|
|
||||||
"entity_id": entity_ref.reference_id,
|
|
||||||
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
|
|
||||||
}
|
|
||||||
for entity_ref in zha_gateway.device_registry[device.ieee]
|
|
||||||
]
|
|
||||||
|
|
||||||
if ha_device_registry is not None:
|
|
||||||
reg_device = ha_device_registry.async_get_device(
|
|
||||||
{(DOMAIN, str(device.ieee))}, set()
|
|
||||||
)
|
|
||||||
if reg_device is not None:
|
|
||||||
ret_device["user_given_name"] = reg_device.name_by_user
|
|
||||||
ret_device["device_reg_id"] = reg_device.id
|
|
||||||
ret_device["area_id"] = reg_device.area_id
|
|
||||||
return ret_device
|
|
||||||
|
|
||||||
|
|
||||||
async def get_groups(hass,):
|
async def get_groups(hass,):
|
||||||
"""Get ZHA Groups."""
|
"""Get ZHA Groups."""
|
||||||
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Config flow for ZHA."""
|
"""Config flow for ZHA."""
|
||||||
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -9,11 +10,14 @@ from homeassistant import config_entries
|
||||||
from .core.const import (
|
from .core.const import (
|
||||||
CONF_RADIO_TYPE,
|
CONF_RADIO_TYPE,
|
||||||
CONF_USB_PATH,
|
CONF_USB_PATH,
|
||||||
|
CONTROLLER,
|
||||||
|
DEFAULT_BAUDRATE,
|
||||||
DEFAULT_DATABASE_NAME,
|
DEFAULT_DATABASE_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
ZHA_GW_RADIO,
|
||||||
RadioType,
|
RadioType,
|
||||||
)
|
)
|
||||||
from .core.helpers import check_zigpy_connection
|
from .core.registries import RADIO_TYPES
|
||||||
|
|
||||||
|
|
||||||
@config_entries.HANDLERS.register(DOMAIN)
|
@config_entries.HANDLERS.register(DOMAIN)
|
||||||
|
@ -57,3 +61,20 @@ class ZhaFlowHandler(config_entries.ConfigFlow):
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=import_info[CONF_USB_PATH], data=import_info
|
title=import_info[CONF_USB_PATH], data=import_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_zigpy_connection(usb_path, radio_type, database_path):
|
||||||
|
"""Test zigpy radio connection."""
|
||||||
|
try:
|
||||||
|
radio = RADIO_TYPES[radio_type][ZHA_GW_RADIO]()
|
||||||
|
controller_application = RADIO_TYPES[radio_type][CONTROLLER]
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
await radio.connect(usb_path, DEFAULT_BAUDRATE)
|
||||||
|
controller = controller_application(radio, database_path)
|
||||||
|
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
|
||||||
|
await controller.shutdown()
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
@ -20,7 +20,6 @@ from homeassistant.helpers.device_registry import (
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||||
|
|
||||||
from ..api import async_get_device_info
|
|
||||||
from .const import (
|
from .const import (
|
||||||
ATTR_IEEE,
|
ATTR_IEEE,
|
||||||
ATTR_MANUFACTURER,
|
ATTR_MANUFACTURER,
|
||||||
|
@ -65,6 +64,7 @@ from .const import (
|
||||||
)
|
)
|
||||||
from .device import DeviceStatus, ZHADevice
|
from .device import DeviceStatus, ZHADevice
|
||||||
from .discovery import async_dispatch_discovery_info, async_process_endpoint
|
from .discovery import async_dispatch_discovery_info, async_process_endpoint
|
||||||
|
from .helpers import async_get_device_info
|
||||||
from .patches import apply_application_controller_patch
|
from .patches import apply_application_controller_patch
|
||||||
from .registries import RADIO_TYPES
|
from .registries import RADIO_TYPES
|
||||||
from .store import async_get_registry
|
from .store import async_get_registry
|
||||||
|
|
|
@ -4,29 +4,20 @@ Helpers for Zigbee Home Automation.
|
||||||
For more details about this component, please refer to the documentation at
|
For more details about this component, please refer to the documentation at
|
||||||
https://home-assistant.io/integrations/zha/
|
https://home-assistant.io/integrations/zha/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import bellows.ezsp
|
|
||||||
import bellows.zigbee.application
|
|
||||||
import zigpy.types
|
import zigpy.types
|
||||||
import zigpy_deconz.api
|
|
||||||
import zigpy_deconz.zigbee.application
|
|
||||||
import zigpy_xbee.api
|
|
||||||
import zigpy_xbee.zigbee.application
|
|
||||||
import zigpy_zigate.api
|
|
||||||
import zigpy_zigate.zigbee.application
|
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
ATTR_NAME,
|
||||||
CLUSTER_TYPE_IN,
|
CLUSTER_TYPE_IN,
|
||||||
CLUSTER_TYPE_OUT,
|
CLUSTER_TYPE_OUT,
|
||||||
DATA_ZHA,
|
DATA_ZHA,
|
||||||
DATA_ZHA_GATEWAY,
|
DATA_ZHA_GATEWAY,
|
||||||
DEFAULT_BAUDRATE,
|
DOMAIN,
|
||||||
RadioType,
|
|
||||||
)
|
)
|
||||||
from .registries import BINDABLE_CLUSTERS
|
from .registries import BINDABLE_CLUSTERS
|
||||||
|
|
||||||
|
@ -56,30 +47,6 @@ async def safe_read(
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
async def check_zigpy_connection(usb_path, radio_type, database_path):
|
|
||||||
"""Test zigpy radio connection."""
|
|
||||||
if radio_type == RadioType.ezsp.name:
|
|
||||||
radio = bellows.ezsp.EZSP()
|
|
||||||
ControllerApplication = bellows.zigbee.application.ControllerApplication
|
|
||||||
elif radio_type == RadioType.xbee.name:
|
|
||||||
radio = zigpy_xbee.api.XBee()
|
|
||||||
ControllerApplication = zigpy_xbee.zigbee.application.ControllerApplication
|
|
||||||
elif radio_type == RadioType.deconz.name:
|
|
||||||
radio = zigpy_deconz.api.Deconz()
|
|
||||||
ControllerApplication = zigpy_deconz.zigbee.application.ControllerApplication
|
|
||||||
elif radio_type == RadioType.zigate.name:
|
|
||||||
radio = zigpy_zigate.api.ZiGate()
|
|
||||||
ControllerApplication = zigpy_zigate.zigbee.application.ControllerApplication
|
|
||||||
try:
|
|
||||||
await radio.connect(usb_path, DEFAULT_BAUDRATE)
|
|
||||||
controller = ControllerApplication(radio, database_path)
|
|
||||||
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
|
|
||||||
await controller.shutdown()
|
|
||||||
except Exception: # pylint: disable=broad-except
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_attr_id_by_name(cluster, attr_name):
|
def get_attr_id_by_name(cluster, attr_name):
|
||||||
"""Get the attribute id for a cluster attribute by its name."""
|
"""Get the attribute id for a cluster attribute by its name."""
|
||||||
return next(
|
return next(
|
||||||
|
@ -164,3 +131,28 @@ class LogMixin:
|
||||||
def error(self, msg, *args):
|
def error(self, msg, *args):
|
||||||
"""Error level log."""
|
"""Error level log."""
|
||||||
return self.log(logging.ERROR, msg, *args)
|
return self.log(logging.ERROR, msg, *args)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_device_info(hass, device, ha_device_registry=None):
|
||||||
|
"""Get ZHA device."""
|
||||||
|
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
||||||
|
ret_device = {}
|
||||||
|
ret_device.update(device.device_info)
|
||||||
|
ret_device["entities"] = [
|
||||||
|
{
|
||||||
|
"entity_id": entity_ref.reference_id,
|
||||||
|
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
|
||||||
|
}
|
||||||
|
for entity_ref in zha_gateway.device_registry[device.ieee]
|
||||||
|
]
|
||||||
|
|
||||||
|
if ha_device_registry is not None:
|
||||||
|
reg_device = ha_device_registry.async_get_device(
|
||||||
|
{(DOMAIN, str(device.ieee))}, set()
|
||||||
|
)
|
||||||
|
if reg_device is not None:
|
||||||
|
ret_device["user_given_name"] = reg_device.name_by_user
|
||||||
|
ret_device["device_reg_id"] = reg_device.id
|
||||||
|
ret_device["area_id"] = reg_device.area_id
|
||||||
|
return ret_device
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
"""Tests for ZHA config flow."""
|
"""Tests for ZHA config flow."""
|
||||||
from asynctest import patch
|
from unittest import mock
|
||||||
|
|
||||||
|
import asynctest
|
||||||
|
|
||||||
from homeassistant.components.zha import config_flow
|
from homeassistant.components.zha import config_flow
|
||||||
from homeassistant.components.zha.core.const import DOMAIN
|
from homeassistant.components.zha.core.const import CONTROLLER, DOMAIN, ZHA_GW_RADIO
|
||||||
|
import homeassistant.components.zha.core.registries
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
@ -12,7 +15,7 @@ async def test_user_flow(hass):
|
||||||
flow = config_flow.ZhaFlowHandler()
|
flow = config_flow.ZhaFlowHandler()
|
||||||
flow.hass = hass
|
flow.hass = hass
|
||||||
|
|
||||||
with patch(
|
with asynctest.patch(
|
||||||
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
||||||
return_value=False,
|
return_value=False,
|
||||||
):
|
):
|
||||||
|
@ -22,7 +25,7 @@ async def test_user_flow(hass):
|
||||||
|
|
||||||
assert result["errors"] == {"base": "cannot_connect"}
|
assert result["errors"] == {"base": "cannot_connect"}
|
||||||
|
|
||||||
with patch(
|
with asynctest.patch(
|
||||||
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
):
|
):
|
||||||
|
@ -71,3 +74,53 @@ async def test_import_flow_existing_config_entry(hass):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["type"] == "abort"
|
assert result["type"] == "abort"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_zigpy_connection():
|
||||||
|
"""Test config flow validator."""
|
||||||
|
|
||||||
|
mock_radio = asynctest.MagicMock()
|
||||||
|
mock_radio.connect = asynctest.CoroutineMock()
|
||||||
|
radio_cls = asynctest.MagicMock(return_value=mock_radio)
|
||||||
|
|
||||||
|
bad_radio = asynctest.MagicMock()
|
||||||
|
bad_radio.connect = asynctest.CoroutineMock(side_effect=Exception)
|
||||||
|
bad_radio_cls = asynctest.MagicMock(return_value=bad_radio)
|
||||||
|
|
||||||
|
mock_ctrl = asynctest.MagicMock()
|
||||||
|
mock_ctrl.startup = asynctest.CoroutineMock()
|
||||||
|
mock_ctrl.shutdown = asynctest.CoroutineMock()
|
||||||
|
ctrl_cls = asynctest.MagicMock(return_value=mock_ctrl)
|
||||||
|
new_radios = {
|
||||||
|
mock.sentinel.radio: {ZHA_GW_RADIO: radio_cls, CONTROLLER: ctrl_cls},
|
||||||
|
mock.sentinel.bad_radio: {ZHA_GW_RADIO: bad_radio_cls, CONTROLLER: ctrl_cls},
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch.dict(
|
||||||
|
homeassistant.components.zha.core.registries.RADIO_TYPES, new_radios, clear=True
|
||||||
|
):
|
||||||
|
assert not await config_flow.check_zigpy_connection(
|
||||||
|
mock.sentinel.usb_path, mock.sentinel.unk_radio, mock.sentinel.zigbee_db
|
||||||
|
)
|
||||||
|
assert mock_radio.connect.call_count == 0
|
||||||
|
assert bad_radio.connect.call_count == 0
|
||||||
|
assert mock_ctrl.startup.call_count == 0
|
||||||
|
assert mock_ctrl.shutdown.call_count == 0
|
||||||
|
|
||||||
|
# unsuccessful radio connect
|
||||||
|
assert not await config_flow.check_zigpy_connection(
|
||||||
|
mock.sentinel.usb_path, mock.sentinel.bad_radio, mock.sentinel.zigbee_db
|
||||||
|
)
|
||||||
|
assert mock_radio.connect.call_count == 0
|
||||||
|
assert bad_radio.connect.call_count == 1
|
||||||
|
assert mock_ctrl.startup.call_count == 0
|
||||||
|
assert mock_ctrl.shutdown.call_count == 0
|
||||||
|
|
||||||
|
# successful radio connect
|
||||||
|
assert await config_flow.check_zigpy_connection(
|
||||||
|
mock.sentinel.usb_path, mock.sentinel.radio, mock.sentinel.zigbee_db
|
||||||
|
)
|
||||||
|
assert mock_radio.connect.call_count == 1
|
||||||
|
assert bad_radio.connect.call_count == 1
|
||||||
|
assert mock_ctrl.startup.call_count == 1
|
||||||
|
assert mock_ctrl.shutdown.call_count == 1
|
||||||
|
|
Loading…
Add table
Reference in a new issue