diff --git a/homeassistant/components/tradfri/__init__.py b/homeassistant/components/tradfri/__init__.py index d3caaf54762..002231e4e20 100644 --- a/homeassistant/components/tradfri/__init__.py +++ b/homeassistant/components/tradfri/__init__.py @@ -24,9 +24,10 @@ from .const import ( CONF_KEY, CONFIG_FILE, DEFAULT_ALLOW_TRADFRI_GROUPS, + DEVICES, DOMAIN, + GROUPS, KEY_API, - KEY_GATEWAY, PLATFORMS, ) @@ -116,13 +117,18 @@ async def async_setup_entry(hass, entry): try: gateway_info = await api(gateway.get_gateway_info()) + devices_commands = await api(gateway.get_devices()) + devices = await api(devices_commands) + groups_commands = await api(gateway.get_groups()) + groups = await api(groups_commands) except RequestError as err: await factory.shutdown() raise ConfigEntryNotReady from err tradfri_data[KEY_API] = api - tradfri_data[KEY_GATEWAY] = gateway tradfri_data[FACTORY] = factory + tradfri_data[DEVICES] = devices + tradfri_data[GROUPS] = groups dev_reg = await hass.helpers.device_registry.async_get_registry() dev_reg.async_get_or_create( diff --git a/homeassistant/components/tradfri/base_class.py b/homeassistant/components/tradfri/base_class.py index 0850bec6c9b..0c9f2f7312f 100644 --- a/homeassistant/components/tradfri/base_class.py +++ b/homeassistant/components/tradfri/base_class.py @@ -1,4 +1,5 @@ """Base class for IKEA TRADFRI.""" +from functools import wraps import logging from pytradfri.error import PytradfriError @@ -11,6 +12,20 @@ from .const import DOMAIN _LOGGER = logging.getLogger(__name__) +def handle_error(func): + """Handle tradfri api call error.""" + + @wraps(func) + async def wrapper(command): + """Decorate api call.""" + try: + await func(command) + except PytradfriError as err: + _LOGGER.error("Unable to execute command %s: %s", command, err) + + return wrapper + + class TradfriBaseClass(Entity): """Base class for IKEA TRADFRI. @@ -19,7 +34,7 @@ class TradfriBaseClass(Entity): def __init__(self, device, api, gateway_id): """Initialize a device.""" - self._api = api + self._api = handle_error(api) self._device = None self._device_control = None self._device_data = None diff --git a/homeassistant/components/tradfri/const.py b/homeassistant/components/tradfri/const.py index 423620ecb2b..f7c2bf6cbe5 100644 --- a/homeassistant/components/tradfri/const.py +++ b/homeassistant/components/tradfri/const.py @@ -19,7 +19,8 @@ CONFIG_FILE = ".tradfri_psk.conf" DEFAULT_ALLOW_TRADFRI_GROUPS = False DOMAIN = "tradfri" KEY_API = "tradfri_api" -KEY_GATEWAY = "tradfri_gateway" +DEVICES = "tradfri_devices" +GROUPS = "tradfri_groups" KEY_SECURITY_CODE = "security_code" SUPPORTED_GROUP_FEATURES = SUPPORT_BRIGHTNESS | SUPPORT_TRANSITION SUPPORTED_LIGHT_FEATURES = SUPPORT_TRANSITION diff --git a/homeassistant/components/tradfri/cover.py b/homeassistant/components/tradfri/cover.py index cab7b6bbab7..2d99de7756a 100644 --- a/homeassistant/components/tradfri/cover.py +++ b/homeassistant/components/tradfri/cover.py @@ -3,7 +3,7 @@ from homeassistant.components.cover import ATTR_POSITION, CoverEntity from .base_class import TradfriBaseDevice -from .const import ATTR_MODEL, CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY +from .const import ATTR_MODEL, CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API async def async_setup_entry(hass, config_entry, async_add_entities): @@ -11,10 +11,8 @@ async def async_setup_entry(hass, config_entry, async_add_entities): gateway_id = config_entry.data[CONF_GATEWAY_ID] tradfri_data = hass.data[DOMAIN][config_entry.entry_id] api = tradfri_data[KEY_API] - gateway = tradfri_data[KEY_GATEWAY] + devices = tradfri_data[DEVICES] - devices_commands = await api(gateway.get_devices()) - devices = await api(devices_commands) covers = [dev for dev in devices if dev.has_blind_control] if covers: async_add_entities(TradfriCover(cover, api, gateway_id) for cover in covers) diff --git a/homeassistant/components/tradfri/light.py b/homeassistant/components/tradfri/light.py index 29e096b2c49..939968852d9 100644 --- a/homeassistant/components/tradfri/light.py +++ b/homeassistant/components/tradfri/light.py @@ -21,9 +21,10 @@ from .const import ( ATTR_TRANSITION_TIME, CONF_GATEWAY_ID, CONF_IMPORT_GROUPS, + DEVICES, DOMAIN, + GROUPS, KEY_API, - KEY_GATEWAY, SUPPORTED_GROUP_FEATURES, SUPPORTED_LIGHT_FEATURES, ) @@ -36,17 +37,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities): gateway_id = config_entry.data[CONF_GATEWAY_ID] tradfri_data = hass.data[DOMAIN][config_entry.entry_id] api = tradfri_data[KEY_API] - gateway = tradfri_data[KEY_GATEWAY] + devices = tradfri_data[DEVICES] - devices_commands = await api(gateway.get_devices()) - devices = await api(devices_commands) lights = [dev for dev in devices if dev.has_light_control] if lights: async_add_entities(TradfriLight(light, api, gateway_id) for light in lights) if config_entry.data[CONF_IMPORT_GROUPS]: - groups_commands = await api(gateway.get_groups()) - groups = await api(groups_commands) + groups = tradfri_data[GROUPS] if groups: async_add_entities(TradfriGroup(group, api, gateway_id) for group in groups) diff --git a/homeassistant/components/tradfri/sensor.py b/homeassistant/components/tradfri/sensor.py index e82e352c009..c2bf640e2aa 100644 --- a/homeassistant/components/tradfri/sensor.py +++ b/homeassistant/components/tradfri/sensor.py @@ -3,7 +3,7 @@ from homeassistant.const import DEVICE_CLASS_BATTERY, PERCENTAGE from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY +from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API async def async_setup_entry(hass, config_entry, async_add_entities): @@ -11,20 +11,18 @@ async def async_setup_entry(hass, config_entry, async_add_entities): gateway_id = config_entry.data[CONF_GATEWAY_ID] tradfri_data = hass.data[DOMAIN][config_entry.entry_id] api = tradfri_data[KEY_API] - gateway = tradfri_data[KEY_GATEWAY] + devices = tradfri_data[DEVICES] - devices_commands = await api(gateway.get_devices()) - all_devices = await api(devices_commands) - devices = ( + sensors = ( dev - for dev in all_devices + for dev in devices if not dev.has_light_control and not dev.has_socket_control and not dev.has_blind_control and not dev.has_signal_repeater_control ) - if devices: - async_add_entities(TradfriSensor(device, api, gateway_id) for device in devices) + if sensors: + async_add_entities(TradfriSensor(sensor, api, gateway_id) for sensor in sensors) class TradfriSensor(TradfriBaseDevice): diff --git a/homeassistant/components/tradfri/switch.py b/homeassistant/components/tradfri/switch.py index 5bc5e6ab8e8..6634090d00d 100644 --- a/homeassistant/components/tradfri/switch.py +++ b/homeassistant/components/tradfri/switch.py @@ -2,7 +2,7 @@ from homeassistant.components.switch import SwitchEntity from .base_class import TradfriBaseDevice -from .const import CONF_GATEWAY_ID, DOMAIN, KEY_API, KEY_GATEWAY +from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API async def async_setup_entry(hass, config_entry, async_add_entities): @@ -10,10 +10,8 @@ async def async_setup_entry(hass, config_entry, async_add_entities): gateway_id = config_entry.data[CONF_GATEWAY_ID] tradfri_data = hass.data[DOMAIN][config_entry.entry_id] api = tradfri_data[KEY_API] - gateway = tradfri_data[KEY_GATEWAY] + devices = tradfri_data[DEVICES] - devices_commands = await api(gateway.get_devices()) - devices = await api(devices_commands) switches = [dev for dev in devices if dev.has_socket_control] if switches: async_add_entities( diff --git a/tests/components/tradfri/test_light.py b/tests/components/tradfri/test_light.py index b4c209c1493..653a9ce62df 100644 --- a/tests/components/tradfri/test_light.py +++ b/tests/components/tradfri/test_light.py @@ -100,7 +100,7 @@ async def generate_psk(self, code): return "mock" -async def setup_gateway(hass, mock_gateway, mock_api): +async def setup_integration(hass): """Load the Tradfri platform with a mock gateway.""" entry = MockConfigEntry( domain=tradfri.DOMAIN, @@ -112,43 +112,44 @@ async def setup_gateway(hass, mock_gateway, mock_api): "gateway_id": MOCK_GATEWAY_ID, }, ) - tradfri_data = {} - hass.data[tradfri.DOMAIN] = {entry.entry_id: tradfri_data} - tradfri_data[tradfri.KEY_API] = mock_api - tradfri_data[tradfri.KEY_GATEWAY] = mock_gateway - await hass.config_entries.async_forward_entry_setup(entry, "light") + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() -def mock_light(test_features={}, test_state={}, n=0): +def mock_light(test_features=None, test_state=None, light_number=0): """Mock a tradfri light.""" + if test_features is None: + test_features = {} + if test_state is None: + test_state = {} mock_light_data = Mock(**test_state) dev_info_mock = MagicMock() dev_info_mock.manufacturer = "manufacturer" dev_info_mock.model_number = "model" dev_info_mock.firmware_version = "1.2.3" - mock_light = Mock( - id=f"mock-light-id-{n}", + _mock_light = Mock( + id=f"mock-light-id-{light_number}", reachable=True, observe=Mock(), device_info=dev_info_mock, ) - mock_light.name = f"tradfri_light_{n}" + _mock_light.name = f"tradfri_light_{light_number}" # Set supported features for the light. features = {**DEFAULT_TEST_FEATURES, **test_features} - lc = LightControl(mock_light) - for k, v in features.items(): - setattr(lc, k, v) + light_control = LightControl(_mock_light) + for attr, value in features.items(): + setattr(light_control, attr, value) # Store the initial state. - setattr(lc, "lights", [mock_light_data]) - mock_light.light_control = lc - return mock_light + setattr(light_control, "lights", [mock_light_data]) + _mock_light.light_control = light_control + return _mock_light -async def test_light(hass, mock_gateway, mock_api): +async def test_light(hass, mock_gateway, api_factory): """Test that lights are correctly added.""" features = {"can_set_dimmer": True, "can_set_color": True, "can_set_temp": True} @@ -162,7 +163,7 @@ async def test_light(hass, mock_gateway, mock_api): mock_gateway.mock_devices.append( mock_light(test_features=features, test_state=state) ) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) lamp_1 = hass.states.get("light.tradfri_light_0") assert lamp_1 is not None @@ -171,48 +172,60 @@ async def test_light(hass, mock_gateway, mock_api): assert lamp_1.attributes["hs_color"] == (0.549, 0.153) -async def test_light_observed(hass, mock_gateway, mock_api): +async def test_light_observed(hass, mock_gateway, api_factory): """Test that lights are correctly observed.""" light = mock_light() mock_gateway.mock_devices.append(light) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) assert len(light.observe.mock_calls) > 0 -async def test_light_available(hass, mock_gateway, mock_api): +async def test_light_available(hass, mock_gateway, api_factory): """Test light available property.""" - light = mock_light({"state": True}, n=1) + light = mock_light({"state": True}, light_number=1) light.reachable = True - light2 = mock_light({"state": True}, n=2) + light2 = mock_light({"state": True}, light_number=2) light2.reachable = False mock_gateway.mock_devices.append(light) mock_gateway.mock_devices.append(light2) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) assert hass.states.get("light.tradfri_light_1").state == "on" assert hass.states.get("light.tradfri_light_2").state == "unavailable" -# Combine TURN_ON_TEST_CASES and TRANSITION_CASES_FOR_TESTS -ALL_TURN_ON_TEST_CASES = [["test_features", "test_data", "expected_result", "id"], []] +def create_all_turn_on_cases(): + """Create all turn on test cases.""" + # Combine TURN_ON_TEST_CASES and TRANSITION_CASES_FOR_TESTS + all_turn_on_test_cases = [ + ["test_features", "test_data", "expected_result", "device_id"], + [], + ] + index = 1 + for test_case in TURN_ON_TEST_CASES: + for trans in TRANSITION_CASES_FOR_TESTS: + case = deepcopy(test_case) + if trans is not None: + case[1]["transition"] = trans + case.append(index) + index += 1 + all_turn_on_test_cases[1].append(case) -idx = 1 -for tc in TURN_ON_TEST_CASES: - for trans in TRANSITION_CASES_FOR_TESTS: - case = deepcopy(tc) - if trans is not None: - case[1]["transition"] = trans - case.append(idx) - idx = idx + 1 - ALL_TURN_ON_TEST_CASES[1].append(case) + return all_turn_on_test_cases -@pytest.mark.parametrize(*ALL_TURN_ON_TEST_CASES) +@pytest.mark.parametrize(*create_all_turn_on_cases()) async def test_turn_on( - hass, mock_gateway, mock_api, test_features, test_data, expected_result, id + hass, + mock_gateway, + api_factory, + test_features, + test_data, + expected_result, + device_id, ): """Test turning on a light.""" # Note pytradfri style, not hass. Values not really important. @@ -224,15 +237,17 @@ async def test_turn_on( } # Setup the gateway with a mock light. - light = mock_light(test_features=test_features, test_state=initial_state, n=id) + light = mock_light( + test_features=test_features, test_state=initial_state, light_number=device_id + ) mock_gateway.mock_devices.append(light) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) # Use the turn_on service call to change the light state. await hass.services.async_call( "light", "turn_on", - {"entity_id": f"light.tradfri_light_{id}", **test_data}, + {"entity_id": f"light.tradfri_light_{device_id}", **test_data}, blocking=True, ) await hass.async_block_till_done() @@ -243,39 +258,39 @@ async def test_turn_on( _, callkwargs = mock_func.call_args assert "callback" in callkwargs # Callback function to refresh light state. - cb = callkwargs["callback"] + callback = callkwargs["callback"] responses = mock_gateway.mock_responses # State on command data. data = {"3311": [{"5850": 1}]} # Add data for all sent commands. - for r in responses: - data["3311"][0] = {**data["3311"][0], **r["3311"][0]} + for resp in responses: + data["3311"][0] = {**data["3311"][0], **resp["3311"][0]} # Use the callback function to update the light state. dev = Device(data) light_data = Light(dev, 0) light.light_control.lights[0] = light_data - cb(light) + callback(light) await hass.async_block_till_done() # Check that the state is correct. - states = hass.states.get(f"light.tradfri_light_{id}") - for k, v in expected_result.items(): - if k == "state": - assert states.state == v + states = hass.states.get(f"light.tradfri_light_{device_id}") + for result, value in expected_result.items(): + if result == "state": + assert states.state == value else: # Allow some rounding error in color conversions. - assert states.attributes[k] == pytest.approx(v, abs=0.01) + assert states.attributes[result] == pytest.approx(value, abs=0.01) -async def test_turn_off(hass, mock_gateway, mock_api): +async def test_turn_off(hass, mock_gateway, api_factory): """Test turning off a light.""" state = {"state": True, "dimmer": 100} light = mock_light(test_state=state) mock_gateway.mock_devices.append(light) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) # Use the turn_off service call to change the light state. await hass.services.async_call( @@ -289,19 +304,19 @@ async def test_turn_off(hass, mock_gateway, mock_api): _, callkwargs = mock_func.call_args assert "callback" in callkwargs # Callback function to refresh light state. - cb = callkwargs["callback"] + callback = callkwargs["callback"] responses = mock_gateway.mock_responses data = {"3311": [{}]} # Add data for all sent commands. - for r in responses: - data["3311"][0] = {**data["3311"][0], **r["3311"][0]} + for resp in responses: + data["3311"][0] = {**data["3311"][0], **resp["3311"][0]} # Use the callback function to update the light state. dev = Device(data) light_data = Light(dev, 0) light.light_control.lights[0] = light_data - cb(light) + callback(light) await hass.async_block_till_done() # Check that the state is correct. @@ -309,23 +324,25 @@ async def test_turn_off(hass, mock_gateway, mock_api): assert states.state == "off" -def mock_group(test_state={}, n=0): +def mock_group(test_state=None, group_number=0): """Mock a Tradfri group.""" + if test_state is None: + test_state = {} default_state = {"state": False, "dimmer": 0} state = {**default_state, **test_state} - mock_group = Mock(member_ids=[], observe=Mock(), **state) - mock_group.name = f"tradfri_group_{n}" - return mock_group + _mock_group = Mock(member_ids=[], observe=Mock(), **state) + _mock_group.name = f"tradfri_group_{group_number}" + return _mock_group -async def test_group(hass, mock_gateway, mock_api): +async def test_group(hass, mock_gateway, api_factory): """Test that groups are correctly added.""" mock_gateway.mock_groups.append(mock_group()) state = {"state": True, "dimmer": 100} mock_gateway.mock_groups.append(mock_group(state, 1)) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) group = hass.states.get("light.tradfri_group_0") assert group is not None @@ -337,15 +354,15 @@ async def test_group(hass, mock_gateway, mock_api): assert group.attributes["brightness"] == 100 -async def test_group_turn_on(hass, mock_gateway, mock_api): +async def test_group_turn_on(hass, mock_gateway, api_factory): """Test turning on a group.""" group = mock_group() - group2 = mock_group(n=1) - group3 = mock_group(n=2) + group2 = mock_group(group_number=1) + group3 = mock_group(group_number=2) mock_gateway.mock_groups.append(group) mock_gateway.mock_groups.append(group2) mock_gateway.mock_groups.append(group3) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) # Use the turn_off service call to change the light state. await hass.services.async_call( @@ -370,11 +387,11 @@ async def test_group_turn_on(hass, mock_gateway, mock_api): group3.set_dimmer.assert_called_with(100, transition_time=10) -async def test_group_turn_off(hass, mock_gateway, mock_api): +async def test_group_turn_off(hass, mock_gateway, api_factory): """Test turning off a group.""" group = mock_group({"state": True}) mock_gateway.mock_groups.append(group) - await setup_gateway(hass, mock_gateway, mock_api) + await setup_integration(hass) # Use the turn_off service call to change the light state. await hass.services.async_call(