Implement coordinator class for Tradfri integration (#64166)

* Initial commit coordinator

* More coordinator implementation

* More coordinator implementation

* Allow integration reload

* Move API calls to try/catch block

* Move back fixture

* Remove coordinator test file

* Ensure unchanged file

* Ensure unchanged conftest.py file

* Remove coordinator key check

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Import RequestError

* Move async_setup_platforms to end of setup_entry

* Remove centralised handling of device data and device controllers

* Remove platform_type argument

* Remove exception

* Remove the correct exception

* Refactor coordinator error handling

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove platform type from base class

* Remove timeout context manager

* Refactor exception callback

* Simplify starting device observation

* Update test

* Move observe start into update method

* Remove await self.coordinator.async_request_refresh()

* Refactor cover.py

* Uncomment const.py

* Add back extra_state_attributes

* Update homeassistant/components/tradfri/coordinator.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Refactor switch platform

* Expose switch state

* Refactor sensor platform

* Put back accidentally deleted code

* Add set_hub_available

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Fix tests for fan platform

* Update homeassistant/components/tradfri/base_class.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Update homeassistant/components/tradfri/base_class.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Fix non-working tests

* Refresh sensor state

* Remove commented line

* Add group coordinator

* Add groups during setup

* Refactor light platform

* Fix tests

* Move outside of try...except

* Remove error handler

* Remove unneeded methods

* Update sensor

* Update .coveragerc

* Move signal

* Add signals for groups

* Fix signal

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Patrik Lindgren 2022-01-27 11:12:52 +01:00 committed by GitHub
parent 3daaed1056
commit 9d404b749a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 452 additions and 251 deletions

View file

@ -1193,6 +1193,7 @@ omit =
homeassistant/components/tradfri/__init__.py
homeassistant/components/tradfri/base_class.py
homeassistant/components/tradfri/config_flow.py
homeassistant/components/tradfri/coordinator.py
homeassistant/components/tradfri/cover.py
homeassistant/components/tradfri/fan.py
homeassistant/components/tradfri/light.py

View file

@ -7,6 +7,9 @@ from typing import Any
from pytradfri import Gateway, PytradfriError, RequestError
from pytradfri.api.aiocoap_api import APIFactory
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.group import Group
import voluptuous as vol
from homeassistant import config_entries
@ -15,7 +18,10 @@ from homeassistant.const import CONF_HOST, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
@ -28,15 +34,20 @@ from .const import (
CONF_IDENTITY,
CONF_IMPORT_GROUPS,
CONF_KEY,
COORDINATOR,
COORDINATOR_LIST,
DEFAULT_ALLOW_TRADFRI_GROUPS,
DEVICES,
DOMAIN,
GROUPS,
GROUPS_LIST,
KEY_API,
PLATFORMS,
SIGNAL_GW,
TIMEOUT_API,
)
from .coordinator import (
TradfriDeviceDataUpdateCoordinator,
TradfriGroupDataUpdateCoordinator,
)
_LOGGER = logging.getLogger(__name__)
@ -84,9 +95,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
) -> bool:
"""Create a gateway."""
# host, identity, key, allow_tradfri_groups
tradfri_data: dict[str, Any] = {}
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = tradfri_data
listeners = tradfri_data[LISTENERS] = []
@ -96,11 +109,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
psk_id=entry.data[CONF_IDENTITY],
psk=entry.data[CONF_KEY],
)
tradfri_data[FACTORY] = factory # Used for async_unload_entry
async def on_hass_stop(event: Event) -> None:
"""Close connection when hass stops."""
await factory.shutdown()
# Setup listeners
listeners.append(hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop))
api = factory.request
@ -108,19 +123,17 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try:
gateway_info = await api(gateway.get_gateway_info(), timeout=TIMEOUT_API)
devices_commands = await api(gateway.get_devices(), timeout=TIMEOUT_API)
devices = await api(devices_commands, timeout=TIMEOUT_API)
groups_commands = await api(gateway.get_groups(), timeout=TIMEOUT_API)
groups = await api(groups_commands, timeout=TIMEOUT_API)
devices_commands: Command = await api(
gateway.get_devices(), timeout=TIMEOUT_API
)
devices: list[Device] = await api(devices_commands, timeout=TIMEOUT_API)
groups_commands: Command = await api(gateway.get_groups(), timeout=TIMEOUT_API)
groups: list[Group] = await api(groups_commands, timeout=TIMEOUT_API)
except PytradfriError as exc:
await factory.shutdown()
raise ConfigEntryNotReady from exc
tradfri_data[KEY_API] = api
tradfri_data[FACTORY] = factory
tradfri_data[DEVICES] = devices
tradfri_data[GROUPS] = groups
dev_reg = await hass.helpers.device_registry.async_get_registry()
dev_reg.async_get_or_create(
config_entry_id=entry.entry_id,
@ -133,7 +146,38 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
sw_version=gateway_info.firmware_version,
)
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
# Setup the device coordinators
coordinator_data = {
CONF_GATEWAY_ID: gateway,
KEY_API: api,
COORDINATOR_LIST: [],
GROUPS_LIST: [],
}
for device in devices:
coordinator = TradfriDeviceDataUpdateCoordinator(
hass=hass, api=api, device=device
)
await coordinator.async_config_entry_first_refresh()
entry.async_on_unload(
async_dispatcher_connect(hass, SIGNAL_GW, coordinator.set_hub_available)
)
coordinator_data[COORDINATOR_LIST].append(coordinator)
for group in groups:
group_coordinator = TradfriGroupDataUpdateCoordinator(
hass=hass, api=api, group=group
)
await group_coordinator.async_config_entry_first_refresh()
entry.async_on_unload(
async_dispatcher_connect(
hass, SIGNAL_GW, group_coordinator.set_hub_available
)
)
coordinator_data[GROUPS_LIST].append(group_coordinator)
tradfri_data[COORDINATOR] = coordinator_data
async def async_keep_alive(now: datetime) -> None:
if hass.is_stopping:
@ -152,6 +196,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async_track_time_interval(hass, async_keep_alive, timedelta(seconds=60))
)
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
return True

View file

@ -1,29 +1,22 @@
"""Base class for IKEA TRADFRI."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
import logging
from typing import Any
from typing import Any, cast
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.device.air_purifier import AirPurifier
from pytradfri.device.air_purifier_control import AirPurifierControl
from pytradfri.device.blind import Blind
from pytradfri.device.blind_control import BlindControl
from pytradfri.device.light import Light
from pytradfri.device.light_control import LightControl
from pytradfri.device.signal_repeater_control import SignalRepeaterControl
from pytradfri.device.socket import Socket
from pytradfri.device.socket_control import SocketControl
from pytradfri.error import PytradfriError
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo, Entity
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN, SIGNAL_GW
from .const import DOMAIN
from .coordinator import TradfriDeviceDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__)
@ -44,102 +37,44 @@ def handle_error(
return wrapper
class TradfriBaseClass(Entity):
"""Base class for IKEA TRADFRI.
class TradfriBaseEntity(CoordinatorEntity):
"""Base Tradfri device."""
All devices and groups should ultimately inherit from this class.
"""
_attr_should_poll = False
coordinator: TradfriDeviceDataUpdateCoordinator
def __init__(
self,
device: Device,
api: Callable[[Command | list[Command]], Any],
device_coordinator: TradfriDeviceDataUpdateCoordinator,
gateway_id: str,
api: Callable[[Command | list[Command]], Any],
) -> None:
"""Initialize a device."""
self._api = handle_error(api)
self._attr_name = device.name
self._device: Device = device
self._device_control: BlindControl | LightControl | SocketControl | SignalRepeaterControl | AirPurifierControl | None = (
None
)
self._device_data: Socket | Light | Blind | AirPurifier | None = None
super().__init__(device_coordinator)
self._gateway_id = gateway_id
async def _async_run_observe(self, cmd: Command) -> None:
"""Run observe in a coroutine."""
try:
await self._api(cmd)
except PytradfriError as err:
self._attr_available = False
self.async_write_ha_state()
_LOGGER.warning("Observation failed, trying again", exc_info=err)
self._async_start_observe()
self._device: Device = device_coordinator.data
self._device_id = self._device.id
self._api = handle_error(api)
self._attr_name = self._device.name
self._attr_unique_id = f"{self._gateway_id}-{self._device.id}"
@abstractmethod
@callback
def _refresh(self) -> None:
"""Refresh device data."""
@callback
def _async_start_observe(self, exc: Exception | None = None) -> None:
"""Start observation of device."""
if exc:
self._attr_available = False
self.async_write_ha_state()
_LOGGER.warning("Observation failed for %s", self._attr_name, exc_info=exc)
cmd = self._device.observe(
callback=self._observe_update,
err_callback=self._async_start_observe,
duration=0,
)
self.hass.async_create_task(self._async_run_observe(cmd))
def _handle_coordinator_update(self) -> None:
"""
Handle updated data from the coordinator.
async def async_added_to_hass(self) -> None:
"""Start thread when added to hass."""
self._async_start_observe()
@callback
def _observe_update(self, device: Device) -> None:
"""Receive new state data for this device."""
self._refresh(device)
def _refresh(self, device: Device, write_ha: bool = True) -> None:
"""Refresh the device data."""
self._device = device
self._attr_name = device.name
if write_ha:
self.async_write_ha_state()
class TradfriBaseDevice(TradfriBaseClass):
"""Base class for a TRADFRI device.
All devices should inherit from this class.
"""
def __init__(
self,
device: Device,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a device."""
self._attr_available = device.reachable
self._hub_available = True
super().__init__(device, api, gateway_id)
async def async_added_to_hass(self) -> None:
"""Start thread when added to hass."""
# Only devices shall receive SIGNAL_GW
self.async_on_remove(
async_dispatcher_connect(self.hass, SIGNAL_GW, self.set_hub_available)
)
await super().async_added_to_hass()
@callback
def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self._hub_available:
self._hub_available = available
self._refresh(self._device)
Tests fails without this method.
"""
self._refresh()
super()._handle_coordinator_update()
@property
def device_info(self) -> DeviceInfo:
@ -154,10 +89,7 @@ class TradfriBaseDevice(TradfriBaseClass):
via_device=(DOMAIN, self._gateway_id),
)
def _refresh(self, device: Device, write_ha: bool = True) -> None:
"""Refresh the device data."""
# The base class _refresh cannot be used, because
# there are devices (group) that do not have .reachable
# so set _attr_available here and let the base class do the rest.
self._attr_available = device.reachable and self._hub_available
super()._refresh(device, write_ha)
@property
def available(self) -> bool:
"""Return if entity is available."""
return cast(bool, self._device.reachable) and super().available

View file

@ -37,3 +37,9 @@ PLATFORMS = [
]
TIMEOUT_API = 30
ATTR_MAX_FAN_STEPS = 49
SCAN_INTERVAL = 60 # Interval for updating the coordinator
COORDINATOR = "coordinator"
COORDINATOR_LIST = "coordinator_list"
GROUPS_LIST = "groups_list"

View file

@ -0,0 +1,145 @@
"""Tradfri DataUpdateCoordinator."""
from __future__ import annotations
from collections.abc import Callable
from datetime import timedelta
import logging
from typing import Any
from pytradfri.command import Command
from pytradfri.device import Device
from pytradfri.error import RequestError
from pytradfri.group import Group
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import SCAN_INTERVAL
_LOGGER = logging.getLogger(__name__)
class TradfriDeviceDataUpdateCoordinator(DataUpdateCoordinator[Device]):
"""Coordinator to manage data for a specific Tradfri device."""
def __init__(
self,
hass: HomeAssistant,
*,
api: Callable[[Command | list[Command]], Any],
device: Device,
) -> None:
"""Initialize device coordinator."""
self.api = api
self.device = device
self._exception: Exception | None = None
super().__init__(
hass,
_LOGGER,
name=f"Update coordinator for {device}",
update_interval=timedelta(seconds=SCAN_INTERVAL),
)
async def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self.last_update_success:
if not available:
self.last_update_success = False
await self.async_request_refresh()
@callback
def _observe_update(self, device: Device) -> None:
"""Update the coordinator for a device when a change is detected."""
self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval
self.async_set_updated_data(data=device)
@callback
def _exception_callback(self, device: Device, exc: Exception | None = None) -> None:
"""Schedule handling exception.."""
self.hass.async_create_task(self._handle_exception(device=device, exc=exc))
async def _handle_exception(
self, device: Device, exc: Exception | None = None
) -> None:
"""Handle observe exceptions in a coroutine."""
self._exception = (
exc # Store exception so that it gets raised in _async_update_data
)
_LOGGER.debug("Observation failed for %s, trying again", device, exc_info=exc)
self.update_interval = timedelta(
seconds=5
) # Change interval so we get a swift refresh
await self.async_request_refresh()
async def _async_update_data(self) -> Device:
"""Fetch data from the gateway for a specific device."""
try:
if self._exception:
exc = self._exception
self._exception = None # Clear stored exception
raise exc # pylint: disable-msg=raising-bad-type
except RequestError as err:
raise UpdateFailed(
f"Error communicating with API: {err}. Try unplugging and replugging your "
f"IKEA gateway."
) from err
if not self.data or not self.last_update_success: # Start subscription
try:
cmd = self.device.observe(
callback=self._observe_update,
err_callback=self._exception_callback,
duration=0,
)
await self.api(cmd)
except RequestError as exc:
await self._handle_exception(device=self.device, exc=exc)
return self.device
class TradfriGroupDataUpdateCoordinator(DataUpdateCoordinator[Group]):
"""Coordinator to manage data for a specific Tradfri group."""
def __init__(
self,
hass: HomeAssistant,
*,
api: Callable[[Command | list[Command]], Any],
group: Group,
) -> None:
"""Initialize group coordinator."""
self.api = api
self.group = group
self._exception: Exception | None = None
super().__init__(
hass,
_LOGGER,
name=f"Update coordinator for {group}",
update_interval=timedelta(seconds=SCAN_INTERVAL),
)
async def set_hub_available(self, available: bool) -> None:
"""Set status of hub."""
if available != self.last_update_success:
if not available:
self.last_update_success = False
await self.async_request_refresh()
async def _async_update_data(self) -> Group:
"""Fetch data from the gateway for a specific group."""
self.update_interval = timedelta(seconds=SCAN_INTERVAL) # Reset update interval
cmd = self.group.update()
try:
await self.api(cmd)
except RequestError as exc:
self.update_interval = timedelta(
seconds=5
) # Change interval so we get a swift refresh
raise UpdateFailed("Unable to update group coordinator") from exc
return self.group

View file

@ -11,8 +11,16 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice
from .const import ATTR_MODEL, CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API
from .base_class import TradfriBaseEntity
from .const import (
ATTR_MODEL,
CONF_GATEWAY_ID,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
KEY_API,
)
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry(
@ -22,28 +30,42 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri covers based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
devices = tradfri_data[DEVICES]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = coordinator_data[KEY_API]
async_add_entities(
TradfriCover(dev, api, gateway_id) for dev in devices if dev.has_blind_control
TradfriCover(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_blind_control
)
class TradfriCover(TradfriBaseDevice, CoverEntity):
class TradfriCover(TradfriBaseEntity, CoverEntity):
"""The platform class required by Home Assistant."""
def __init__(
self,
device: Command,
device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a cover."""
self._attr_unique_id = f"{gateway_id}-{device.id}"
super().__init__(device, api, gateway_id)
self._refresh(device, write_ha=False)
"""Initialize a switch."""
super().__init__(
device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.blind_control
self._device_data = self._device_control.blinds[0]
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.blind_control.blinds[0]
@property
def extra_state_attributes(self) -> dict[str, str] | None:
@ -88,11 +110,3 @@ class TradfriCover(TradfriBaseDevice, CoverEntity):
def is_closed(self) -> bool:
"""Return if the cover is closed or not."""
return self.current_cover_position == 0
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the cover data."""
# Caching of BlindControl and cover object
self._device = device
self._device_control = device.blind_control
self._device_data = device.blind_control.blinds[0]
super()._refresh(device, write_ha=write_ha)

View file

@ -16,15 +16,17 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice
from .base_class import TradfriBaseEntity
from .const import (
ATTR_AUTO,
ATTR_MAX_FAN_STEPS,
CONF_GATEWAY_ID,
DEVICES,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
KEY_API,
)
from .coordinator import TradfriDeviceDataUpdateCoordinator
def _from_fan_percentage(percentage: int) -> int:
@ -44,30 +46,42 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
devices = tradfri_data[DEVICES]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = coordinator_data[KEY_API]
async_add_entities(
TradfriAirPurifierFan(dev, api, gateway_id)
for dev in devices
if dev.has_air_purifier_control
TradfriAirPurifierFan(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_air_purifier_control
)
class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity):
class TradfriAirPurifierFan(TradfriBaseEntity, FanEntity):
"""The platform class required by Home Assistant."""
def __init__(
self,
device: Command,
device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a switch."""
super().__init__(device, api, gateway_id)
self._attr_unique_id = f"{gateway_id}-{device.id}"
self._refresh(device, write_ha=False)
super().__init__(
device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.air_purifier_control
self._device_data = self._device_control.air_purifiers[0]
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.air_purifier_control.air_purifiers[0]
@property
def supported_features(self) -> int:
@ -168,10 +182,3 @@ class TradfriAirPurifierFan(TradfriBaseDevice, FanEntity):
if not self._device_control:
return
await self._api(self._device_control.turn_off())
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the purifier data."""
# Caching of air purifier control and purifier object
self._device_control = device.air_purifier_control
self._device_data = device.air_purifier_control.air_purifiers[0]
super()._refresh(device, write_ha=write_ha)

View file

@ -5,6 +5,7 @@ from collections.abc import Callable
from typing import Any, cast
from pytradfri.command import Command
from pytradfri.group import Group
from homeassistant.components.light import (
ATTR_BRIGHTNESS,
@ -19,9 +20,10 @@ from homeassistant.components.light import (
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
import homeassistant.util.color as color_util
from .base_class import TradfriBaseClass, TradfriBaseDevice
from .base_class import TradfriBaseEntity
from .const import (
ATTR_DIMMER,
ATTR_HUE,
@ -29,13 +31,18 @@ from .const import (
ATTR_TRANSITION_TIME,
CONF_GATEWAY_ID,
CONF_IMPORT_GROUPS,
DEVICES,
COORDINATOR,
COORDINATOR_LIST,
DOMAIN,
GROUPS,
GROUPS_LIST,
KEY_API,
SUPPORTED_GROUP_FEATURES,
SUPPORTED_LIGHT_FEATURES,
)
from .coordinator import (
TradfriDeviceDataUpdateCoordinator,
TradfriGroupDataUpdateCoordinator,
)
async def async_setup_entry(
@ -45,56 +52,66 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri lights based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
devices = tradfri_data[DEVICES]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = coordinator_data[KEY_API]
entities: list[TradfriBaseClass] = [
TradfriLight(dev, api, gateway_id) for dev in devices if dev.has_light_control
entities: list = [
TradfriLight(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_light_control
]
if config_entry.data[CONF_IMPORT_GROUPS] and (groups := tradfri_data[GROUPS]):
entities.extend([TradfriGroup(group, api, gateway_id) for group in groups])
if config_entry.data[CONF_IMPORT_GROUPS] and (
group_coordinators := coordinator_data[GROUPS_LIST]
):
entities.extend(
[
TradfriGroup(group_coordinator, api, gateway_id)
for group_coordinator in group_coordinators
]
)
async_add_entities(entities)
class TradfriGroup(TradfriBaseClass, LightEntity):
class TradfriGroup(CoordinatorEntity, LightEntity):
"""The platform class for light groups required by hass."""
_attr_supported_features = SUPPORTED_GROUP_FEATURES
def __init__(
self,
device: Command,
group_coordinator: TradfriGroupDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a Group."""
super().__init__(device, api, gateway_id)
super().__init__(coordinator=group_coordinator)
self._attr_unique_id = f"group-{gateway_id}-{device.id}"
self._attr_should_poll = True
self._refresh(device, write_ha=False)
self._group: Group = self.coordinator.data
async def async_update(self) -> None:
"""Fetch new state data for the group.
This method is required for groups to update properly.
"""
await self._api(self._device.update())
self._api = api
self._attr_unique_id = f"group-{gateway_id}-{self._group.id}"
@property
def is_on(self) -> bool:
"""Return true if group lights are on."""
return cast(bool, self._device.state)
return cast(bool, self._group.state)
@property
def brightness(self) -> int | None:
"""Return the brightness of the group lights."""
return cast(int, self._device.dimmer)
return cast(int, self._group.dimmer)
async def async_turn_off(self, **kwargs: Any) -> None:
"""Instruct the group lights to turn off."""
await self._api(self._device.set_state(0))
await self._api(self._group.set_state(0))
await self.coordinator.async_request_refresh()
async def async_turn_on(self, **kwargs: Any) -> None:
"""Instruct the group lights to turn on, or dim."""
@ -106,39 +123,53 @@ class TradfriGroup(TradfriBaseClass, LightEntity):
if kwargs[ATTR_BRIGHTNESS] == 255:
kwargs[ATTR_BRIGHTNESS] = 254
await self._api(self._device.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys))
await self._api(self._group.set_dimmer(kwargs[ATTR_BRIGHTNESS], **keys))
else:
await self._api(self._device.set_state(1))
await self._api(self._group.set_state(1))
await self.coordinator.async_request_refresh()
class TradfriLight(TradfriBaseDevice, LightEntity):
class TradfriLight(TradfriBaseEntity, LightEntity):
"""The platform class required by Home Assistant."""
def __init__(
self,
device: Command,
device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a Light."""
super().__init__(device, api, gateway_id)
self._attr_unique_id = f"light-{gateway_id}-{device.id}"
super().__init__(
device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
self._device_control = self._device.light_control
self._device_data = self._device_control.lights[0]
self._attr_unique_id = f"light-{gateway_id}-{self._device_id}"
self._hs_color = None
# Calculate supported features
_features = SUPPORTED_LIGHT_FEATURES
if device.light_control.can_set_dimmer:
if self._device.light_control.can_set_dimmer:
_features |= SUPPORT_BRIGHTNESS
if device.light_control.can_set_color:
if self._device.light_control.can_set_color:
_features |= SUPPORT_COLOR | SUPPORT_COLOR_TEMP
if device.light_control.can_set_temp:
if self._device.light_control.can_set_temp:
_features |= SUPPORT_COLOR_TEMP
self._attr_supported_features = _features
self._refresh(device, write_ha=False)
if self._device_control:
self._attr_min_mireds = self._device_control.min_mireds
self._attr_max_mireds = self._device_control.max_mireds
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.light_control.lights[0]
@property
def is_on(self) -> bool:
"""Return true if light is on."""
@ -268,10 +299,3 @@ class TradfriLight(TradfriBaseDevice, LightEntity):
await self._api(temp_command)
if command is not None:
await self._api(command)
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the light data."""
# Caching of LightControl and light object
self._device_control = device.light_control
self._device_data = device.light_control.lights[0]
super()._refresh(device, write_ha=write_ha)

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any, cast
from typing import Any
from pytradfri.command import Command
@ -12,8 +12,9 @@ from homeassistant.const import PERCENTAGE
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API
from .base_class import TradfriBaseEntity
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry(
@ -23,24 +24,27 @@ async def async_setup_entry(
) -> None:
"""Set up a Tradfri config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
devices = tradfri_data[DEVICES]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = coordinator_data[KEY_API]
async_add_entities(
TradfriSensor(dev, api, gateway_id)
for dev in devices
TradfriSensor(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if (
not dev.has_light_control
and not dev.has_socket_control
and not dev.has_blind_control
and not dev.has_signal_repeater_control
and not dev.has_air_purifier_control
not device_coordinator.device.has_light_control
and not device_coordinator.device.has_socket_control
and not device_coordinator.device.has_blind_control
and not device_coordinator.device.has_signal_repeater_control
and not device_coordinator.device.has_air_purifier_control
)
)
class TradfriSensor(TradfriBaseDevice, SensorEntity):
class TradfriSensor(TradfriBaseEntity, SensorEntity):
"""The platform class required by Home Assistant."""
_attr_device_class = SensorDeviceClass.BATTERY
@ -48,17 +52,19 @@ class TradfriSensor(TradfriBaseDevice, SensorEntity):
def __init__(
self,
device: Command,
device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize the device."""
super().__init__(device, api, gateway_id)
self._attr_unique_id = f"{gateway_id}-{device.id}"
"""Initialize a switch."""
super().__init__(
device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
@property
def native_value(self) -> int | None:
"""Return the current state of the device."""
if not self._device:
return None
return cast(int, self._device.device_info.battery_level)
self._refresh() # Set initial state
def _refresh(self) -> None:
"""Refresh the device."""
self._attr_native_value = self.coordinator.data.device_info.battery_level

View file

@ -11,8 +11,9 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .base_class import TradfriBaseDevice
from .const import CONF_GATEWAY_ID, DEVICES, DOMAIN, KEY_API
from .base_class import TradfriBaseEntity
from .const import CONF_GATEWAY_ID, COORDINATOR, COORDINATOR_LIST, DOMAIN, KEY_API
from .coordinator import TradfriDeviceDataUpdateCoordinator
async def async_setup_entry(
@ -22,35 +23,42 @@ async def async_setup_entry(
) -> None:
"""Load Tradfri switches based on a config entry."""
gateway_id = config_entry.data[CONF_GATEWAY_ID]
tradfri_data = hass.data[DOMAIN][config_entry.entry_id]
api = tradfri_data[KEY_API]
devices = tradfri_data[DEVICES]
coordinator_data = hass.data[DOMAIN][config_entry.entry_id][COORDINATOR]
api = coordinator_data[KEY_API]
async_add_entities(
TradfriSwitch(dev, api, gateway_id) for dev in devices if dev.has_socket_control
TradfriSwitch(
device_coordinator,
api,
gateway_id,
)
for device_coordinator in coordinator_data[COORDINATOR_LIST]
if device_coordinator.device.has_socket_control
)
class TradfriSwitch(TradfriBaseDevice, SwitchEntity):
class TradfriSwitch(TradfriBaseEntity, SwitchEntity):
"""The platform class required by Home Assistant."""
def __init__(
self,
device: Command,
device_coordinator: TradfriDeviceDataUpdateCoordinator,
api: Callable[[Command | list[Command]], Any],
gateway_id: str,
) -> None:
"""Initialize a switch."""
super().__init__(device, api, gateway_id)
self._attr_unique_id = f"{gateway_id}-{device.id}"
self._refresh(device, write_ha=False)
super().__init__(
device_coordinator=device_coordinator,
api=api,
gateway_id=gateway_id,
)
def _refresh(self, device: Command, write_ha: bool = True) -> None:
"""Refresh the switch data."""
# Caching of switch control and switch object
self._device_control = device.socket_control
self._device_data = device.socket_control.sockets[0]
super()._refresh(device, write_ha=write_ha)
self._device_control = self._device.socket_control
self._device_data = self._device_control.sockets[0]
def _refresh(self) -> None:
"""Refresh the device."""
self._device_data = self.coordinator.data.socket_control.sockets[0]
@property
def is_on(self) -> bool:

View file

@ -22,3 +22,5 @@ async def setup_integration(hass):
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
return entry

View file

@ -121,7 +121,6 @@ async def test_set_percentage(
"""Test setting speed of a fan."""
# Note pytradfri style, not hass. Values not really important.
initial_state = {"percentage": 10, "fan_speed": 3}
# Setup the gateway with a mock fan.
fan = mock_fan(test_state=initial_state, device_number=0)
mock_gateway.mock_devices.append(fan)

View file

@ -317,6 +317,7 @@ def mock_group(test_state=None, group_number=0):
_mock_group = Mock(member_ids=[], observe=Mock(), **state)
_mock_group.name = f"tradfri_group_{group_number}"
_mock_group.id = group_number
return _mock_group
@ -327,11 +328,11 @@ async def test_group(hass, mock_gateway, mock_api_factory):
mock_gateway.mock_groups.append(mock_group(state, 1))
await setup_integration(hass)
group = hass.states.get("light.tradfri_group_0")
group = hass.states.get("light.tradfri_group_mock_gateway_id_0")
assert group is not None
assert group.state == "off"
group = hass.states.get("light.tradfri_group_1")
group = hass.states.get("light.tradfri_group_mock_gateway_id_1")
assert group is not None
assert group.state == "on"
assert group.attributes["brightness"] == 100
@ -348,19 +349,26 @@ async def test_group_turn_on(hass, mock_gateway, mock_api_factory):
await setup_integration(hass)
# Use the turn_off service call to change the light state.
await hass.services.async_call(
"light", "turn_on", {"entity_id": "light.tradfri_group_0"}, blocking=True
)
await hass.services.async_call(
"light",
"turn_on",
{"entity_id": "light.tradfri_group_1", "brightness": 100},
{"entity_id": "light.tradfri_group_mock_gateway_id_0"},
blocking=True,
)
await hass.services.async_call(
"light",
"turn_on",
{"entity_id": "light.tradfri_group_2", "brightness": 100, "transition": 1},
{"entity_id": "light.tradfri_group_mock_gateway_id_1", "brightness": 100},
blocking=True,
)
await hass.services.async_call(
"light",
"turn_on",
{
"entity_id": "light.tradfri_group_mock_gateway_id_2",
"brightness": 100,
"transition": 1,
},
blocking=True,
)
await hass.async_block_till_done()
@ -378,7 +386,10 @@ async def test_group_turn_off(hass, mock_gateway, mock_api_factory):
# Use the turn_off service call to change the light state.
await hass.services.async_call(
"light", "turn_off", {"entity_id": "light.tradfri_group_0"}, blocking=True
"light",
"turn_off",
{"entity_id": "light.tradfri_group_mock_gateway_id_0"},
blocking=True,
)
await hass.async_block_till_done()