Compare commits

...
Sign in to create a new pull request.

6 commits
dev ... trend

Author SHA1 Message Date
Phil Bruckner
b78d854bbd Ruff PT006 2023-02-15 08:44:45 -06:00
Phil Bruckner
a5fe2cc59a Add some more tests 2023-02-15 08:32:57 -06:00
Phil Bruckner
2d7b264f7a Add unique_id & allow max_samples of 0 to disable max # of samples 2023-02-15 08:32:57 -06:00
Phil Bruckner
998a953951 Update per core changes & add typing 2023-02-15 08:32:57 -06:00
Phil Bruckner
c631c333f7 Add trend min_samples option
Allow sample_duration to use any positive time period specification.
2023-02-15 08:32:57 -06:00
Phil Bruckner
b0feaa0747 Remove trend samples as they get too old 2023-02-15 08:32:57 -06:00
3 changed files with 330 additions and 122 deletions

View file

@ -2,8 +2,12 @@
from __future__ import annotations
from collections import deque
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
import logging
import math
from typing import Any, cast
import numpy as np
import voluptuous as vol
@ -13,26 +17,29 @@ from homeassistant.components.binary_sensor import (
ENTITY_ID_FORMAT,
PLATFORM_SCHEMA,
BinarySensorEntity,
BinarySensorEntityDescription,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_FRIENDLY_NAME,
CONF_ATTRIBUTE,
CONF_DEVICE_CLASS,
CONF_ENTITY_ID,
CONF_FRIENDLY_NAME,
CONF_SENSORS,
CONF_UNIQUE_ID,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import generate_entity_id
from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_state_change_event,
)
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util.dt import utcnow
from . import PLATFORMS
from .const import (
@ -44,23 +51,51 @@ from .const import (
CONF_INVERT,
CONF_MAX_SAMPLES,
CONF_MIN_GRADIENT,
CONF_MIN_SAMPLES,
CONF_SAMPLE_DURATION,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
SENSOR_SCHEMA = vol.Schema(
{
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Optional(CONF_ATTRIBUTE): cv.string,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_FRIENDLY_NAME): cv.string,
vol.Optional(CONF_INVERT, default=False): cv.boolean,
vol.Optional(CONF_MAX_SAMPLES, default=2): cv.positive_int,
vol.Optional(CONF_MIN_GRADIENT, default=0.0): vol.Coerce(float),
vol.Optional(CONF_SAMPLE_DURATION, default=0): cv.positive_int,
}
def _check_sample_options(config: ConfigType) -> ConfigType:
"""Check min/max sample options."""
if config[CONF_MAX_SAMPLES]:
if config[CONF_MAX_SAMPLES] < config[CONF_MIN_SAMPLES]:
raise vol.Invalid(
f"{CONF_MAX_SAMPLES} must not be smaller than {CONF_MIN_SAMPLES}"
)
elif not config[CONF_SAMPLE_DURATION]:
raise vol.Invalid(
f"{CONF_MAX_SAMPLES} & {CONF_SAMPLE_DURATION} cannot both be zero"
)
else:
config[CONF_MAX_SAMPLES] = None
return config
SENSOR_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_ATTRIBUTE): cv.string,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_FRIENDLY_NAME): cv.string,
vol.Optional(CONF_INVERT, default=False): cv.boolean,
vol.Optional(CONF_MAX_SAMPLES, default=2): vol.All(
vol.Coerce(int),
vol.Any(0, vol.Range(min=2), msg="must be 0 or 2 or larger"),
),
vol.Optional(CONF_MIN_GRADIENT, default=0.0): vol.Coerce(float),
vol.Optional(CONF_MIN_SAMPLES, default=2): vol.All(
vol.Coerce(int), vol.Range(min=2)
),
vol.Optional(CONF_SAMPLE_DURATION, default=0): cv.positive_time_period,
}
),
_check_sample_options,
)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
@ -68,6 +103,14 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
@dataclass
class Sample:
"""Trend sample."""
time: datetime
value: float
async def async_setup_platform(
hass: HomeAssistant,
config: ConfigType,
@ -77,118 +120,138 @@ async def async_setup_platform(
"""Set up the trend sensors."""
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
sensors = []
for device_id, device_config in config[CONF_SENSORS].items():
entity_id = device_config[ATTR_ENTITY_ID]
attribute = device_config.get(CONF_ATTRIBUTE)
device_class = device_config.get(CONF_DEVICE_CLASS)
friendly_name = device_config.get(ATTR_FRIENDLY_NAME, device_id)
invert = device_config[CONF_INVERT]
max_samples = device_config[CONF_MAX_SAMPLES]
min_gradient = device_config[CONF_MIN_GRADIENT]
sample_duration = device_config[CONF_SAMPLE_DURATION]
sensors.append(
async_add_entities(
[
SensorTrend(
hass,
device_id,
friendly_name,
entity_id,
attribute,
device_class,
invert,
max_samples,
min_gradient,
sample_duration,
BinarySensorEntityDescription(
key=async_generate_entity_id(
ENTITY_ID_FORMAT, device_id, hass=hass
),
device_class=device_config.get(CONF_DEVICE_CLASS),
name=device_config.get(CONF_FRIENDLY_NAME, device_id),
),
device_config,
)
)
if not sensors:
_LOGGER.error("No sensors added")
return
async_add_entities(sensors)
for device_id, device_config in cast(
ConfigType, config[CONF_SENSORS]
).items()
]
)
class SensorTrend(BinarySensorEntity):
"""Representation of a trend Sensor."""
_attr_should_poll = False
_unsub: Callable[..., None] | None = None
def __init__(
self,
hass,
device_id,
friendly_name,
entity_id,
attribute,
device_class,
invert,
max_samples,
min_gradient,
sample_duration,
):
entity_description: BinarySensorEntityDescription,
config: ConfigType,
) -> None:
"""Initialize the sensor."""
self._hass = hass
self.entity_id = generate_entity_id(ENTITY_ID_FORMAT, device_id, hass=hass)
self._name = friendly_name
self._entity_id = entity_id
self._attribute = attribute
self._device_class = device_class
self._invert = invert
self._sample_duration = sample_duration
self._min_gradient = min_gradient
self._gradient = None
self._state = None
self.samples = deque(maxlen=max_samples)
self._attr_unique_id = config.get(CONF_UNIQUE_ID)
self.entity_description = entity_description
self.entity_id = entity_description.key
@property
def name(self):
"""Return the name of the sensor."""
return self._name
self._entity_id: str = config[ATTR_ENTITY_ID]
self._attribute: str | None = config.get(CONF_ATTRIBUTE)
self._invert: bool = config[CONF_INVERT]
self._sample_duration: timedelta = config[CONF_SAMPLE_DURATION]
self._min_gradient: float = config[CONF_MIN_GRADIENT]
self._min_samples: int = config[CONF_MIN_SAMPLES]
@property
def is_on(self):
"""Return true if sensor is on."""
return self._state
@property
def device_class(self):
"""Return the sensor class of the sensor."""
return self._device_class
@property
def extra_state_attributes(self):
"""Return the state attributes of the sensor."""
return {
self._attr_extra_state_attributes = {
ATTR_ENTITY_ID: self._entity_id,
ATTR_FRIENDLY_NAME: self._name,
ATTR_GRADIENT: self._gradient,
ATTR_GRADIENT: None,
ATTR_INVERT: self._invert,
ATTR_MIN_GRADIENT: self._min_gradient,
ATTR_SAMPLE_COUNT: len(self.samples),
ATTR_SAMPLE_DURATION: self._sample_duration,
ATTR_SAMPLE_COUNT: 0,
ATTR_SAMPLE_DURATION: self._sample_duration.total_seconds(),
}
self._samples: deque[Sample] = deque(maxlen=config[CONF_MAX_SAMPLES])
def unsub():
if self._unsub:
self._unsub()
self._unsub = None
self.async_on_remove(unsub)
@property
def _gradient(self) -> float | None:
"""Return gradient."""
return self._attr_extra_state_attributes[ATTR_GRADIENT]
@_gradient.setter
def _gradient(self, gradient: float | None) -> None:
"""Set gradient."""
self._attr_extra_state_attributes[ATTR_GRADIENT] = gradient
def _remove_oldest_sample(self) -> None:
"""Remove oldest sample."""
self._samples.popleft()
self._attr_extra_state_attributes[ATTR_SAMPLE_COUNT] = len(self._samples)
def _add_sample(self, sample: Sample) -> None:
"""Add new sample."""
self._samples.append(sample)
self._attr_extra_state_attributes[ATTR_SAMPLE_COUNT] = len(self._samples)
async def async_added_to_hass(self) -> None:
"""Complete device setup after being added to hass."""
@callback
def trend_sensor_state_listener(event):
def remove_stale_sample(remove_time: datetime) -> None:
"""Remove stale sample."""
self._remove_oldest_sample()
if self._samples:
self._unsub = async_track_point_in_utc_time(
self.hass,
remove_stale_sample,
self._samples[0].time + self._sample_duration,
)
else:
self._unsub = None
self.async_schedule_update_ha_state(True)
@callback
def trend_sensor_state_listener(event: Event) -> None:
"""Handle state changes on the observed device."""
if (new_state := event.data.get("new_state")) is None:
if (new_state := cast(State | Any, event.data.get("new_state"))) is None:
return
if self._attribute:
state = new_state.attributes.get(self._attribute)
else:
state = new_state.state
if state in (None, STATE_UNKNOWN, STATE_UNAVAILABLE):
return
samle_count_was = len(self._samples)
try:
if self._attribute:
state = new_state.attributes.get(self._attribute)
else:
state = new_state.state
if state not in (STATE_UNKNOWN, STATE_UNAVAILABLE):
sample = (new_state.last_updated.timestamp(), float(state))
self.samples.append(sample)
self.async_schedule_update_ha_state(True)
except (ValueError, TypeError) as ex:
_LOGGER.error(ex)
sample = Sample(new_state.last_updated, float(state)) # type: ignore[arg-type]
except (TypeError, ValueError):
_LOGGER.error("Input value %s for %s is not a number", state, self.name)
return
self._add_sample(sample)
if self._sample_duration and samle_count_was in [
0,
self._samples.maxlen,
]:
if self._unsub:
self._unsub()
self._unsub = async_track_point_in_utc_time(
self.hass,
remove_stale_sample,
self._samples[0].time + self._sample_duration,
)
self.async_schedule_update_ha_state(True)
self.async_on_remove(
async_track_state_change_event(
@ -198,33 +261,29 @@ class SensorTrend(BinarySensorEntity):
async def async_update(self) -> None:
"""Get the latest data and update the states."""
# Remove outdated samples
if self._sample_duration > 0:
cutoff = utcnow().timestamp() - self._sample_duration
while self.samples and self.samples[0][0] < cutoff:
self.samples.popleft()
if len(self.samples) < 2:
if len(self._samples) < self._min_samples:
self._gradient = None
self._attr_is_on = None
return
def calculate_gradient() -> float:
"""Compute the linear trend gradient of the current samples.
This need run inside executor.
"""
timestamps = np.array([s.time.timestamp() for s in self._samples])
values = np.array([s.value for s in self._samples])
coeffs = np.polyfit(timestamps, values, 1)
return coeffs[0]
# Calculate gradient of linear trend
await self.hass.async_add_executor_job(self._calculate_gradient)
self._gradient = await self.hass.async_add_executor_job(calculate_gradient)
# Update state
self._state = (
self._attr_is_on = (
abs(self._gradient) > abs(self._min_gradient)
and math.copysign(self._gradient, self._min_gradient) == self._gradient
)
if self._invert:
self._state = not self._state
def _calculate_gradient(self):
"""Compute the linear trend gradient of the current samples.
This need run inside executor.
"""
timestamps = np.array([t for t, _ in self.samples])
values = np.array([s for _, s in self.samples])
coeffs = np.polyfit(timestamps, values, 1)
self._gradient = coeffs[0]
self._attr_is_on = not self._attr_is_on

View file

@ -11,4 +11,5 @@ ATTR_SAMPLE_COUNT = "sample_count"
CONF_INVERT = "invert"
CONF_MAX_SAMPLES = "max_samples"
CONF_MIN_GRADIENT = "min_gradient"
CONF_MIN_SAMPLES = "min_samples"
CONF_SAMPLE_DURATION = "sample_duration"

View file

@ -2,6 +2,8 @@
from datetime import timedelta
from unittest.mock import patch
import pytest
from homeassistant import config as hass_config, setup
from homeassistant.components.trend.const import DOMAIN
from homeassistant.const import SERVICE_RELOAD, STATE_UNKNOWN
@ -10,6 +12,7 @@ import homeassistant.util.dt as dt_util
from tests.common import (
assert_setup_component,
fire_time_changed,
get_fixture_path,
get_test_home_assistant,
)
@ -378,6 +381,151 @@ class TestTrendBinarySensor:
)
assert self.hass.states.all("binary_sensor") == []
@pytest.mark.parametrize(
("sample_duration", "min_samples", "max_samples"),
[
(-1, 2, 2),
(0, 1, 2),
(0, 2, 1),
(0, 3, 2),
(0, 2, 0),
],
)
def test_invalid_samples_parameters(
self, sample_duration, min_samples, max_samples
):
"""Test invalid samples parameters."""
config = {
"binary_sensor": {
"platform": "trend",
"sensors": {
"test_trend_sensor": {
"entity_id": "sensor.test_state",
"sample_duration": timedelta(seconds=sample_duration),
"min_samples": min_samples,
"max_samples": max_samples,
}
},
}
}
with assert_setup_component(0):
assert setup.setup_component(self.hass, "binary_sensor", config)
assert self.hass.states.all("binary_sensor") == []
@pytest.mark.parametrize(
("max_samples", "data"),
[
(
3,
[
(0, 1, "unknown"),
(10, 2, "on"),
(20, 3, "on"), # Do not exceed either limit
(None, 3, "on"),
(None, 3, "on"),
(None, 2, "on"),
(None, 1, "unknown"),
(None, 0, "unknown"),
(0, 1, "unknown"),
(10, 2, "on"),
(20, 3, "on"),
(30, 3, "on"), # Exceed max_samples
(None, 3, "on"),
(None, 3, "on"),
(None, 2, "on"),
(None, 1, "unknown"),
(None, 0, "unknown"),
(0, 1, "unknown"),
(None, 1, "unknown"),
(10, 2, "on"),
(None, 2, "on"),
(20, 3, "on"),
(None, 2, "on"), # Exceed sample_duration
(None, 2, "on"),
(None, 1, "unknown"),
(None, 1, "unknown"),
(None, 0, "unknown"),
],
),
(
0,
[
(0, 1, "unknown"),
(10, 2, "on"),
(20, 3, "on"), # Do not exceed sample_duration
(None, 3, "on"),
(None, 3, "on"),
(None, 2, "on"),
(None, 1, "unknown"),
(None, 0, "unknown"),
(0, 1, "unknown"),
(None, 1, "unknown"),
(10, 2, "on"),
(None, 2, "on"),
(20, 3, "on"),
(None, 2, "on"), # Exceed sample_duration
(None, 2, "on"),
(None, 1, "unknown"),
(None, 1, "unknown"),
(None, 0, "unknown"),
],
),
],
)
def test_stale_samples_removed(self, max_samples, data) -> None:
"""Test samples outside of max_samples or sample_duration are dropped."""
assert setup.setup_component(
self.hass,
"binary_sensor",
{
"binary_sensor": {
"platform": "trend",
"sensors": {
"test_trend_sensor": {
"entity_id": "sensor.test_state",
"sample_duration": 9,
"min_gradient": 2,
"max_samples": max_samples,
}
},
}
},
)
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_trend_sensor")
assert state.state == "unknown"
assert state.attributes["sample_count"] == 0
now = dt_util.utcnow()
value, expected_count, expected_state = data[0]
if value is not None:
with patch("homeassistant.util.dt.utcnow", return_value=now):
self.hass.states.set("sensor.test_state", value)
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_trend_sensor")
assert state.attributes["sample_count"] == expected_count
assert state.state == expected_state
for idx in range(1, len(data)):
value, expected_count, expected_state = data[idx]
now += timedelta(seconds=2)
fire_time_changed(self.hass, now)
self.hass.block_till_done()
if value is not None:
with patch("homeassistant.util.dt.utcnow", return_value=now):
self.hass.states.set("sensor.test_state", value)
self.hass.block_till_done()
state = self.hass.states.get("binary_sensor.test_trend_sensor")
assert state.attributes["sample_count"] == expected_count
assert state.state == expected_state
async def test_reload(hass: HomeAssistant) -> None:
"""Verify we can reload trend sensors."""