Improve typing and code quality in beyesian (#79603)
* strict typing * Detail implication * adds newline * don't change indenting * really dont change indenting * Update homeassistant/components/bayesian/binary_sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * typing in async_setup_platform() + remove arg * less ambiguity * mypy thinks Literal[False] otherwise * clearer log * don't use `and` assignments * observations not values * clarify can be None * observation can't be none * assert we have at least one * make it clearer where we're using UUIDs * remove unnecessary bool Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Unnecessary None handling Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Better type setting Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Reccomended changes. * remove if statement not needed * Not strict until _TrackTemplateResultInfo fixed Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
parent
a18a0b39dd
commit
9d351a3c10
3 changed files with 94 additions and 67 deletions
|
@ -2,12 +2,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.binary_sensor import PLATFORM_SCHEMA, BinarySensorEntity
|
||||
from homeassistant.components.binary_sensor import (
|
||||
PLATFORM_SCHEMA,
|
||||
BinarySensorDeviceClass,
|
||||
BinarySensorEntity,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
CONF_ABOVE,
|
||||
CONF_BELOW,
|
||||
|
@ -20,18 +26,19 @@ from homeassistant.const import (
|
|||
STATE_UNAVAILABLE,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConditionError, TemplateError
|
||||
from homeassistant.helpers import condition
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.event import (
|
||||
TrackTemplate,
|
||||
TrackTemplateResult,
|
||||
async_track_state_change_event,
|
||||
async_track_template_result,
|
||||
)
|
||||
from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.template import result_as_boolean
|
||||
from homeassistant.helpers.template import Template, result_as_boolean
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import DOMAIN, PLATFORMS
|
||||
|
@ -107,7 +114,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
def update_probability(prior, prob_given_true, prob_given_false):
|
||||
def update_probability(
|
||||
prior: float, prob_given_true: float, prob_given_false: float
|
||||
) -> float:
|
||||
"""Update probability using Bayes' rule."""
|
||||
numerator = prob_given_true * prior
|
||||
denominator = numerator + prob_given_false * (1 - prior)
|
||||
|
@ -123,18 +132,18 @@ async def async_setup_platform(
|
|||
"""Set up the Bayesian Binary sensor."""
|
||||
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
|
||||
|
||||
name = config[CONF_NAME]
|
||||
observations = config[CONF_OBSERVATIONS]
|
||||
prior = config[CONF_PRIOR]
|
||||
probability_threshold = config[CONF_PROBABILITY_THRESHOLD]
|
||||
device_class = config.get(CONF_DEVICE_CLASS)
|
||||
name: str = config[CONF_NAME]
|
||||
observations: list[ConfigType] = config[CONF_OBSERVATIONS]
|
||||
prior: float = config[CONF_PRIOR]
|
||||
probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
|
||||
device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
|
||||
|
||||
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
|
||||
broken_observations: list[dict[str, Any]] = []
|
||||
for observation in observations:
|
||||
if CONF_P_GIVEN_F not in observation:
|
||||
text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}"
|
||||
raise_no_prob_given_false(hass, observation, text)
|
||||
raise_no_prob_given_false(hass, text)
|
||||
_LOGGER.error("Missing prob_given_false YAML entry for %s", text)
|
||||
broken_observations.append(observation)
|
||||
observations = [x for x in observations if x not in broken_observations]
|
||||
|
@ -153,7 +162,14 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, name, prior, observations, probability_threshold, device_class):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
prior: float,
|
||||
observations: list[ConfigType],
|
||||
probability_threshold: float,
|
||||
device_class: BinarySensorDeviceClass | None,
|
||||
) -> None:
|
||||
"""Initialize the Bayesian sensor."""
|
||||
self._attr_name = name
|
||||
self._observations = [
|
||||
|
@ -173,17 +189,17 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
self._probability_threshold = probability_threshold
|
||||
self._attr_device_class = device_class
|
||||
self._attr_is_on = False
|
||||
self._callbacks = []
|
||||
self._callbacks: list = []
|
||||
|
||||
self.prior = prior
|
||||
self.probability = prior
|
||||
|
||||
self.current_observations = OrderedDict({})
|
||||
self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({})
|
||||
|
||||
self.observations_by_entity = self._build_observations_by_entity()
|
||||
self.observations_by_template = self._build_observations_by_template()
|
||||
|
||||
self.observation_handlers = {
|
||||
self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = {
|
||||
"numeric_state": self._process_numeric_state,
|
||||
"state": self._process_state,
|
||||
"multi_state": self._process_multi_state,
|
||||
|
@ -205,7 +221,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
"""
|
||||
|
||||
@callback
|
||||
def async_threshold_sensor_state_listener(event):
|
||||
def async_threshold_sensor_state_listener(event: Event) -> None:
|
||||
"""
|
||||
Handle sensor state changes.
|
||||
|
||||
|
@ -213,7 +229,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
then calculate the new probability.
|
||||
"""
|
||||
|
||||
entity = event.data.get("entity_id")
|
||||
entity: str = event.data[CONF_ENTITY_ID]
|
||||
|
||||
self.current_observations.update(self._record_entity_observations(entity))
|
||||
self.async_set_context(event.context)
|
||||
|
@ -228,11 +244,15 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
)
|
||||
|
||||
@callback
|
||||
def _async_template_result_changed(event, updates):
|
||||
def _async_template_result_changed(
|
||||
event: Event | None, updates: list[TrackTemplateResult]
|
||||
) -> None:
|
||||
track_template_result = updates.pop()
|
||||
template = track_template_result.template
|
||||
result = track_template_result.result
|
||||
entity = event and event.data.get("entity_id")
|
||||
entity: str | None = (
|
||||
None if event is None else event.data.get(CONF_ENTITY_ID)
|
||||
)
|
||||
if isinstance(result, TemplateError):
|
||||
_LOGGER.error(
|
||||
"TemplateError('%s') "
|
||||
|
@ -252,7 +272,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
# in some cases a template may update because of the absence of an entity
|
||||
if entity is not None:
|
||||
observation.entity_id = str(entity)
|
||||
observation.entity_id = entity
|
||||
|
||||
self.current_observations[observation.id] = observation
|
||||
|
||||
|
@ -273,7 +293,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
self.current_observations.update(self._initialize_current_observations())
|
||||
self.probability = self._calculate_new_probability()
|
||||
self._attr_is_on = bool(self.probability >= self._probability_threshold)
|
||||
self._attr_is_on = self.probability >= self._probability_threshold
|
||||
|
||||
# detect mirrored entries
|
||||
for entity, observations in self.observations_by_entity.items():
|
||||
|
@ -281,9 +301,9 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
self.hass, observations, text=f"{self._attr_name}/{entity}"
|
||||
)
|
||||
|
||||
all_template_observations = []
|
||||
for value in self.observations_by_template.values():
|
||||
all_template_observations.append(value[0])
|
||||
all_template_observations: list[Observation] = []
|
||||
for observations in self.observations_by_template.values():
|
||||
all_template_observations.append(observations[0])
|
||||
if len(all_template_observations) == 2:
|
||||
raise_mirrored_entries(
|
||||
self.hass,
|
||||
|
@ -292,62 +312,63 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
)
|
||||
|
||||
@callback
|
||||
def _recalculate_and_write_state(self):
|
||||
def _recalculate_and_write_state(self) -> None:
|
||||
self.probability = self._calculate_new_probability()
|
||||
self._attr_is_on = bool(self.probability >= self._probability_threshold)
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _initialize_current_observations(self):
|
||||
local_observations = OrderedDict({})
|
||||
|
||||
def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]:
|
||||
local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
|
||||
for entity in self.observations_by_entity:
|
||||
local_observations.update(self._record_entity_observations(entity))
|
||||
return local_observations
|
||||
|
||||
def _record_entity_observations(self, entity):
|
||||
local_observations = OrderedDict({})
|
||||
def _record_entity_observations(
|
||||
self, entity: str
|
||||
) -> OrderedDict[UUID, Observation]:
|
||||
local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
|
||||
|
||||
for observation in self.observations_by_entity[entity]:
|
||||
platform = observation.platform
|
||||
|
||||
observed = self.observation_handlers[platform](observation)
|
||||
observation.observed = observed
|
||||
observation.observed = self.observation_handlers[platform](observation)
|
||||
|
||||
local_observations[observation.id] = observation
|
||||
|
||||
return local_observations
|
||||
|
||||
def _calculate_new_probability(self):
|
||||
def _calculate_new_probability(self) -> float:
|
||||
prior = self.prior
|
||||
|
||||
for observation in self.current_observations.values():
|
||||
if observation is not None:
|
||||
if observation.observed is True:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
observation.prob_given_true,
|
||||
observation.prob_given_false,
|
||||
)
|
||||
elif observation.observed is False:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
1 - observation.prob_given_true,
|
||||
1 - observation.prob_given_false,
|
||||
)
|
||||
elif observation.observed is None:
|
||||
if observation.entity_id is not None:
|
||||
_LOGGER.debug(
|
||||
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
|
||||
observation.entity_id,
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
|
||||
)
|
||||
|
||||
if observation.observed is True:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
observation.prob_given_true,
|
||||
observation.prob_given_false,
|
||||
)
|
||||
continue
|
||||
if observation.observed is False:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
1 - observation.prob_given_true,
|
||||
1 - observation.prob_given_false,
|
||||
)
|
||||
continue
|
||||
# observation.observed is None
|
||||
if observation.entity_id is not None:
|
||||
_LOGGER.debug(
|
||||
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
|
||||
observation.entity_id,
|
||||
)
|
||||
continue
|
||||
_LOGGER.debug(
|
||||
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
|
||||
)
|
||||
# the prior has been updated and is now the posterior
|
||||
return prior
|
||||
|
||||
def _build_observations_by_entity(self):
|
||||
def _build_observations_by_entity(self) -> dict[str, list[Observation]]:
|
||||
"""
|
||||
Build and return data structure of the form below.
|
||||
|
||||
|
@ -378,7 +399,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
return observations_by_entity
|
||||
|
||||
def _build_observations_by_template(self):
|
||||
def _build_observations_by_template(self) -> dict[Template, list[Observation]]:
|
||||
"""
|
||||
Build and return data structure of the form below.
|
||||
|
||||
|
@ -392,7 +413,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
for all relevant observations to be looked up via their `template`.
|
||||
"""
|
||||
|
||||
observations_by_template = {}
|
||||
observations_by_template: dict[Template, list[Observation]] = {}
|
||||
for observation in self._observations:
|
||||
if observation.value_template is None:
|
||||
continue
|
||||
|
@ -402,7 +423,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
return observations_by_template
|
||||
|
||||
def _process_numeric_state(self, entity_observation):
|
||||
def _process_numeric_state(self, entity_observation: Observation) -> bool | None:
|
||||
"""Return True if numeric condition is met, return False if not, return None otherwise."""
|
||||
entity = entity_observation.entity_id
|
||||
|
||||
|
@ -420,7 +441,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
except ConditionError:
|
||||
return None
|
||||
|
||||
def _process_state(self, entity_observation):
|
||||
def _process_state(self, entity_observation: Observation) -> bool | None:
|
||||
"""Return True if state conditions are met, return False if they are not.
|
||||
|
||||
Returns None if the state is unavailable.
|
||||
|
@ -436,7 +457,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
except ConditionError:
|
||||
return None
|
||||
|
||||
def _process_multi_state(self, entity_observation):
|
||||
def _process_multi_state(self, entity_observation: Observation) -> bool | None:
|
||||
"""Return True if state conditions are met, otherwise return None.
|
||||
|
||||
Never return False as all other states should have their own probabilities configured.
|
||||
|
@ -452,7 +473,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
return None
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self):
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return the state attributes of the sensor."""
|
||||
|
||||
return {
|
||||
|
|
|
@ -18,7 +18,10 @@ from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
|
|||
|
||||
@dataclass
|
||||
class Observation:
|
||||
"""Representation of a sensor or template observation."""
|
||||
"""Representation of a sensor or template observation.
|
||||
|
||||
Either entity_id or value_template should be non-None.
|
||||
"""
|
||||
|
||||
entity_id: str | None
|
||||
platform: str
|
||||
|
@ -29,7 +32,7 @@ class Observation:
|
|||
below: float | None
|
||||
value_template: Template | None
|
||||
observed: bool | None = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
id: uuid.UUID = field(default_factory=uuid.uuid4)
|
||||
|
||||
def to_dict(self) -> dict[str, str | float | bool | None]:
|
||||
"""Represent Class as a Dict for easier serialization."""
|
||||
|
|
|
@ -5,9 +5,12 @@ from homeassistant.core import HomeAssistant
|
|||
from homeassistant.helpers import issue_registry
|
||||
|
||||
from . import DOMAIN
|
||||
from .helpers import Observation
|
||||
|
||||
|
||||
def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") -> None:
|
||||
def raise_mirrored_entries(
|
||||
hass: HomeAssistant, observations: list[Observation], text: str = ""
|
||||
) -> None:
|
||||
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
|
||||
if len(observations) != 2:
|
||||
return
|
||||
|
@ -26,7 +29,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
|
|||
|
||||
|
||||
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
|
||||
def raise_no_prob_given_false(hass: HomeAssistant, observation, text: str) -> None:
|
||||
def raise_no_prob_given_false(hass: HomeAssistant, text: str) -> None:
|
||||
"""In previous 2022.9 and earlier, prob_given_false was optional and had a default version."""
|
||||
issue_registry.async_create_issue(
|
||||
hass,
|
||||
|
|
Loading…
Add table
Reference in a new issue