Accept more than 1 state for numeric entities in Bayesian (#119281)

* test driven delevopment

* test driven development - multi numeric state

* better multi-state processing

* when state==below return true

* adds test for a bad state

* improve codecov

* value error already handled in async_numeric_state

* remove whitespace

* remove async_get

* linting

* test_driven dev for error handling

* make tests fail correctly

* ensure tests fail correctly

* prevent bad numeric entries

* ensure no overlapping ranges

* fix tests, as error caught in validation

* remove redundant er call

* remove reddundant arg

* improves code coverage

* filter for numeric states before testing overlap

* adress code review

* skip non numeric configs but continue

* wait to avoid race condition

* Better tuples name and better guard clause

* better test description

* more accurate description

* Add comments to calculations

* using typing not collections as per ruff

* Apply suggestions from code review

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* follow on from suggestions

* Lazy evaluation

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* update error text in tests

* fix broken tests

* move validation function call

* fixes return type of above_greater_than_below.

* improves codecov

* fixes validation

---------

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
HarvsG 2024-09-12 11:06:18 +01:00 committed by GitHub
parent c21ea6b8da
commit 70ebf2f5d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 464 additions and 87 deletions

View file

@ -5,7 +5,8 @@ from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable
import logging
from typing import Any
import math
from typing import TYPE_CHECKING, Any, NamedTuple
from uuid import UUID
import voluptuous as vol
@ -50,6 +51,7 @@ from .const import (
ATTR_OCCURRED_OBSERVATION_ENTITIES,
ATTR_PROBABILITY,
ATTR_PROBABILITY_THRESHOLD,
CONF_NUMERIC_STATE,
CONF_OBSERVATIONS,
CONF_P_GIVEN_F,
CONF_P_GIVEN_T,
@ -66,18 +68,74 @@ from .issues import raise_mirrored_entries, raise_no_prob_given_false
_LOGGER = logging.getLogger(__name__)
NUMERIC_STATE_SCHEMA = vol.Schema(
{
CONF_PLATFORM: "numeric_state",
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_BELOW): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
},
required=True,
def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
above = config.get(CONF_ABOVE)
below = config.get(CONF_BELOW)
if above is None and below is None:
_LOGGER.error(
"For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
config[CONF_ENTITY_ID],
)
raise vol.Invalid(
"For bayesian numeric state at least one of 'above' or 'below' must be specified."
)
if above is not None and below is not None:
if above > below:
_LOGGER.error(
"For bayesian numeric state 'above' (%s) must be less than 'below' (%s)",
above,
below,
)
raise vol.Invalid("'above' is greater than 'below'")
return config
NUMERIC_STATE_SCHEMA = vol.All(
vol.Schema(
{
CONF_PLATFORM: CONF_NUMERIC_STATE,
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_BELOW): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
},
required=True,
),
_above_greater_than_below,
)
def _no_overlapping(configs: list[dict]) -> list[dict]:
numeric_configs = [
config for config in configs if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
]
if len(numeric_configs) < 2:
return configs
class NumericConfig(NamedTuple):
above: float
below: float
d: dict[str, list[NumericConfig]] = {}
for _, config in enumerate(numeric_configs):
above = config.get(CONF_ABOVE, -math.inf)
below = config.get(CONF_BELOW, math.inf)
entity_id: str = str(config[CONF_ENTITY_ID])
d.setdefault(entity_id, []).append(NumericConfig(above, below))
for ent_id, intervals in d.items():
intervals = sorted(intervals, key=lambda tup: tup.above)
for i, tup in enumerate(intervals):
if len(intervals) > i + 1 and tup.below > intervals[i + 1].above:
raise vol.Invalid(
f"Ranges for bayesian numeric state entities must not overlap, but {ent_id} has overlapping ranges, above:{tup.above}, below:{tup.below} overlaps with above:{intervals[i+1].above}, below:{intervals[i+1].below}."
)
return configs
STATE_SCHEMA = vol.Schema(
{
CONF_PLATFORM: CONF_STATE,
@ -107,7 +165,8 @@ PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
vol.Required(CONF_OBSERVATIONS): vol.Schema(
vol.All(
cv.ensure_list,
[vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA, TEMPLATE_SCHEMA)],
[vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
_no_overlapping,
)
),
vol.Required(CONF_PRIOR): vol.Coerce(float),
@ -211,10 +270,11 @@ class BayesianBinarySensor(BinarySensorEntity):
self.observations_by_entity = self._build_observations_by_entity()
self.observations_by_template = self._build_observations_by_template()
self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = {
self.observation_handlers: dict[
str, Callable[[Observation, bool], bool | None]
] = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"multi_state": self._process_multi_state,
}
async def async_added_to_hass(self) -> None:
@ -342,8 +402,9 @@ class BayesianBinarySensor(BinarySensorEntity):
for observation in self.observations_by_entity[entity]:
platform = observation.platform
observation.observed = self.observation_handlers[platform](observation)
observation.observed = self.observation_handlers[platform](
observation, observation.multi
)
local_observations[observation.id] = observation
return local_observations
@ -408,9 +469,7 @@ class BayesianBinarySensor(BinarySensorEntity):
if len(entity_observations) == 1:
continue
for observation in entity_observations:
if observation.platform != "state":
continue
observation.platform = "multi_state"
observation.multi = True
return observations_by_entity
@ -437,14 +496,23 @@ class BayesianBinarySensor(BinarySensorEntity):
return observations_by_template
def _process_numeric_state(self, entity_observation: Observation) -> bool | None:
def _process_numeric_state(
self, entity_observation: Observation, multi: bool = False
) -> bool | None:
"""Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation.entity_id
entity_id = entity_observation.entity_id
# if we are dealing with numeric_state observations entity_id cannot be None
if TYPE_CHECKING:
assert entity_id is not None
entity = self.hass.states.get(entity_id)
if entity is None:
return None
try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None
return condition.async_numeric_state(
result = condition.async_numeric_state(
self.hass,
entity,
entity_observation.below,
@ -452,10 +520,24 @@ class BayesianBinarySensor(BinarySensorEntity):
None,
entity_observation.to_dict(),
)
if result:
return True
if multi:
state = float(entity.state)
if (
entity_observation.below is not None
and state == entity_observation.below
):
return True
return None
except ConditionError:
return None
else:
return False
def _process_state(self, entity_observation: Observation) -> bool | None:
def _process_state(
self, entity_observation: Observation, multi: bool = False
) -> bool | None:
"""Return True if state conditions are met, return False if they are not.
Returns None if the state is unavailable.
@ -467,24 +549,13 @@ class BayesianBinarySensor(BinarySensorEntity):
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None
return condition.state(self.hass, entity, entity_observation.to_state)
result = condition.state(self.hass, entity, entity_observation.to_state)
if multi and not result:
return None
except ConditionError:
return None
def _process_multi_state(self, entity_observation: Observation) -> bool | None:
"""Return True if state conditions are met, otherwise return None.
Never return False as all other states should have their own probabilities configured.
"""
entity = entity_observation.entity_id
try:
if condition.state(self.hass, entity, entity_observation.to_state):
return True
except ConditionError:
return None
return None
else:
return result
@property
def extra_state_attributes(self) -> dict[str, Any]: