Improve typing and code quality in beyesian (#79603)

* strict typing

* Detail implication

* adds newline

* don't change indenting

* really dont change indenting

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* typing in async_setup_platform() + remove arg

* less ambiguity

* mypy thinks Literal[False] otherwise

* clearer log

* don't use `and` assignments

* observations not values

* clarify can be None

* observation can't be none

* assert we have at least one

* make it clearer where we're using UUIDs

* remove unnecessary bool

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Unnecessary None handling

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Better type setting

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Reccomended changes.

* remove if statement not needed

* Not strict until _TrackTemplateResultInfo fixed

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
HarvsG 2022-10-07 21:23:25 +01:00 committed by GitHub
parent a18a0b39dd
commit 9d351a3c10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 67 deletions

View file

@ -2,12 +2,18 @@
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable
import logging
from typing import Any
from uuid import UUID
import voluptuous as vol
from homeassistant.components.binary_sensor import PLATFORM_SCHEMA, BinarySensorEntity
from homeassistant.components.binary_sensor import (
PLATFORM_SCHEMA,
BinarySensorDeviceClass,
BinarySensorEntity,
)
from homeassistant.const import (
CONF_ABOVE,
CONF_BELOW,
@ -20,18 +26,19 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConditionError, TemplateError
from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
async_track_state_change_event,
async_track_template_result,
)
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.template import result_as_boolean
from homeassistant.helpers.template import Template, result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN, PLATFORMS
@ -107,7 +114,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
def update_probability(prior, prob_given_true, prob_given_false):
def update_probability(
prior: float, prob_given_true: float, prob_given_false: float
) -> float:
"""Update probability using Bayes' rule."""
numerator = prob_given_true * prior
denominator = numerator + prob_given_false * (1 - prior)
@ -123,18 +132,18 @@ async def async_setup_platform(
"""Set up the Bayesian Binary sensor."""
await async_setup_reload_service(hass, DOMAIN, PLATFORMS)
name = config[CONF_NAME]
observations = config[CONF_OBSERVATIONS]
prior = config[CONF_PRIOR]
probability_threshold = config[CONF_PROBABILITY_THRESHOLD]
device_class = config.get(CONF_DEVICE_CLASS)
name: str = config[CONF_NAME]
observations: list[ConfigType] = config[CONF_OBSERVATIONS]
prior: float = config[CONF_PRIOR]
probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD]
device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS)
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
broken_observations: list[dict[str, Any]] = []
for observation in observations:
if CONF_P_GIVEN_F not in observation:
text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}"
raise_no_prob_given_false(hass, observation, text)
raise_no_prob_given_false(hass, text)
_LOGGER.error("Missing prob_given_false YAML entry for %s", text)
broken_observations.append(observation)
observations = [x for x in observations if x not in broken_observations]
@ -153,7 +162,14 @@ class BayesianBinarySensor(BinarySensorEntity):
_attr_should_poll = False
def __init__(self, name, prior, observations, probability_threshold, device_class):
def __init__(
self,
name: str,
prior: float,
observations: list[ConfigType],
probability_threshold: float,
device_class: BinarySensorDeviceClass | None,
) -> None:
"""Initialize the Bayesian sensor."""
self._attr_name = name
self._observations = [
@ -173,17 +189,17 @@ class BayesianBinarySensor(BinarySensorEntity):
self._probability_threshold = probability_threshold
self._attr_device_class = device_class
self._attr_is_on = False
self._callbacks = []
self._callbacks: list = []
self.prior = prior
self.probability = prior
self.current_observations = OrderedDict({})
self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({})
self.observations_by_entity = self._build_observations_by_entity()
self.observations_by_template = self._build_observations_by_template()
self.observation_handlers = {
self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"multi_state": self._process_multi_state,
@ -205,7 +221,7 @@ class BayesianBinarySensor(BinarySensorEntity):
"""
@callback
def async_threshold_sensor_state_listener(event):
def async_threshold_sensor_state_listener(event: Event) -> None:
"""
Handle sensor state changes.
@ -213,7 +229,7 @@ class BayesianBinarySensor(BinarySensorEntity):
then calculate the new probability.
"""
entity = event.data.get("entity_id")
entity: str = event.data[CONF_ENTITY_ID]
self.current_observations.update(self._record_entity_observations(entity))
self.async_set_context(event.context)
@ -228,11 +244,15 @@ class BayesianBinarySensor(BinarySensorEntity):
)
@callback
def _async_template_result_changed(event, updates):
def _async_template_result_changed(
event: Event | None, updates: list[TrackTemplateResult]
) -> None:
track_template_result = updates.pop()
template = track_template_result.template
result = track_template_result.result
entity = event and event.data.get("entity_id")
entity: str | None = (
None if event is None else event.data.get(CONF_ENTITY_ID)
)
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') "
@ -252,7 +272,7 @@ class BayesianBinarySensor(BinarySensorEntity):
# in some cases a template may update because of the absence of an entity
if entity is not None:
observation.entity_id = str(entity)
observation.entity_id = entity
self.current_observations[observation.id] = observation
@ -273,7 +293,7 @@ class BayesianBinarySensor(BinarySensorEntity):
self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
self._attr_is_on = bool(self.probability >= self._probability_threshold)
self._attr_is_on = self.probability >= self._probability_threshold
# detect mirrored entries
for entity, observations in self.observations_by_entity.items():
@ -281,9 +301,9 @@ class BayesianBinarySensor(BinarySensorEntity):
self.hass, observations, text=f"{self._attr_name}/{entity}"
)
all_template_observations = []
for value in self.observations_by_template.values():
all_template_observations.append(value[0])
all_template_observations: list[Observation] = []
for observations in self.observations_by_template.values():
all_template_observations.append(observations[0])
if len(all_template_observations) == 2:
raise_mirrored_entries(
self.hass,
@ -292,62 +312,63 @@ class BayesianBinarySensor(BinarySensorEntity):
)
@callback
def _recalculate_and_write_state(self):
def _recalculate_and_write_state(self) -> None:
self.probability = self._calculate_new_probability()
self._attr_is_on = bool(self.probability >= self._probability_threshold)
self.async_write_ha_state()
def _initialize_current_observations(self):
local_observations = OrderedDict({})
def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]:
local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity))
return local_observations
def _record_entity_observations(self, entity):
local_observations = OrderedDict({})
def _record_entity_observations(
self, entity: str
) -> OrderedDict[UUID, Observation]:
local_observations: OrderedDict[UUID, Observation] = OrderedDict({})
for observation in self.observations_by_entity[entity]:
platform = observation.platform
observed = self.observation_handlers[platform](observation)
observation.observed = observed
observation.observed = self.observation_handlers[platform](observation)
local_observations[observation.id] = observation
return local_observations
def _calculate_new_probability(self):
def _calculate_new_probability(self) -> float:
prior = self.prior
for observation in self.current_observations.values():
if observation is not None:
if observation.observed is True:
prior = update_probability(
prior,
observation.prob_given_true,
observation.prob_given_false,
)
elif observation.observed is False:
prior = update_probability(
prior,
1 - observation.prob_given_true,
1 - observation.prob_given_false,
)
elif observation.observed is None:
if observation.entity_id is not None:
_LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
observation.entity_id,
)
else:
_LOGGER.debug(
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
)
if observation.observed is True:
prior = update_probability(
prior,
observation.prob_given_true,
observation.prob_given_false,
)
continue
if observation.observed is False:
prior = update_probability(
prior,
1 - observation.prob_given_true,
1 - observation.prob_given_false,
)
continue
# observation.observed is None
if observation.entity_id is not None:
_LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
observation.entity_id,
)
continue
_LOGGER.debug(
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
)
# the prior has been updated and is now the posterior
return prior
def _build_observations_by_entity(self):
def _build_observations_by_entity(self) -> dict[str, list[Observation]]:
"""
Build and return data structure of the form below.
@ -378,7 +399,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return observations_by_entity
def _build_observations_by_template(self):
def _build_observations_by_template(self) -> dict[Template, list[Observation]]:
"""
Build and return data structure of the form below.
@ -392,7 +413,7 @@ class BayesianBinarySensor(BinarySensorEntity):
for all relevant observations to be looked up via their `template`.
"""
observations_by_template = {}
observations_by_template: dict[Template, list[Observation]] = {}
for observation in self._observations:
if observation.value_template is None:
continue
@ -402,7 +423,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return observations_by_template
def _process_numeric_state(self, entity_observation):
def _process_numeric_state(self, entity_observation: Observation) -> bool | None:
"""Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation.entity_id
@ -420,7 +441,7 @@ class BayesianBinarySensor(BinarySensorEntity):
except ConditionError:
return None
def _process_state(self, entity_observation):
def _process_state(self, entity_observation: Observation) -> bool | None:
"""Return True if state conditions are met, return False if they are not.
Returns None if the state is unavailable.
@ -436,7 +457,7 @@ class BayesianBinarySensor(BinarySensorEntity):
except ConditionError:
return None
def _process_multi_state(self, entity_observation):
def _process_multi_state(self, entity_observation: Observation) -> bool | None:
"""Return True if state conditions are met, otherwise return None.
Never return False as all other states should have their own probabilities configured.
@ -452,7 +473,7 @@ class BayesianBinarySensor(BinarySensorEntity):
return None
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes of the sensor."""
return {

View file

@ -18,7 +18,10 @@ from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
@dataclass
class Observation:
"""Representation of a sensor or template observation."""
"""Representation of a sensor or template observation.
Either entity_id or value_template should be non-None.
"""
entity_id: str | None
platform: str
@ -29,7 +32,7 @@ class Observation:
below: float | None
value_template: Template | None
observed: bool | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
id: uuid.UUID = field(default_factory=uuid.uuid4)
def to_dict(self) -> dict[str, str | float | bool | None]:
"""Represent Class as a Dict for easier serialization."""

View file

@ -5,9 +5,12 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry
from . import DOMAIN
from .helpers import Observation
def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") -> None:
def raise_mirrored_entries(
hass: HomeAssistant, observations: list[Observation], text: str = ""
) -> None:
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
if len(observations) != 2:
return
@ -26,7 +29,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
# Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas.
def raise_no_prob_given_false(hass: HomeAssistant, observation, text: str) -> None:
def raise_no_prob_given_false(hass: HomeAssistant, text: str) -> None:
"""In previous 2022.9 and earlier, prob_given_false was optional and had a default version."""
issue_registry.async_create_issue(
hass,