Fix zha circular import (#29802)

* Refactor zha.core.helpers.
* Make zha isort-able.
* Const import reorg.
* Keep DATA_ZHA config key on entry unload.
* Cleanup ZHA config flow.
* isort.
* Add test.
This commit is contained in:
Alexei Chetroi 2019-12-10 00:00:04 -05:00 committed by GitHub
parent 12f273eb11
commit 315d0064fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 73 deletions

View file

@ -1,7 +1,4 @@
"""Support for Zigbee Home Automation devices.
isort:skip_file
"""
"""Support for Zigbee Home Automation devices."""
import logging
@ -11,7 +8,6 @@ from homeassistant import config_entries, const as ha_const
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE
from . import config_flow # noqa: F401 pylint: disable=unused-import
from . import api
from .core import ZHAGateway
from .core.const import (
@ -147,5 +143,4 @@ async def async_unload_entry(hass, config_entry):
for component in COMPONENTS:
await hass.config_entries.async_forward_entry_unload(config_entry, component)
del hass.data[DATA_ZHA]
return True

View file

@ -50,7 +50,11 @@ from .core.const import (
WARNING_DEVICE_STROBE_HIGH,
WARNING_DEVICE_STROBE_YES,
)
from .core.helpers import async_is_bindable_target, get_matched_clusters
from .core.helpers import (
async_get_device_info,
async_is_bindable_target,
get_matched_clusters,
)
_LOGGER = logging.getLogger(__name__)
@ -423,31 +427,6 @@ async def websocket_remove_group_members(hass, connection, msg):
connection.send_result(msg[ID], ret_group)
@callback
def async_get_device_info(hass, device, ha_device_registry=None):
"""Get ZHA device."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ret_device = {}
ret_device.update(device.device_info)
ret_device["entities"] = [
{
"entity_id": entity_ref.reference_id,
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
}
for entity_ref in zha_gateway.device_registry[device.ieee]
]
if ha_device_registry is not None:
reg_device = ha_device_registry.async_get_device(
{(DOMAIN, str(device.ieee))}, set()
)
if reg_device is not None:
ret_device["user_given_name"] = reg_device.name_by_user
ret_device["device_reg_id"] = reg_device.id
ret_device["area_id"] = reg_device.area_id
return ret_device
async def get_groups(hass,):
"""Get ZHA Groups."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]

View file

@ -1,4 +1,5 @@
"""Config flow for ZHA."""
import asyncio
from collections import OrderedDict
import os
@ -9,11 +10,14 @@ from homeassistant import config_entries
from .core.const import (
CONF_RADIO_TYPE,
CONF_USB_PATH,
CONTROLLER,
DEFAULT_BAUDRATE,
DEFAULT_DATABASE_NAME,
DOMAIN,
ZHA_GW_RADIO,
RadioType,
)
from .core.helpers import check_zigpy_connection
from .core.registries import RADIO_TYPES
@config_entries.HANDLERS.register(DOMAIN)
@ -57,3 +61,20 @@ class ZhaFlowHandler(config_entries.ConfigFlow):
return self.async_create_entry(
title=import_info[CONF_USB_PATH], data=import_info
)
async def check_zigpy_connection(usb_path, radio_type, database_path):
"""Test zigpy radio connection."""
try:
radio = RADIO_TYPES[radio_type][ZHA_GW_RADIO]()
controller_application = RADIO_TYPES[radio_type][CONTROLLER]
except KeyError:
return False
try:
await radio.connect(usb_path, DEFAULT_BAUDRATE)
controller = controller_application(radio, database_path)
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
await controller.shutdown()
except Exception: # pylint: disable=broad-except
return False
return True

View file

@ -20,7 +20,6 @@ from homeassistant.helpers.device_registry import (
)
from homeassistant.helpers.dispatcher import async_dispatcher_send
from ..api import async_get_device_info
from .const import (
ATTR_IEEE,
ATTR_MANUFACTURER,
@ -65,6 +64,7 @@ from .const import (
)
from .device import DeviceStatus, ZHADevice
from .discovery import async_dispatch_discovery_info, async_process_endpoint
from .helpers import async_get_device_info
from .patches import apply_application_controller_patch
from .registries import RADIO_TYPES
from .store import async_get_registry

View file

@ -4,29 +4,20 @@ Helpers for Zigbee Home Automation.
For more details about this component, please refer to the documentation at
https://home-assistant.io/integrations/zha/
"""
import asyncio
import collections
import logging
import bellows.ezsp
import bellows.zigbee.application
import zigpy.types
import zigpy_deconz.api
import zigpy_deconz.zigbee.application
import zigpy_xbee.api
import zigpy_xbee.zigbee.application
import zigpy_zigate.api
import zigpy_zigate.zigbee.application
from homeassistant.core import callback
from .const import (
ATTR_NAME,
CLUSTER_TYPE_IN,
CLUSTER_TYPE_OUT,
DATA_ZHA,
DATA_ZHA_GATEWAY,
DEFAULT_BAUDRATE,
RadioType,
DOMAIN,
)
from .registries import BINDABLE_CLUSTERS
@ -56,30 +47,6 @@ async def safe_read(
return {}
async def check_zigpy_connection(usb_path, radio_type, database_path):
"""Test zigpy radio connection."""
if radio_type == RadioType.ezsp.name:
radio = bellows.ezsp.EZSP()
ControllerApplication = bellows.zigbee.application.ControllerApplication
elif radio_type == RadioType.xbee.name:
radio = zigpy_xbee.api.XBee()
ControllerApplication = zigpy_xbee.zigbee.application.ControllerApplication
elif radio_type == RadioType.deconz.name:
radio = zigpy_deconz.api.Deconz()
ControllerApplication = zigpy_deconz.zigbee.application.ControllerApplication
elif radio_type == RadioType.zigate.name:
radio = zigpy_zigate.api.ZiGate()
ControllerApplication = zigpy_zigate.zigbee.application.ControllerApplication
try:
await radio.connect(usb_path, DEFAULT_BAUDRATE)
controller = ControllerApplication(radio, database_path)
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
await controller.shutdown()
except Exception: # pylint: disable=broad-except
return False
return True
def get_attr_id_by_name(cluster, attr_name):
"""Get the attribute id for a cluster attribute by its name."""
return next(
@ -164,3 +131,28 @@ class LogMixin:
def error(self, msg, *args):
"""Error level log."""
return self.log(logging.ERROR, msg, *args)
@callback
def async_get_device_info(hass, device, ha_device_registry=None):
"""Get ZHA device."""
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
ret_device = {}
ret_device.update(device.device_info)
ret_device["entities"] = [
{
"entity_id": entity_ref.reference_id,
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
}
for entity_ref in zha_gateway.device_registry[device.ieee]
]
if ha_device_registry is not None:
reg_device = ha_device_registry.async_get_device(
{(DOMAIN, str(device.ieee))}, set()
)
if reg_device is not None:
ret_device["user_given_name"] = reg_device.name_by_user
ret_device["device_reg_id"] = reg_device.id
ret_device["area_id"] = reg_device.area_id
return ret_device

View file

@ -1,8 +1,11 @@
"""Tests for ZHA config flow."""
from asynctest import patch
from unittest import mock
import asynctest
from homeassistant.components.zha import config_flow
from homeassistant.components.zha.core.const import DOMAIN
from homeassistant.components.zha.core.const import CONTROLLER, DOMAIN, ZHA_GW_RADIO
import homeassistant.components.zha.core.registries
from tests.common import MockConfigEntry
@ -12,7 +15,7 @@ async def test_user_flow(hass):
flow = config_flow.ZhaFlowHandler()
flow.hass = hass
with patch(
with asynctest.patch(
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
return_value=False,
):
@ -22,7 +25,7 @@ async def test_user_flow(hass):
assert result["errors"] == {"base": "cannot_connect"}
with patch(
with asynctest.patch(
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
return_value=True,
):
@ -71,3 +74,53 @@ async def test_import_flow_existing_config_entry(hass):
)
assert result["type"] == "abort"
async def test_check_zigpy_connection():
"""Test config flow validator."""
mock_radio = asynctest.MagicMock()
mock_radio.connect = asynctest.CoroutineMock()
radio_cls = asynctest.MagicMock(return_value=mock_radio)
bad_radio = asynctest.MagicMock()
bad_radio.connect = asynctest.CoroutineMock(side_effect=Exception)
bad_radio_cls = asynctest.MagicMock(return_value=bad_radio)
mock_ctrl = asynctest.MagicMock()
mock_ctrl.startup = asynctest.CoroutineMock()
mock_ctrl.shutdown = asynctest.CoroutineMock()
ctrl_cls = asynctest.MagicMock(return_value=mock_ctrl)
new_radios = {
mock.sentinel.radio: {ZHA_GW_RADIO: radio_cls, CONTROLLER: ctrl_cls},
mock.sentinel.bad_radio: {ZHA_GW_RADIO: bad_radio_cls, CONTROLLER: ctrl_cls},
}
with mock.patch.dict(
homeassistant.components.zha.core.registries.RADIO_TYPES, new_radios, clear=True
):
assert not await config_flow.check_zigpy_connection(
mock.sentinel.usb_path, mock.sentinel.unk_radio, mock.sentinel.zigbee_db
)
assert mock_radio.connect.call_count == 0
assert bad_radio.connect.call_count == 0
assert mock_ctrl.startup.call_count == 0
assert mock_ctrl.shutdown.call_count == 0
# unsuccessful radio connect
assert not await config_flow.check_zigpy_connection(
mock.sentinel.usb_path, mock.sentinel.bad_radio, mock.sentinel.zigbee_db
)
assert mock_radio.connect.call_count == 0
assert bad_radio.connect.call_count == 1
assert mock_ctrl.startup.call_count == 0
assert mock_ctrl.shutdown.call_count == 0
# successful radio connect
assert await config_flow.check_zigpy_connection(
mock.sentinel.usb_path, mock.sentinel.radio, mock.sentinel.zigbee_db
)
assert mock_radio.connect.call_count == 1
assert bad_radio.connect.call_count == 1
assert mock_ctrl.startup.call_count == 1
assert mock_ctrl.shutdown.call_count == 1