Accept more than 1 state for numeric entities in Bayesian (#119281)
* test driven delevopment * test driven development - multi numeric state * better multi-state processing * when state==below return true * adds test for a bad state * improve codecov * value error already handled in async_numeric_state * remove whitespace * remove async_get * linting * test_driven dev for error handling * make tests fail correctly * ensure tests fail correctly * prevent bad numeric entries * ensure no overlapping ranges * fix tests, as error caught in validation * remove redundant er call * remove reddundant arg * improves code coverage * filter for numeric states before testing overlap * adress code review * skip non numeric configs but continue * wait to avoid race condition * Better tuples name and better guard clause * better test description * more accurate description * Add comments to calculations * using typing not collections as per ruff * Apply suggestions from code review Co-authored-by: Erik Montnemery <erik@montnemery.com> * follow on from suggestions * Lazy evaluation Co-authored-by: Erik Montnemery <erik@montnemery.com> * update error text in tests * fix broken tests * move validation function call * fixes return type of above_greater_than_below. * improves codecov * fixes validation --------- Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
parent
c21ea6b8da
commit
70ebf2f5d8
4 changed files with 464 additions and 87 deletions
|
@ -5,7 +5,8 @@ from __future__ import annotations
|
|||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -50,6 +51,7 @@ from .const import (
|
|||
ATTR_OCCURRED_OBSERVATION_ENTITIES,
|
||||
ATTR_PROBABILITY,
|
||||
ATTR_PROBABILITY_THRESHOLD,
|
||||
CONF_NUMERIC_STATE,
|
||||
CONF_OBSERVATIONS,
|
||||
CONF_P_GIVEN_F,
|
||||
CONF_P_GIVEN_T,
|
||||
|
@ -66,18 +68,74 @@ from .issues import raise_mirrored_entries, raise_no_prob_given_false
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
NUMERIC_STATE_SCHEMA = vol.Schema(
|
||||
{
|
||||
CONF_PLATFORM: "numeric_state",
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_id,
|
||||
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||||
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||||
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
|
||||
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
|
||||
},
|
||||
required=True,
|
||||
def _above_greater_than_below(config: dict[str, Any]) -> dict[str, Any]:
|
||||
if config[CONF_PLATFORM] == CONF_NUMERIC_STATE:
|
||||
above = config.get(CONF_ABOVE)
|
||||
below = config.get(CONF_BELOW)
|
||||
if above is None and below is None:
|
||||
_LOGGER.error(
|
||||
"For bayesian numeric state for entity: %s at least one of 'above' or 'below' must be specified",
|
||||
config[CONF_ENTITY_ID],
|
||||
)
|
||||
raise vol.Invalid(
|
||||
"For bayesian numeric state at least one of 'above' or 'below' must be specified."
|
||||
)
|
||||
if above is not None and below is not None:
|
||||
if above > below:
|
||||
_LOGGER.error(
|
||||
"For bayesian numeric state 'above' (%s) must be less than 'below' (%s)",
|
||||
above,
|
||||
below,
|
||||
)
|
||||
raise vol.Invalid("'above' is greater than 'below'")
|
||||
return config
|
||||
|
||||
|
||||
NUMERIC_STATE_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
CONF_PLATFORM: CONF_NUMERIC_STATE,
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_id,
|
||||
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||||
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||||
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
|
||||
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
|
||||
},
|
||||
required=True,
|
||||
),
|
||||
_above_greater_than_below,
|
||||
)
|
||||
|
||||
|
||||
def _no_overlapping(configs: list[dict]) -> list[dict]:
|
||||
numeric_configs = [
|
||||
config for config in configs if config[CONF_PLATFORM] == CONF_NUMERIC_STATE
|
||||
]
|
||||
if len(numeric_configs) < 2:
|
||||
return configs
|
||||
|
||||
class NumericConfig(NamedTuple):
|
||||
above: float
|
||||
below: float
|
||||
|
||||
d: dict[str, list[NumericConfig]] = {}
|
||||
for _, config in enumerate(numeric_configs):
|
||||
above = config.get(CONF_ABOVE, -math.inf)
|
||||
below = config.get(CONF_BELOW, math.inf)
|
||||
entity_id: str = str(config[CONF_ENTITY_ID])
|
||||
d.setdefault(entity_id, []).append(NumericConfig(above, below))
|
||||
|
||||
for ent_id, intervals in d.items():
|
||||
intervals = sorted(intervals, key=lambda tup: tup.above)
|
||||
|
||||
for i, tup in enumerate(intervals):
|
||||
if len(intervals) > i + 1 and tup.below > intervals[i + 1].above:
|
||||
raise vol.Invalid(
|
||||
f"Ranges for bayesian numeric state entities must not overlap, but {ent_id} has overlapping ranges, above:{tup.above}, below:{tup.below} overlaps with above:{intervals[i+1].above}, below:{intervals[i+1].below}."
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
STATE_SCHEMA = vol.Schema(
|
||||
{
|
||||
CONF_PLATFORM: CONF_STATE,
|
||||
|
@ -107,7 +165,8 @@ PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend(
|
|||
vol.Required(CONF_OBSERVATIONS): vol.Schema(
|
||||
vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA, TEMPLATE_SCHEMA)],
|
||||
[vol.Any(TEMPLATE_SCHEMA, STATE_SCHEMA, NUMERIC_STATE_SCHEMA)],
|
||||
_no_overlapping,
|
||||
)
|
||||
),
|
||||
vol.Required(CONF_PRIOR): vol.Coerce(float),
|
||||
|
@ -211,10 +270,11 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
self.observations_by_entity = self._build_observations_by_entity()
|
||||
self.observations_by_template = self._build_observations_by_template()
|
||||
|
||||
self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = {
|
||||
self.observation_handlers: dict[
|
||||
str, Callable[[Observation, bool], bool | None]
|
||||
] = {
|
||||
"numeric_state": self._process_numeric_state,
|
||||
"state": self._process_state,
|
||||
"multi_state": self._process_multi_state,
|
||||
}
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
|
@ -342,8 +402,9 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
for observation in self.observations_by_entity[entity]:
|
||||
platform = observation.platform
|
||||
|
||||
observation.observed = self.observation_handlers[platform](observation)
|
||||
|
||||
observation.observed = self.observation_handlers[platform](
|
||||
observation, observation.multi
|
||||
)
|
||||
local_observations[observation.id] = observation
|
||||
|
||||
return local_observations
|
||||
|
@ -408,9 +469,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
if len(entity_observations) == 1:
|
||||
continue
|
||||
for observation in entity_observations:
|
||||
if observation.platform != "state":
|
||||
continue
|
||||
observation.platform = "multi_state"
|
||||
observation.multi = True
|
||||
|
||||
return observations_by_entity
|
||||
|
||||
|
@ -437,14 +496,23 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
return observations_by_template
|
||||
|
||||
def _process_numeric_state(self, entity_observation: Observation) -> bool | None:
|
||||
def _process_numeric_state(
|
||||
self, entity_observation: Observation, multi: bool = False
|
||||
) -> bool | None:
|
||||
"""Return True if numeric condition is met, return False if not, return None otherwise."""
|
||||
entity = entity_observation.entity_id
|
||||
entity_id = entity_observation.entity_id
|
||||
# if we are dealing with numeric_state observations entity_id cannot be None
|
||||
if TYPE_CHECKING:
|
||||
assert entity_id is not None
|
||||
|
||||
entity = self.hass.states.get(entity_id)
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||
return None
|
||||
return condition.async_numeric_state(
|
||||
result = condition.async_numeric_state(
|
||||
self.hass,
|
||||
entity,
|
||||
entity_observation.below,
|
||||
|
@ -452,10 +520,24 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
None,
|
||||
entity_observation.to_dict(),
|
||||
)
|
||||
if result:
|
||||
return True
|
||||
if multi:
|
||||
state = float(entity.state)
|
||||
if (
|
||||
entity_observation.below is not None
|
||||
and state == entity_observation.below
|
||||
):
|
||||
return True
|
||||
return None
|
||||
except ConditionError:
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
|
||||
def _process_state(self, entity_observation: Observation) -> bool | None:
|
||||
def _process_state(
|
||||
self, entity_observation: Observation, multi: bool = False
|
||||
) -> bool | None:
|
||||
"""Return True if state conditions are met, return False if they are not.
|
||||
|
||||
Returns None if the state is unavailable.
|
||||
|
@ -467,24 +549,13 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||
return None
|
||||
|
||||
return condition.state(self.hass, entity, entity_observation.to_state)
|
||||
result = condition.state(self.hass, entity, entity_observation.to_state)
|
||||
if multi and not result:
|
||||
return None
|
||||
except ConditionError:
|
||||
return None
|
||||
|
||||
def _process_multi_state(self, entity_observation: Observation) -> bool | None:
|
||||
"""Return True if state conditions are met, otherwise return None.
|
||||
|
||||
Never return False as all other states should have their own probabilities configured.
|
||||
"""
|
||||
|
||||
entity = entity_observation.entity_id
|
||||
|
||||
try:
|
||||
if condition.state(self.hass, entity, entity_observation.to_state):
|
||||
return True
|
||||
except ConditionError:
|
||||
return None
|
||||
return None
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue