Clean tradfri hass data and add tests (#39620)

This commit is contained in:
Martin Hjelmare 2020-09-03 18:39:24 +02:00 committed by GitHub
parent d128443a2a
commit bde0bdbf80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 153 additions and 64 deletions

View file

@ -1,6 +1,5 @@
"""Support for IKEA Tradfri."""
import asyncio
import logging
from pytradfri import Gateway, RequestError
from pytradfri.api.aiocoap_api import APIFactory
@ -31,9 +30,8 @@ from .const import (
PLATFORMS,
)
_LOGGER = logging.getLogger(__name__)
FACTORY = "tradfri_factory"
LISTENERS = "tradfri_listeners"
CONFIG_SCHEMA = vol.Schema(
{
@ -98,6 +96,8 @@ async def async_setup(hass, config):
async def async_setup_entry(hass, entry):
"""Create a gateway."""
# host, identity, key, allow_tradfri_groups
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
listeners = tradfri_data[LISTENERS] = []
factory = await APIFactory.init(
entry.data[CONF_HOST],
@ -109,7 +109,7 @@ async def async_setup_entry(hass, entry):
"""Close connection when hass stops."""
await factory.shutdown()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
api = factory.request
gateway = Gateway()
@ -120,9 +120,8 @@ async def async_setup_entry(hass, entry):
await factory.shutdown()
raise ConfigEntryNotReady from err
hass.data.setdefault(KEY_API, {})[entry.entry_id] = api
hass.data.setdefault(KEY_GATEWAY, {})[entry.entry_id] = gateway
tradfri_data = hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
tradfri_data[KEY_API] = api
tradfri_data[KEY_GATEWAY] = gateway
tradfri_data[FACTORY] = factory
dev_reg = await hass.helpers.device_registry.async_get_registry()
@ -156,10 +155,11 @@ async def async_unload_entry(hass, entry):
)
)
if unload_ok:
hass.data[KEY_API].pop(entry.entry_id)
hass.data[KEY_GATEWAY].pop(entry.entry_id)
tradfri_data = hass.data[DOMAIN].pop(entry.entry_id)
factory = tradfri_data[FACTORY]
await factory.shutdown()
# unsubscribe listeners
for listener in tradfri_data[LISTENERS]:
listener()
return unload_ok

View file

@ -3,14 +3,15 @@
from homeassistant.components.cover import ATTR_POSITION, CoverEntity
from .base_class import TradfriBaseDevice
from .const import ATTR_MODEL, CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands)

View file

@ -21,6 +21,7 @@ from .const import (
ATTR_TRANSITION_TIME,
CONF_GATEWAY_ID,
CONF_IMPORT_GROUPS,
DOMAIN,
KEY_API,
KEY_GATEWAY,
SUPPORTED_GROUP_FEATURES,
@ -33,8 +34,9 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands)

View file

@ -3,14 +3,15 @@
from homeassistant.const import DEVICE_CLASS_BATTERY, UNIT_PERCENTAGE
from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices())
all_devices = await api(devices_commands)

View file

@ -2,14 +2,15 @@
from homeassistant.components.switch import SwitchEntity
from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, KEY_API, KEY_GATEWAY
from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
api = hass.data[KEY_API][config_entry.entry_id]
gateway = hass.data[KEY_GATEWAY][config_entry.entry_id]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
gateway = tradfri_data[KEY_GATEWAY]
devices_commands = await api(gateway.get_devices())
devices = await api(devices_commands)

View file

@ -1 +1,2 @@
"""Tests for the tradfri component."""
MOCK_GATEWAY_ID = "mock-gateway-id"

View file

@ -1,7 +1,11 @@
"""Common tradfri test fixtures."""
import pytest
from tests.async_mock import patch
from . import MOCK_GATEWAY_ID
from tests.async_mock import Mock, patch
# pylint: disable=protected-access
@pytest.fixture
@ -9,8 +13,8 @@ def mock_gateway_info():
"""Mock get_gateway_info."""
with patch(
"homeassistant.components.tradfri.config_flow.get_gateway_info"
) as mock_gateway:
yield mock_gateway
) as gateway_info:
yield gateway_info
@pytest.fixture
@ -19,3 +23,64 @@ def mock_entry_setup():
with patch("homeassistant.components.tradfri.async_setup_entry") as mock_setup:
mock_setup.return_value = True
yield mock_setup
@pytest.fixture(name="gateway_id")
def mock_gateway_id_fixture():
"""Return mock gateway_id."""
return MOCK_GATEWAY_ID
@pytest.fixture(name="mock_gateway")
def mock_gateway_fixture(gateway_id):
"""Mock a Tradfri gateway."""
def get_devices():
"""Return mock devices."""
return gateway.mock_devices
def get_groups():
"""Return mock groups."""
return gateway.mock_groups
gateway_info = Mock(id=gateway_id, firmware_version="1.2.1234")
def get_gateway_info():
"""Return mock gateway info."""
return gateway_info
gateway = Mock(
get_devices=get_devices,
get_groups=get_groups,
get_gateway_info=get_gateway_info,
mock_devices=[],
mock_groups=[],
mock_responses=[],
)
with patch("homeassistant.components.tradfri.Gateway", return_value=gateway), patch(
"homeassistant.components.tradfri.config_flow.Gateway", return_value=gateway
):
yield gateway
@pytest.fixture(name="mock_api")
def mock_api_fixture(mock_gateway):
"""Mock api."""
async def api(command):
"""Mock api function."""
# Store the data for "real" command objects.
if hasattr(command, "_data") and not isinstance(command, Mock):
mock_gateway.mock_responses.append(command._data)
return command
return api
@pytest.fixture(name="api_factory")
def mock_api_factory_fixture(mock_api):
"""Mock pytradfri api factory."""
with patch("homeassistant.components.tradfri.APIFactory", autospec=True) as factory:
factory.init.return_value = factory.return_value
factory.return_value.request = mock_api
yield factory.return_value

View file

@ -1,4 +1,9 @@
"""Tests for Tradfri setup."""
from homeassistant.components import tradfri
from homeassistant.helpers.device_registry import (
async_entries_for_config_entry,
async_get_registry as async_get_device_registry,
)
from homeassistant.setup import async_setup_component
from tests.async_mock import patch
@ -48,13 +53,15 @@ async def test_config_json_host_not_imported(hass):
assert len(mock_init.mock_calls) == 0
async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_setup):
async def test_config_json_host_imported(
hass, mock_gateway_info, mock_entry_setup, gateway_id
):
"""Test that we import a configured host."""
mock_gateway_info.side_effect = lambda hass, host, identity, key: {
"host": host,
"identity": identity,
"key": key,
"gateway_id": "mock-gateway",
"gateway_id": gateway_id,
}
with patch(
@ -68,3 +75,45 @@ async def test_config_json_host_imported(hass, mock_gateway_info, mock_entry_set
assert config_entry.domain == "tradfri"
assert config_entry.source == "import"
assert config_entry.title == "mock-host"
async def test_entry_setup_unload(hass, api_factory, gateway_id):
"""Test config entry setup and unload."""
entry = MockConfigEntry(
domain=tradfri.DOMAIN,
data={
tradfri.CONF_HOST: "mock-host",
tradfri.CONF_IDENTITY: "mock-identity",
tradfri.CONF_KEY: "mock-key",
tradfri.CONF_IMPORT_GROUPS: True,
tradfri.CONF_GATEWAY_ID: gateway_id,
},
)
entry.add_to_hass(hass)
with patch.object(
hass.config_entries, "async_forward_entry_setup", return_value=True
) as setup:
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert setup.call_count == len(tradfri.PLATFORMS)
dev_reg = await async_get_device_registry(hass)
dev_entries = async_entries_for_config_entry(dev_reg, entry.entry_id)
assert dev_entries
dev_entry = dev_entries[0]
assert dev_entry.identifiers == {
(tradfri.DOMAIN, entry.data[tradfri.CONF_GATEWAY_ID])
}
assert dev_entry.manufacturer == tradfri.ATTR_TRADFRI_MANUFACTURER
assert dev_entry.name == tradfri.ATTR_TRADFRI_GATEWAY
assert dev_entry.model == tradfri.ATTR_TRADFRI_GATEWAY_MODEL
with patch.object(
hass.config_entries, "async_forward_entry_unload", return_value=True
) as unload:
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert unload.call_count == len(tradfri.PLATFORMS)
assert api_factory.shutdown.call_count == 1

View file

@ -9,6 +9,8 @@ from pytradfri.device.light_control import LightControl
from homeassistant.components import tradfri
from . import MOCK_GATEWAY_ID
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
from tests.common import MockConfigEntry
@ -93,42 +95,6 @@ def setup(request):
request.addfinalizer(teardown)
@pytest.fixture
def mock_gateway():
"""Mock a Tradfri gateway."""
def get_devices():
"""Return mock devices."""
return gateway.mock_devices
def get_groups():
"""Return mock groups."""
return gateway.mock_groups
gateway = Mock(
get_devices=get_devices,
get_groups=get_groups,
mock_devices=[],
mock_groups=[],
mock_responses=[],
)
return gateway
@pytest.fixture
def mock_api(mock_gateway):
"""Mock api."""
async def api(command):
"""Mock api function."""
# Store the data for "real" command objects.
if hasattr(command, "_data") and not isinstance(command, Mock):
mock_gateway.mock_responses.append(command._data)
return command
return api
async def generate_psk(self, code):
"""Mock psk."""
return "mock"
@ -143,11 +109,14 @@ async def setup_gateway(hass, mock_gateway, mock_api):
"identity": "mock-identity",
"key": "mock-key",
"import_groups": True,
"gateway_id": "mock-gateway-id",
"gateway_id": MOCK_GATEWAY_ID,
},
)
hass.data[tradfri.KEY_GATEWAY] = {entry.entry_id: mock_gateway}
hass.data[tradfri.KEY_API] = {entry.entry_id: mock_api}
tradfri_data = {}
hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data}
tradfri_data[tradfri.KEY_API] = mock_api
tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway
await hass.config_entries.async_forward_entry_setup(entry, "light")
await hass.async_block_till_done()