diff --git a/homeassistant/components/pi_hole/__init__.py b/homeassistant/components/pi_hole/__init__.py index a0d6c5da6d1..eba9053183b 100644 --- a/homeassistant/components/pi_hole/__init__.py +++ b/homeassistant/components/pi_hole/__init__.py @@ -17,10 +17,12 @@ from homeassistant.const import ( from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession -from homeassistant.util import Throttle +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import ( CONF_LOCATION, + DATA_KEY_API, + DATA_KEY_COORDINATOR, DEFAULT_LOCATION, DEFAULT_NAME, DEFAULT_SSL, @@ -34,7 +36,7 @@ from .const import ( SERVICE_ENABLE_ATTR_NAME, ) -LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) PI_HOLE_SCHEMA = vol.Schema( vol.All( @@ -56,7 +58,7 @@ CONFIG_SCHEMA = vol.Schema( async def async_setup(hass, config): - """Set up the pi_hole integration.""" + """Set up the Pi_hole integration.""" service_disable_schema = vol.Schema( vol.All( @@ -82,37 +84,36 @@ async def async_setup(hass, config): ) ) - def get_pi_hole_from_name(name): - pi_hole = hass.data[DOMAIN].get(name) - if pi_hole is None: - LOGGER.error("Unknown Pi-hole name %s", name) + def get_api_from_name(name): + """Get Pi-hole API object from user configured name.""" + hole_data = hass.data[DOMAIN].get(name) + if hole_data is None: + _LOGGER.error("Unknown Pi-hole name %s", name) return None - if not pi_hole.api.api_token: - LOGGER.error( + api = hole_data[DATA_KEY_API] + if not api.api_token: + _LOGGER.error( "Pi-hole %s must have an api_key provided in configuration to be enabled", name, ) return None - return pi_hole + return api async def disable_service_handler(call): - """Handle the service call to disable a single Pi-Hole or all configured Pi-Holes.""" + """Handle the service call to disable a single Pi-hole or all configured Pi-holes.""" duration = call.data[SERVICE_DISABLE_ATTR_DURATION].total_seconds() name = call.data.get(SERVICE_DISABLE_ATTR_NAME) async def do_disable(name): - """Disable the named Pi-Hole.""" - pi_hole = get_pi_hole_from_name(name) - if pi_hole is None: + """Disable the named Pi-hole.""" + api = get_api_from_name(name) + if api is None: return - LOGGER.debug( - "Disabling Pi-hole '%s' (%s) for %d seconds", - name, - pi_hole.api.host, - duration, + _LOGGER.debug( + "Disabling Pi-hole '%s' (%s) for %d seconds", name, api.host, duration, ) - await pi_hole.api.disable(duration) + await api.disable(duration) if name is not None: await do_disable(name) @@ -121,18 +122,18 @@ async def async_setup(hass, config): await do_disable(name) async def enable_service_handler(call): - """Handle the service call to enable a single Pi-Hole or all configured Pi-Holes.""" + """Handle the service call to enable a single Pi-hole or all configured Pi-holes.""" name = call.data.get(SERVICE_ENABLE_ATTR_NAME) async def do_enable(name): - """Enable the named Pi-Hole.""" - pi_hole = get_pi_hole_from_name(name) - if pi_hole is None: + """Enable the named Pi-hole.""" + api = get_api_from_name(name) + if api is None: return - LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, pi_hole.api.host) - await pi_hole.api.enable() + _LOGGER.debug("Enabling Pi-hole '%s' (%s)", name, api.host) + await api.enable() if name is not None: await do_enable(name) @@ -160,27 +161,37 @@ async def async_setup_entry(hass, entry): location = entry.data[CONF_LOCATION] api_key = entry.data.get(CONF_API_KEY) - LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) + _LOGGER.debug("Setting up %s integration with host %s", DOMAIN, host) try: session = async_get_clientsession(hass, verify_tls) - pi_hole = PiHoleData( - Hole( - host, - hass.loop, - session, - location=location, - tls=use_tls, - api_token=api_key, - ), - name, + api = Hole( + host, hass.loop, session, location=location, tls=use_tls, api_token=api_key, ) - await pi_hole.async_update() - hass.data[DOMAIN][name] = pi_hole + await api.get_data() except HoleError as ex: - LOGGER.warning("Failed to connect: %s", ex) + _LOGGER.warning("Failed to connect: %s", ex) raise ConfigEntryNotReady + async def async_update_data(): + """Fetch data from API endpoint.""" + try: + await api.get_data() + except HoleError as err: + raise UpdateFailed(f"Failed to communicating with API: {err}") + + coordinator = DataUpdateCoordinator( + hass, + _LOGGER, + name=name, + update_method=async_update_data, + update_interval=MIN_TIME_BETWEEN_UPDATES, + ) + hass.data[DOMAIN][name] = { + DATA_KEY_API: api, + DATA_KEY_COORDINATOR: coordinator, + } + hass.async_create_task( hass.config_entries.async_forward_entry_setup(entry, SENSOR_DOMAIN) ) @@ -192,24 +203,3 @@ async def async_unload_entry(hass, entry): """Unload pi-hole entry.""" hass.data[DOMAIN].pop(entry.data[CONF_NAME]) return await hass.config_entries.async_forward_entry_unload(entry, SENSOR_DOMAIN) - - -class PiHoleData: - """Get the latest data and update the states.""" - - def __init__(self, api, name): - """Initialize the data object.""" - self.api = api - self.name = name - self.available = True - - @Throttle(MIN_TIME_BETWEEN_UPDATES) - async def async_update(self): - """Get the latest data from the Pi-hole.""" - - try: - await self.api.get_data() - self.available = True - except HoleError: - LOGGER.error("Unable to fetch data from Pi-hole") - self.available = False diff --git a/homeassistant/components/pi_hole/const.py b/homeassistant/components/pi_hole/const.py index eec71ca441d..a5807de5575 100644 --- a/homeassistant/components/pi_hole/const.py +++ b/homeassistant/components/pi_hole/const.py @@ -23,6 +23,9 @@ ATTR_BLOCKED_DOMAINS = "domains_blocked" MIN_TIME_BETWEEN_UPDATES = timedelta(minutes=5) +DATA_KEY_API = "api" +DATA_KEY_COORDINATOR = "coordinator" + SENSOR_DICT = { "ads_blocked_today": ["Ads Blocked Today", "ads", "mdi:close-octagon-outline"], "ads_percentage_today": [ diff --git a/homeassistant/components/pi_hole/sensor.py b/homeassistant/components/pi_hole/sensor.py index bbc42cdd8a5..d0009f1ebba 100644 --- a/homeassistant/components/pi_hole/sensor.py +++ b/homeassistant/components/pi_hole/sensor.py @@ -6,6 +6,8 @@ from homeassistant.helpers.entity import Entity from .const import ( ATTR_BLOCKED_DOMAINS, + DATA_KEY_API, + DATA_KEY_COORDINATOR, DOMAIN as PIHOLE_DOMAIN, SENSOR_DICT, SENSOR_LIST, @@ -15,10 +17,17 @@ LOGGER = logging.getLogger(__name__) async def async_setup_entry(hass, entry, async_add_entities): - """Set up the pi-hole sensor.""" - pi_hole = hass.data[PIHOLE_DOMAIN][entry.data[CONF_NAME]] + """Set up the Pi-hole sensor.""" + name = entry.data[CONF_NAME] + hole_data = hass.data[PIHOLE_DOMAIN][name] sensors = [ - PiHoleSensor(pi_hole, sensor_name, entry.entry_id) + PiHoleSensor( + hole_data[DATA_KEY_API], + hole_data[DATA_KEY_COORDINATOR], + name, + sensor_name, + entry.entry_id, + ) for sensor_name in SENSOR_LIST ] async_add_entities(sensors, True) @@ -27,10 +36,11 @@ async def async_setup_entry(hass, entry, async_add_entities): class PiHoleSensor(Entity): """Representation of a Pi-hole sensor.""" - def __init__(self, pi_hole, sensor_name, server_unique_id): + def __init__(self, api, coordinator, name, sensor_name, server_unique_id): """Initialize a Pi-hole sensor.""" - self.pi_hole = pi_hole - self._name = pi_hole.name + self.api = api + self.coordinator = coordinator + self._name = name self._condition = sensor_name self._server_unique_id = server_unique_id @@ -38,7 +48,12 @@ class PiHoleSensor(Entity): self._condition_name = variable_info[0] self._unit_of_measurement = variable_info[1] self._icon = variable_info[2] - self.data = {} + + async def async_added_to_hass(self): + """When entity is added to hass.""" + self.async_on_remove( + self.coordinator.async_add_listener(self.async_write_ha_state) + ) @property def name(self): @@ -73,21 +88,25 @@ class PiHoleSensor(Entity): def state(self): """Return the state of the device.""" try: - return round(self.data[self._condition], 2) + return round(self.api.data[self._condition], 2) except TypeError: - return self.data[self._condition] + return self.api.data[self._condition] @property def device_state_attributes(self): - """Return the state attributes of the Pi-Hole.""" - return {ATTR_BLOCKED_DOMAINS: self.data["domains_being_blocked"]} + """Return the state attributes of the Pi-hole.""" + return {ATTR_BLOCKED_DOMAINS: self.api.data["domains_being_blocked"]} @property def available(self): """Could the device be accessed during the last update call.""" - return self.pi_hole.available + return self.coordinator.last_update_success + + @property + def should_poll(self): + """No need to poll. Coordinator notifies entity of updates.""" + return False async def async_update(self): """Get the latest data from the Pi-hole API.""" - await self.pi_hole.async_update() - self.data = self.pi_hole.api.data + await self.coordinator.async_request_refresh() diff --git a/tests/components/pi_hole/test_init.py b/tests/components/pi_hole/test_init.py index 73a501c74ce..d6cee176775 100644 --- a/tests/components/pi_hole/test_init.py +++ b/tests/components/pi_hole/test_init.py @@ -1,11 +1,13 @@ """Test pi_hole component.""" from homeassistant.components import pi_hole +from homeassistant.components.pi_hole.const import MIN_TIME_BETWEEN_UPDATES +from homeassistant.util import dt as dt_util from . import _create_mocked_hole, _patch_config_flow_hole from tests.async_mock import patch -from tests.common import async_setup_component +from tests.common import async_fire_time_changed, async_setup_component def _patch_init_hole(mocked_hole): @@ -138,3 +140,23 @@ async def test_enable_service_call(hass): await hass.async_block_till_done() assert mocked_hole.enable.call_count == 2 + + +async def test_update_coordinator(hass): + """Test update coordinator.""" + mocked_hole = _create_mocked_hole() + sensor_entity_id = "sensor.pi_hole_ads_blocked_today" + with _patch_config_flow_hole(mocked_hole), _patch_init_hole(mocked_hole): + assert await async_setup_component( + hass, pi_hole.DOMAIN, {pi_hole.DOMAIN: [{"host": "pi.hole"}]} + ) + await hass.async_block_till_done() + assert mocked_hole.get_data.call_count == 3 + assert hass.states.get(sensor_entity_id).state == "0" + + mocked_hole.data["ads_blocked_today"] = 1 + utcnow = dt_util.utcnow() + async_fire_time_changed(hass, utcnow + MIN_TIME_BETWEEN_UPDATES) + await hass.async_block_till_done() + assert mocked_hole.get_data.call_count == 4 + assert hass.states.get(sensor_entity_id).state == "1"