diff --git a/homeassistant/components/bayesian/binary_sensor.py b/homeassistant/components/bayesian/binary_sensor.py index 190fb889553..1d2674255f9 100644 --- a/homeassistant/components/bayesian/binary_sensor.py +++ b/homeassistant/components/bayesian/binary_sensor.py @@ -2,12 +2,18 @@ from __future__ import annotations from collections import OrderedDict +from collections.abc import Callable import logging from typing import Any +from uuid import UUID import voluptuous as vol -from homeassistant.components.binary_sensor import PLATFORM_SCHEMA, BinarySensorEntity +from homeassistant.components.binary_sensor import ( + PLATFORM_SCHEMA, + BinarySensorDeviceClass, + BinarySensorEntity, +) from homeassistant.const import ( CONF_ABOVE, CONF_BELOW, @@ -20,18 +26,19 @@ from homeassistant.const import ( STATE_UNAVAILABLE, STATE_UNKNOWN, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import ConditionError, TemplateError from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( TrackTemplate, + TrackTemplateResult, async_track_state_change_event, async_track_template_result, ) from homeassistant.helpers.reload import async_setup_reload_service -from homeassistant.helpers.template import result_as_boolean +from homeassistant.helpers.template import Template, result_as_boolean from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import DOMAIN, PLATFORMS @@ -107,7 +114,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( ) -def update_probability(prior, prob_given_true, prob_given_false): +def update_probability( + prior: float, prob_given_true: float, prob_given_false: float +) -> float: """Update probability using Bayes' rule.""" numerator = prob_given_true * prior denominator = numerator + prob_given_false * (1 - prior) @@ -123,18 +132,18 @@ async def async_setup_platform( """Set up the Bayesian Binary sensor.""" await async_setup_reload_service(hass, DOMAIN, PLATFORMS) - name = config[CONF_NAME] - observations = config[CONF_OBSERVATIONS] - prior = config[CONF_PRIOR] - probability_threshold = config[CONF_PROBABILITY_THRESHOLD] - device_class = config.get(CONF_DEVICE_CLASS) + name: str = config[CONF_NAME] + observations: list[ConfigType] = config[CONF_OBSERVATIONS] + prior: float = config[CONF_PRIOR] + probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD] + device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS) # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas. broken_observations: list[dict[str, Any]] = [] for observation in observations: if CONF_P_GIVEN_F not in observation: text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}" - raise_no_prob_given_false(hass, observation, text) + raise_no_prob_given_false(hass, text) _LOGGER.error("Missing prob_given_false YAML entry for %s", text) broken_observations.append(observation) observations = [x for x in observations if x not in broken_observations] @@ -153,7 +162,14 @@ class BayesianBinarySensor(BinarySensorEntity): _attr_should_poll = False - def __init__(self, name, prior, observations, probability_threshold, device_class): + def __init__( + self, + name: str, + prior: float, + observations: list[ConfigType], + probability_threshold: float, + device_class: BinarySensorDeviceClass | None, + ) -> None: """Initialize the Bayesian sensor.""" self._attr_name = name self._observations = [ @@ -173,17 +189,17 @@ class BayesianBinarySensor(BinarySensorEntity): self._probability_threshold = probability_threshold self._attr_device_class = device_class self._attr_is_on = False - self._callbacks = [] + self._callbacks: list = [] self.prior = prior self.probability = prior - self.current_observations = OrderedDict({}) + self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({}) self.observations_by_entity = self._build_observations_by_entity() self.observations_by_template = self._build_observations_by_template() - self.observation_handlers = { + self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = { "numeric_state": self._process_numeric_state, "state": self._process_state, "multi_state": self._process_multi_state, @@ -205,7 +221,7 @@ class BayesianBinarySensor(BinarySensorEntity): """ @callback - def async_threshold_sensor_state_listener(event): + def async_threshold_sensor_state_listener(event: Event) -> None: """ Handle sensor state changes. @@ -213,7 +229,7 @@ class BayesianBinarySensor(BinarySensorEntity): then calculate the new probability. """ - entity = event.data.get("entity_id") + entity: str = event.data[CONF_ENTITY_ID] self.current_observations.update(self._record_entity_observations(entity)) self.async_set_context(event.context) @@ -228,11 +244,15 @@ class BayesianBinarySensor(BinarySensorEntity): ) @callback - def _async_template_result_changed(event, updates): + def _async_template_result_changed( + event: Event | None, updates: list[TrackTemplateResult] + ) -> None: track_template_result = updates.pop() template = track_template_result.template result = track_template_result.result - entity = event and event.data.get("entity_id") + entity: str | None = ( + None if event is None else event.data.get(CONF_ENTITY_ID) + ) if isinstance(result, TemplateError): _LOGGER.error( "TemplateError('%s') " @@ -252,7 +272,7 @@ class BayesianBinarySensor(BinarySensorEntity): # in some cases a template may update because of the absence of an entity if entity is not None: - observation.entity_id = str(entity) + observation.entity_id = entity self.current_observations[observation.id] = observation @@ -273,7 +293,7 @@ class BayesianBinarySensor(BinarySensorEntity): self.current_observations.update(self._initialize_current_observations()) self.probability = self._calculate_new_probability() - self._attr_is_on = bool(self.probability >= self._probability_threshold) + self._attr_is_on = self.probability >= self._probability_threshold # detect mirrored entries for entity, observations in self.observations_by_entity.items(): @@ -281,9 +301,9 @@ class BayesianBinarySensor(BinarySensorEntity): self.hass, observations, text=f"{self._attr_name}/{entity}" ) - all_template_observations = [] - for value in self.observations_by_template.values(): - all_template_observations.append(value[0]) + all_template_observations: list[Observation] = [] + for observations in self.observations_by_template.values(): + all_template_observations.append(observations[0]) if len(all_template_observations) == 2: raise_mirrored_entries( self.hass, @@ -292,62 +312,63 @@ class BayesianBinarySensor(BinarySensorEntity): ) @callback - def _recalculate_and_write_state(self): + def _recalculate_and_write_state(self) -> None: self.probability = self._calculate_new_probability() self._attr_is_on = bool(self.probability >= self._probability_threshold) self.async_write_ha_state() - def _initialize_current_observations(self): - local_observations = OrderedDict({}) - + def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]: + local_observations: OrderedDict[UUID, Observation] = OrderedDict({}) for entity in self.observations_by_entity: local_observations.update(self._record_entity_observations(entity)) return local_observations - def _record_entity_observations(self, entity): - local_observations = OrderedDict({}) + def _record_entity_observations( + self, entity: str + ) -> OrderedDict[UUID, Observation]: + local_observations: OrderedDict[UUID, Observation] = OrderedDict({}) for observation in self.observations_by_entity[entity]: platform = observation.platform - observed = self.observation_handlers[platform](observation) - observation.observed = observed + observation.observed = self.observation_handlers[platform](observation) local_observations[observation.id] = observation return local_observations - def _calculate_new_probability(self): + def _calculate_new_probability(self) -> float: prior = self.prior for observation in self.current_observations.values(): - if observation is not None: - if observation.observed is True: - prior = update_probability( - prior, - observation.prob_given_true, - observation.prob_given_false, - ) - elif observation.observed is False: - prior = update_probability( - prior, - 1 - observation.prob_given_true, - 1 - observation.prob_given_false, - ) - elif observation.observed is None: - if observation.entity_id is not None: - _LOGGER.debug( - "Observation for entity '%s' returned None, it will not be used for Bayesian updating", - observation.entity_id, - ) - else: - _LOGGER.debug( - "Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating", - ) - + if observation.observed is True: + prior = update_probability( + prior, + observation.prob_given_true, + observation.prob_given_false, + ) + continue + if observation.observed is False: + prior = update_probability( + prior, + 1 - observation.prob_given_true, + 1 - observation.prob_given_false, + ) + continue + # observation.observed is None + if observation.entity_id is not None: + _LOGGER.debug( + "Observation for entity '%s' returned None, it will not be used for Bayesian updating", + observation.entity_id, + ) + continue + _LOGGER.debug( + "Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating", + ) + # the prior has been updated and is now the posterior return prior - def _build_observations_by_entity(self): + def _build_observations_by_entity(self) -> dict[str, list[Observation]]: """ Build and return data structure of the form below. @@ -378,7 +399,7 @@ class BayesianBinarySensor(BinarySensorEntity): return observations_by_entity - def _build_observations_by_template(self): + def _build_observations_by_template(self) -> dict[Template, list[Observation]]: """ Build and return data structure of the form below. @@ -392,7 +413,7 @@ class BayesianBinarySensor(BinarySensorEntity): for all relevant observations to be looked up via their `template`. """ - observations_by_template = {} + observations_by_template: dict[Template, list[Observation]] = {} for observation in self._observations: if observation.value_template is None: continue @@ -402,7 +423,7 @@ class BayesianBinarySensor(BinarySensorEntity): return observations_by_template - def _process_numeric_state(self, entity_observation): + def _process_numeric_state(self, entity_observation: Observation) -> bool | None: """Return True if numeric condition is met, return False if not, return None otherwise.""" entity = entity_observation.entity_id @@ -420,7 +441,7 @@ class BayesianBinarySensor(BinarySensorEntity): except ConditionError: return None - def _process_state(self, entity_observation): + def _process_state(self, entity_observation: Observation) -> bool | None: """Return True if state conditions are met, return False if they are not. Returns None if the state is unavailable. @@ -436,7 +457,7 @@ class BayesianBinarySensor(BinarySensorEntity): except ConditionError: return None - def _process_multi_state(self, entity_observation): + def _process_multi_state(self, entity_observation: Observation) -> bool | None: """Return True if state conditions are met, otherwise return None. Never return False as all other states should have their own probabilities configured. @@ -452,7 +473,7 @@ class BayesianBinarySensor(BinarySensorEntity): return None @property - def extra_state_attributes(self): + def extra_state_attributes(self) -> dict[str, Any]: """Return the state attributes of the sensor.""" return { diff --git a/homeassistant/components/bayesian/helpers.py b/homeassistant/components/bayesian/helpers.py index 22c5d518b46..6e78de63607 100644 --- a/homeassistant/components/bayesian/helpers.py +++ b/homeassistant/components/bayesian/helpers.py @@ -18,7 +18,10 @@ from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE @dataclass class Observation: - """Representation of a sensor or template observation.""" + """Representation of a sensor or template observation. + + Either entity_id or value_template should be non-None. + """ entity_id: str | None platform: str @@ -29,7 +32,7 @@ class Observation: below: float | None value_template: Template | None observed: bool | None = None - id: str = field(default_factory=lambda: str(uuid.uuid4())) + id: uuid.UUID = field(default_factory=uuid.uuid4) def to_dict(self) -> dict[str, str | float | bool | None]: """Represent Class as a Dict for easier serialization.""" diff --git a/homeassistant/components/bayesian/repairs.py b/homeassistant/components/bayesian/repairs.py index 2b04a6a6605..9a527636948 100644 --- a/homeassistant/components/bayesian/repairs.py +++ b/homeassistant/components/bayesian/repairs.py @@ -5,9 +5,12 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry from . import DOMAIN +from .helpers import Observation -def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") -> None: +def raise_mirrored_entries( + hass: HomeAssistant, observations: list[Observation], text: str = "" +) -> None: """If there are mirrored entries, the user is probably using a workaround for a patched bug.""" if len(observations) != 2: return @@ -26,7 +29,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") -> # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas. -def raise_no_prob_given_false(hass: HomeAssistant, observation, text: str) -> None: +def raise_no_prob_given_false(hass: HomeAssistant, text: str) -> None: """In previous 2022.9 and earlier, prob_given_false was optional and had a default version.""" issue_registry.async_create_issue( hass,