Fix zha circular import (#29802)
* Refactor zha.core.helpers. * Make zha isort-able. * Const import reorg. * Keep DATA_ZHA config key on entry unload. * Cleanup ZHA config flow. * isort. * Add test.
This commit is contained in:
parent
12f273eb11
commit
315d0064fe
6 changed files with 113 additions and 73 deletions
|
@ -1,7 +1,4 @@
|
|||
"""Support for Zigbee Home Automation devices.
|
||||
|
||||
isort:skip_file
|
||||
"""
|
||||
"""Support for Zigbee Home Automation devices."""
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -11,7 +8,6 @@ from homeassistant import config_entries, const as ha_const
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.device_registry import CONNECTION_ZIGBEE
|
||||
|
||||
from . import config_flow # noqa: F401 pylint: disable=unused-import
|
||||
from . import api
|
||||
from .core import ZHAGateway
|
||||
from .core.const import (
|
||||
|
@ -147,5 +143,4 @@ async def async_unload_entry(hass, config_entry):
|
|||
for component in COMPONENTS:
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, component)
|
||||
|
||||
del hass.data[DATA_ZHA]
|
||||
return True
|
||||
|
|
|
@ -50,7 +50,11 @@ from .core.const import (
|
|||
WARNING_DEVICE_STROBE_HIGH,
|
||||
WARNING_DEVICE_STROBE_YES,
|
||||
)
|
||||
from .core.helpers import async_is_bindable_target, get_matched_clusters
|
||||
from .core.helpers import (
|
||||
async_get_device_info,
|
||||
async_is_bindable_target,
|
||||
get_matched_clusters,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -423,31 +427,6 @@ async def websocket_remove_group_members(hass, connection, msg):
|
|||
connection.send_result(msg[ID], ret_group)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_device_info(hass, device, ha_device_registry=None):
|
||||
"""Get ZHA device."""
|
||||
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
||||
ret_device = {}
|
||||
ret_device.update(device.device_info)
|
||||
ret_device["entities"] = [
|
||||
{
|
||||
"entity_id": entity_ref.reference_id,
|
||||
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
|
||||
}
|
||||
for entity_ref in zha_gateway.device_registry[device.ieee]
|
||||
]
|
||||
|
||||
if ha_device_registry is not None:
|
||||
reg_device = ha_device_registry.async_get_device(
|
||||
{(DOMAIN, str(device.ieee))}, set()
|
||||
)
|
||||
if reg_device is not None:
|
||||
ret_device["user_given_name"] = reg_device.name_by_user
|
||||
ret_device["device_reg_id"] = reg_device.id
|
||||
ret_device["area_id"] = reg_device.area_id
|
||||
return ret_device
|
||||
|
||||
|
||||
async def get_groups(hass,):
|
||||
"""Get ZHA Groups."""
|
||||
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Config flow for ZHA."""
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
|
@ -9,11 +10,14 @@ from homeassistant import config_entries
|
|||
from .core.const import (
|
||||
CONF_RADIO_TYPE,
|
||||
CONF_USB_PATH,
|
||||
CONTROLLER,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_DATABASE_NAME,
|
||||
DOMAIN,
|
||||
ZHA_GW_RADIO,
|
||||
RadioType,
|
||||
)
|
||||
from .core.helpers import check_zigpy_connection
|
||||
from .core.registries import RADIO_TYPES
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
|
@ -57,3 +61,20 @@ class ZhaFlowHandler(config_entries.ConfigFlow):
|
|||
return self.async_create_entry(
|
||||
title=import_info[CONF_USB_PATH], data=import_info
|
||||
)
|
||||
|
||||
|
||||
async def check_zigpy_connection(usb_path, radio_type, database_path):
|
||||
"""Test zigpy radio connection."""
|
||||
try:
|
||||
radio = RADIO_TYPES[radio_type][ZHA_GW_RADIO]()
|
||||
controller_application = RADIO_TYPES[radio_type][CONTROLLER]
|
||||
except KeyError:
|
||||
return False
|
||||
try:
|
||||
await radio.connect(usb_path, DEFAULT_BAUDRATE)
|
||||
controller = controller_application(radio, database_path)
|
||||
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
|
||||
await controller.shutdown()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -20,7 +20,6 @@ from homeassistant.helpers.device_registry import (
|
|||
)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
|
||||
from ..api import async_get_device_info
|
||||
from .const import (
|
||||
ATTR_IEEE,
|
||||
ATTR_MANUFACTURER,
|
||||
|
@ -65,6 +64,7 @@ from .const import (
|
|||
)
|
||||
from .device import DeviceStatus, ZHADevice
|
||||
from .discovery import async_dispatch_discovery_info, async_process_endpoint
|
||||
from .helpers import async_get_device_info
|
||||
from .patches import apply_application_controller_patch
|
||||
from .registries import RADIO_TYPES
|
||||
from .store import async_get_registry
|
||||
|
|
|
@ -4,29 +4,20 @@ Helpers for Zigbee Home Automation.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/integrations/zha/
|
||||
"""
|
||||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import bellows.ezsp
|
||||
import bellows.zigbee.application
|
||||
import zigpy.types
|
||||
import zigpy_deconz.api
|
||||
import zigpy_deconz.zigbee.application
|
||||
import zigpy_xbee.api
|
||||
import zigpy_xbee.zigbee.application
|
||||
import zigpy_zigate.api
|
||||
import zigpy_zigate.zigbee.application
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import (
|
||||
ATTR_NAME,
|
||||
CLUSTER_TYPE_IN,
|
||||
CLUSTER_TYPE_OUT,
|
||||
DATA_ZHA,
|
||||
DATA_ZHA_GATEWAY,
|
||||
DEFAULT_BAUDRATE,
|
||||
RadioType,
|
||||
DOMAIN,
|
||||
)
|
||||
from .registries import BINDABLE_CLUSTERS
|
||||
|
||||
|
@ -56,30 +47,6 @@ async def safe_read(
|
|||
return {}
|
||||
|
||||
|
||||
async def check_zigpy_connection(usb_path, radio_type, database_path):
|
||||
"""Test zigpy radio connection."""
|
||||
if radio_type == RadioType.ezsp.name:
|
||||
radio = bellows.ezsp.EZSP()
|
||||
ControllerApplication = bellows.zigbee.application.ControllerApplication
|
||||
elif radio_type == RadioType.xbee.name:
|
||||
radio = zigpy_xbee.api.XBee()
|
||||
ControllerApplication = zigpy_xbee.zigbee.application.ControllerApplication
|
||||
elif radio_type == RadioType.deconz.name:
|
||||
radio = zigpy_deconz.api.Deconz()
|
||||
ControllerApplication = zigpy_deconz.zigbee.application.ControllerApplication
|
||||
elif radio_type == RadioType.zigate.name:
|
||||
radio = zigpy_zigate.api.ZiGate()
|
||||
ControllerApplication = zigpy_zigate.zigbee.application.ControllerApplication
|
||||
try:
|
||||
await radio.connect(usb_path, DEFAULT_BAUDRATE)
|
||||
controller = ControllerApplication(radio, database_path)
|
||||
await asyncio.wait_for(controller.startup(auto_form=True), timeout=30)
|
||||
await controller.shutdown()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_attr_id_by_name(cluster, attr_name):
|
||||
"""Get the attribute id for a cluster attribute by its name."""
|
||||
return next(
|
||||
|
@ -164,3 +131,28 @@ class LogMixin:
|
|||
def error(self, msg, *args):
|
||||
"""Error level log."""
|
||||
return self.log(logging.ERROR, msg, *args)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_device_info(hass, device, ha_device_registry=None):
|
||||
"""Get ZHA device."""
|
||||
zha_gateway = hass.data[DATA_ZHA][DATA_ZHA_GATEWAY]
|
||||
ret_device = {}
|
||||
ret_device.update(device.device_info)
|
||||
ret_device["entities"] = [
|
||||
{
|
||||
"entity_id": entity_ref.reference_id,
|
||||
ATTR_NAME: entity_ref.device_info[ATTR_NAME],
|
||||
}
|
||||
for entity_ref in zha_gateway.device_registry[device.ieee]
|
||||
]
|
||||
|
||||
if ha_device_registry is not None:
|
||||
reg_device = ha_device_registry.async_get_device(
|
||||
{(DOMAIN, str(device.ieee))}, set()
|
||||
)
|
||||
if reg_device is not None:
|
||||
ret_device["user_given_name"] = reg_device.name_by_user
|
||||
ret_device["device_reg_id"] = reg_device.id
|
||||
ret_device["area_id"] = reg_device.area_id
|
||||
return ret_device
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
"""Tests for ZHA config flow."""
|
||||
from asynctest import patch
|
||||
from unittest import mock
|
||||
|
||||
import asynctest
|
||||
|
||||
from homeassistant.components.zha import config_flow
|
||||
from homeassistant.components.zha.core.const import DOMAIN
|
||||
from homeassistant.components.zha.core.const import CONTROLLER, DOMAIN, ZHA_GW_RADIO
|
||||
import homeassistant.components.zha.core.registries
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -12,7 +15,7 @@ async def test_user_flow(hass):
|
|||
flow = config_flow.ZhaFlowHandler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch(
|
||||
with asynctest.patch(
|
||||
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
||||
return_value=False,
|
||||
):
|
||||
|
@ -22,7 +25,7 @@ async def test_user_flow(hass):
|
|||
|
||||
assert result["errors"] == {"base": "cannot_connect"}
|
||||
|
||||
with patch(
|
||||
with asynctest.patch(
|
||||
"homeassistant.components.zha.config_flow" ".check_zigpy_connection",
|
||||
return_value=True,
|
||||
):
|
||||
|
@ -71,3 +74,53 @@ async def test_import_flow_existing_config_entry(hass):
|
|||
)
|
||||
|
||||
assert result["type"] == "abort"
|
||||
|
||||
|
||||
async def test_check_zigpy_connection():
|
||||
"""Test config flow validator."""
|
||||
|
||||
mock_radio = asynctest.MagicMock()
|
||||
mock_radio.connect = asynctest.CoroutineMock()
|
||||
radio_cls = asynctest.MagicMock(return_value=mock_radio)
|
||||
|
||||
bad_radio = asynctest.MagicMock()
|
||||
bad_radio.connect = asynctest.CoroutineMock(side_effect=Exception)
|
||||
bad_radio_cls = asynctest.MagicMock(return_value=bad_radio)
|
||||
|
||||
mock_ctrl = asynctest.MagicMock()
|
||||
mock_ctrl.startup = asynctest.CoroutineMock()
|
||||
mock_ctrl.shutdown = asynctest.CoroutineMock()
|
||||
ctrl_cls = asynctest.MagicMock(return_value=mock_ctrl)
|
||||
new_radios = {
|
||||
mock.sentinel.radio: {ZHA_GW_RADIO: radio_cls, CONTROLLER: ctrl_cls},
|
||||
mock.sentinel.bad_radio: {ZHA_GW_RADIO: bad_radio_cls, CONTROLLER: ctrl_cls},
|
||||
}
|
||||
|
||||
with mock.patch.dict(
|
||||
homeassistant.components.zha.core.registries.RADIO_TYPES, new_radios, clear=True
|
||||
):
|
||||
assert not await config_flow.check_zigpy_connection(
|
||||
mock.sentinel.usb_path, mock.sentinel.unk_radio, mock.sentinel.zigbee_db
|
||||
)
|
||||
assert mock_radio.connect.call_count == 0
|
||||
assert bad_radio.connect.call_count == 0
|
||||
assert mock_ctrl.startup.call_count == 0
|
||||
assert mock_ctrl.shutdown.call_count == 0
|
||||
|
||||
# unsuccessful radio connect
|
||||
assert not await config_flow.check_zigpy_connection(
|
||||
mock.sentinel.usb_path, mock.sentinel.bad_radio, mock.sentinel.zigbee_db
|
||||
)
|
||||
assert mock_radio.connect.call_count == 0
|
||||
assert bad_radio.connect.call_count == 1
|
||||
assert mock_ctrl.startup.call_count == 0
|
||||
assert mock_ctrl.shutdown.call_count == 0
|
||||
|
||||
# successful radio connect
|
||||
assert await config_flow.check_zigpy_connection(
|
||||
mock.sentinel.usb_path, mock.sentinel.radio, mock.sentinel.zigbee_db
|
||||
)
|
||||
assert mock_radio.connect.call_count == 1
|
||||
assert bad_radio.connect.call_count == 1
|
||||
assert mock_ctrl.startup.call_count == 1
|
||||
assert mock_ctrl.shutdown.call_count == 1
|
||||
|
|
Loading…
Add table
Reference in a new issue