Clean tradfri hass data and add tests (#39620)
This commit is contained in:
parent
d128443a2a
commit
bde0bdbf80
9 changed files with 153 additions and 64 deletions
|
@ -1,6 +1,5 @@
|
|||
"""Support for IKEA Tradfri."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from pytradfri import Gateway, RequestError
|
||||
from pytradfri.api.aiocoap_api import APIFactory
|
||||
|
@ -31,9 +30,8 @@ from .const import (
|
|||
PLATFORMS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
FACTORY = "tradfri_factory"
|
||||
LISTENERS = "tradfri_listeners"
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
|
@ -98,6 +96,8 @@ async def async_setup(hass, config):
|
|||
async def async_setup_entry(hass, entry):
|
||||
"""Create a gateway."""
|
||||
# host, identity, key, allow_tradfri_groups
|
||||
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
||||
listeners = tradfri_data[LISTENERS] = []
|
||||
|
||||
factory = await APIFactory.init(
|
||||
entry.data[CONF_HOST],
|
||||
|
@ -109,7 +109,7 @@ async def async_setup_entry(hass, entry):
|
|||
"""Close connection when hass stops."""
|
||||
await factory.shutdown()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
|
||||
listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
|
||||
|
||||
api = factory.request
|
||||
gateway = Gateway()
|
||||
|
@ -120,9 +120,8 @@ async def async_setup_entry(hass, entry):
|
|||
await factory.shutdown()
|
||||
raise ConfigEntryNotReady from err
|
||||
|
||||
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api
|
||||
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway
|
||||
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
||||
tradfri_data[KEY_API] = api
|
||||
tradfri_data[KEY_GATEWAY] = gateway
|
||||
tradfri_data[FACTORY] = factory
|
||||
|
||||
dev_reg = await hass.helpers.device_registry.async_get_registry()
|
||||
|
@ -156,10 +155,11 @@ async def async_unload_entry(hass, entry):
|
|||
)
|
||||
)
|
||||
if unload_ok:
|
||||
hass.data[KEY_API].pop(entry.entry_id)
|
||||
hass.data[KEY_GATEWAY].pop(entry.entry_id)
|
||||
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
|
||||
factory = tradfri_data[FACTORY]
|
||||
await factory.shutdown()
|
||||
# unsubscribe listeners
|
||||
for listener in tradfri_data[LISTENERS]:
|
||||
listener()
|
||||
|
||||
return unload_ok
|
||||
|
|
|
@ -3,14 +3,15 @@
|
|||
from homeassistant.components.cover import ATTR_POSITION, CoverEntity
|
||||
|
||||
from .base_class import TradfriBaseDevice
|
||||
from .const import ATTR_MODEL, CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
||||
from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Load Tradfri covers based on a config entry."""
|
||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||
api = hass.data[KEY_API][config_entry.entry_id]
|
||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||
api = tradfri_data[KEY_API]
|
||||
gateway = tradfri_data[KEY_GATEWAY]
|
||||
|
||||
devices_commands = await api(gateway.get_devices())
|
||||
devices = await api(devices_commands)
|
||||
|
|
|
@ -21,6 +21,7 @@ from .const import (
|
|||
ATTR_TRANSITION_TIME,
|
||||
CONF_GATEWAY_ID,
|
||||
CONF_IMPORT_GROUPS,
|
||||
DOMAIN,
|
||||
KEY_API,
|
||||
KEY_GATEWAY,
|
||||
SUPPORTED_GROUP_FEATURES,
|
||||
|
@ -33,8 +34,9 @@ _LOGGER = logging.getLogger(__name__)
|
|||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Load Tradfri lights based on a config entry."""
|
||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||
api = hass.data[KEY_API][config_entry.entry_id]
|
||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||
api = tradfri_data[KEY_API]
|
||||
gateway = tradfri_data[KEY_GATEWAY]
|
||||
|
||||
devices_commands = await api(gateway.get_devices())
|
||||
devices = await api(devices_commands)
|
||||
|
|
|
@ -3,14 +3,15 @@
|
|||
from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE
|
||||
|
||||
from .base_class import TradfriBaseDevice
|
||||
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
||||
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Set up a Tradfri config entry."""
|
||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||
api = hass.data[KEY_API][config_entry.entry_id]
|
||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||
api = tradfri_data[KEY_API]
|
||||
gateway = tradfri_data[KEY_GATEWAY]
|
||||
|
||||
devices_commands = await api(gateway.get_devices())
|
||||
all_devices = await api(devices_commands)
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
from homeassistant.components.switch import SwitchEntity
|
||||
|
||||
from .base_class import TradfriBaseDevice
|
||||
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
|
||||
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Load Tradfri switches based on a config entry."""
|
||||
gateway_id = config_entry.data[CONF_GATEWAY_ID]
|
||||
api = hass.data[KEY_API][config_entry.entry_id]
|
||||
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
|
||||
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
|
||||
api = tradfri_data[KEY_API]
|
||||
gateway = tradfri_data[KEY_GATEWAY]
|
||||
|
||||
devices_commands = await api(gateway.get_devices())
|
||||
devices = await api(devices_commands)
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
"""Tests for the tradfri component."""
|
||||
MOCK_GATEWAY_ID = "mock-gateway-id"
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
"""Common tradfri test fixtures."""
|
||||
import pytest
|
||||
|
||||
from tests.async_mock import patch
|
||||
from . import MOCK_GATEWAY_ID
|
||||
|
||||
from tests.async_mock import Mock, patch
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -9,8 +13,8 @@ def mock_gateway_info():
|
|||
"""Mock get_gateway_info."""
|
||||
with patch(
|
||||
"homeassistant.components.tradfri.config_flow.get_gateway_info"
|
||||
) as mock_gateway:
|
||||
yield mock_gateway
|
||||
) as gateway_info:
|
||||
yield gateway_info
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -19,3 +23,64 @@ def mock_entry_setup():
|
|||
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
|
||||
mock_setup.return_value = True
|
||||
yield mock_setup
|
||||
|
||||
|
||||
@pytest.fixture(name="gateway_id")
|
||||
def mock_gateway_id_fixture():
|
||||
"""Return mock gateway_id."""
|
||||
return MOCK_GATEWAY_ID
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_gateway")
|
||||
def mock_gateway_fixture(gateway_id):
|
||||
"""Mock a Tradfri gateway."""
|
||||
|
||||
def get_devices():
|
||||
"""Return mock devices."""
|
||||
return gateway.mock_devices
|
||||
|
||||
def get_groups():
|
||||
"""Return mock groups."""
|
||||
return gateway.mock_groups
|
||||
|
||||
gateway_info = Mock(id=gateway_id, firmware_version="1.2.1234")
|
||||
|
||||
def get_gateway_info():
|
||||
"""Return mock gateway info."""
|
||||
return gateway_info
|
||||
|
||||
gateway = Mock(
|
||||
get_devices=get_devices,
|
||||
get_groups=get_groups,
|
||||
get_gateway_info=get_gateway_info,
|
||||
mock_devices=[],
|
||||
mock_groups=[],
|
||||
mock_responses=[],
|
||||
)
|
||||
with patch("homeassistant.components.tradfri.Gateway", return_value=gateway), patch(
|
||||
"homeassistant.components.tradfri.config_flow.Gateway", return_value=gateway
|
||||
):
|
||||
yield gateway
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_api")
|
||||
def mock_api_fixture(mock_gateway):
|
||||
"""Mock api."""
|
||||
|
||||
async def api(command):
|
||||
"""Mock api function."""
|
||||
# Store the data for "real" command objects.
|
||||
if hasattr(command, "_data") and not isinstance(command, Mock):
|
||||
mock_gateway.mock_responses.append(command._data)
|
||||
return command
|
||||
|
||||
return api
|
||||
|
||||
|
||||
@pytest.fixture(name="api_factory")
|
||||
def mock_api_factory_fixture(mock_api):
|
||||
"""Mock pytradfri api factory."""
|
||||
with patch("homeassistant.components.tradfri.APIFactory", autospec=True) as factory:
|
||||
factory.init.return_value = factory.return_value
|
||||
factory.return_value.request = mock_api
|
||||
yield factory.return_value
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
"""Tests for Tradfri setup."""
|
||||
from homeassistant.components import tradfri
|
||||
from homeassistant.helpers.device_registry import (
|
||||
async_entries_for_config_entry,
|
||||
async_get_registry as async_get_device_registry,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.async_mock import patch
|
||||
|
@ -48,13 +53,15 @@ async def test_config_json_host_not_imported(hass):
|
|||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup):
|
||||
async def test_config_json_host_imported(
|
||||
hass, mock_gateway_info, mock_entry_setup, gateway_id
|
||||
):
|
||||
"""Test that we import a configured host."""
|
||||
mock_gateway_info.side_effect = lambda hass, host, identity, key: {
|
||||
"host": host,
|
||||
"identity": identity,
|
||||
"key": key,
|
||||
"gateway_id": "mock-gateway",
|
||||
"gateway_id": gateway_id,
|
||||
}
|
||||
|
||||
with patch(
|
||||
|
@ -68,3 +75,45 @@ async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_set
|
|||
assert config_entry.domain == "tradfri"
|
||||
assert config_entry.source == "import"
|
||||
assert config_entry.title == "mock-host"
|
||||
|
||||
|
||||
async def test_entry_setup_unload(hass, api_factory, gateway_id):
|
||||
"""Test config entry setup and unload."""
|
||||
entry = MockConfigEntry(
|
||||
domain=tradfri.DOMAIN,
|
||||
data={
|
||||
tradfri.CONF_HOST: "mock-host",
|
||||
tradfri.CONF_IDENTITY: "mock-identity",
|
||||
tradfri.CONF_KEY: "mock-key",
|
||||
tradfri.CONF_IMPORT_GROUPS: True,
|
||||
tradfri.CONF_GATEWAY_ID: gateway_id,
|
||||
},
|
||||
)
|
||||
|
||||
entry.add_to_hass(hass)
|
||||
with patch.object(
|
||||
hass.config_entries, "async_forward_entry_setup", return_value=True
|
||||
) as setup:
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert setup.call_count == len(tradfri.PLATFORMS)
|
||||
|
||||
dev_reg = await async_get_device_registry(hass)
|
||||
dev_entries = async_entries_for_config_entry(dev_reg, entry.entry_id)
|
||||
|
||||
assert dev_entries
|
||||
dev_entry = dev_entries[0]
|
||||
assert dev_entry.identifiers == {
|
||||
(tradfri.DOMAIN, entry.data[tradfri.CONF_GATEWAY_ID])
|
||||
}
|
||||
assert dev_entry.manufacturer == tradfri.ATTR_TRADFRI_MANUFACTURER
|
||||
assert dev_entry.name == tradfri.ATTR_TRADFRI_GATEWAY
|
||||
assert dev_entry.model == tradfri.ATTR_TRADFRI_GATEWAY_MODEL
|
||||
|
||||
with patch.object(
|
||||
hass.config_entries, "async_forward_entry_unload", return_value=True
|
||||
) as unload:
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert unload.call_count == len(tradfri.PLATFORMS)
|
||||
assert api_factory.shutdown.call_count == 1
|
||||
|
|
|
@ -9,6 +9,8 @@ from pytradfri.device.light_control import LightControl
|
|||
|
||||
from homeassistant.components import tradfri
|
||||
|
||||
from . import MOCK_GATEWAY_ID
|
||||
|
||||
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -93,42 +95,6 @@ def setup(request):
|
|||
request.addfinalizer(teardown)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gateway():
|
||||
"""Mock a Tradfri gateway."""
|
||||
|
||||
def get_devices():
|
||||
"""Return mock devices."""
|
||||
return gateway.mock_devices
|
||||
|
||||
def get_groups():
|
||||
"""Return mock groups."""
|
||||
return gateway.mock_groups
|
||||
|
||||
gateway = Mock(
|
||||
get_devices=get_devices,
|
||||
get_groups=get_groups,
|
||||
mock_devices=[],
|
||||
mock_groups=[],
|
||||
mock_responses=[],
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api(mock_gateway):
|
||||
"""Mock api."""
|
||||
|
||||
async def api(command):
|
||||
"""Mock api function."""
|
||||
# Store the data for "real" command objects.
|
||||
if hasattr(command, "_data") and not isinstance(command, Mock):
|
||||
mock_gateway.mock_responses.append(command._data)
|
||||
return command
|
||||
|
||||
return api
|
||||
|
||||
|
||||
async def generate_psk(self, code):
|
||||
"""Mock psk."""
|
||||
return "mock"
|
||||
|
@ -143,11 +109,14 @@ async def setup_gateway(hass, mock_gateway, mock_api):
|
|||
"identity": "mock-identity",
|
||||
"key": "mock-key",
|
||||
"import_groups": True,
|
||||
"gateway_id": "mock-gateway-id",
|
||||
"gateway_id": MOCK_GATEWAY_ID,
|
||||
},
|
||||
)
|
||||
hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway}
|
||||
hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api}
|
||||
tradfri_data = {}
|
||||
hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data}
|
||||
tradfri_data[tradfri.KEY_API] = mock_api
|
||||
tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway
|
||||
|
||||
await hass.config_entries.async_forward_entry_setup(entry, "light")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue