Convert bayesian binary_sensor to use async_track_template_result (#39174)
Add coverage to reach 100% line coverage
This commit is contained in:
parent
42227c1c53
commit
b68c5cec94
2 changed files with 250 additions and 30 deletions
|
@ -1,5 +1,6 @@
|
|||
"""Use Bayesian Inference to trigger a binary sensor."""
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -16,9 +17,14 @@ from homeassistant.const import (
|
|||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import condition
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import async_track_state_change_event
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_state_change_event,
|
||||
async_track_template_result,
|
||||
)
|
||||
from homeassistant.helpers.template import result_as_boolean
|
||||
|
||||
ATTR_OBSERVATIONS = "observations"
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
||||
|
@ -36,6 +42,9 @@ CONF_TO_STATE = "to_state"
|
|||
DEFAULT_NAME = "Bayesian Binary Sensor"
|
||||
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
NUMERIC_STATE_SCHEMA = vol.Schema(
|
||||
{
|
||||
CONF_PLATFORM: "numeric_state",
|
||||
|
@ -107,8 +116,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
|||
BayesianBinarySensor(
|
||||
name, prior, observations, probability_threshold, device_class
|
||||
)
|
||||
],
|
||||
True,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -122,17 +130,19 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
self._probability_threshold = probability_threshold
|
||||
self._device_class = device_class
|
||||
self._deviation = False
|
||||
self._callbacks = []
|
||||
|
||||
self.prior = prior
|
||||
self.probability = prior
|
||||
|
||||
self.current_observations = OrderedDict({})
|
||||
|
||||
self.observations_by_entity = self._build_observations_by_entity()
|
||||
self.observations_by_template = self._build_observations_by_template()
|
||||
|
||||
self.observation_handlers = {
|
||||
"numeric_state": self._process_numeric_state,
|
||||
"state": self._process_state,
|
||||
"template": self._process_template,
|
||||
}
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
|
@ -166,17 +176,61 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
entity = event.data.get("entity_id")
|
||||
|
||||
self.current_observations.update(self._record_entity_observations(entity))
|
||||
self.probability = self._calculate_new_probability()
|
||||
self._recalculate_and_write_state()
|
||||
|
||||
self.hass.async_add_job(self.async_update_ha_state, True)
|
||||
self.async_on_remove(
|
||||
async_track_state_change_event(
|
||||
self.hass,
|
||||
list(self.observations_by_entity),
|
||||
async_threshold_sensor_state_listener,
|
||||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_template_result_changed(event, template, last_result, result):
|
||||
entity = event and event.data.get("entity_id")
|
||||
|
||||
if isinstance(result, TemplateError):
|
||||
_LOGGER.error(
|
||||
"TemplateError('%s') "
|
||||
"while processing template '%s' "
|
||||
"in entity '%s'",
|
||||
result,
|
||||
template,
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
should_trigger = False
|
||||
else:
|
||||
should_trigger = result_as_boolean(result)
|
||||
|
||||
for obs in self.observations_by_template[template]:
|
||||
if should_trigger:
|
||||
obs_entry = {"entity_id": entity, **obs}
|
||||
else:
|
||||
obs_entry = None
|
||||
self.current_observations[obs["id"]] = obs_entry
|
||||
|
||||
self._recalculate_and_write_state()
|
||||
|
||||
for template in self.observations_by_template:
|
||||
info = async_track_template_result(
|
||||
self.hass, template, _async_template_result_changed
|
||||
)
|
||||
|
||||
self._callbacks.append(info)
|
||||
self.async_on_remove(info.async_remove)
|
||||
info.async_refresh()
|
||||
|
||||
self.current_observations.update(self._initialize_current_observations())
|
||||
self.probability = self._calculate_new_probability()
|
||||
async_track_state_change_event(
|
||||
self.hass,
|
||||
list(self.observations_by_entity),
|
||||
async_threshold_sensor_state_listener,
|
||||
)
|
||||
self._deviation = bool(self.probability >= self._probability_threshold)
|
||||
|
||||
@callback
|
||||
def _recalculate_and_write_state(self):
|
||||
self.probability = self._calculate_new_probability()
|
||||
self._deviation = bool(self.probability >= self._probability_threshold)
|
||||
self.async_write_ha_state()
|
||||
|
||||
def _initialize_current_observations(self):
|
||||
local_observations = OrderedDict({})
|
||||
|
@ -186,9 +240,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
def _record_entity_observations(self, entity):
|
||||
local_observations = OrderedDict({})
|
||||
entity_obs_list = self.observations_by_entity[entity]
|
||||
|
||||
for entity_obs in entity_obs_list:
|
||||
for entity_obs in self.observations_by_entity[entity]:
|
||||
platform = entity_obs["platform"]
|
||||
|
||||
should_trigger = self.observation_handlers[platform](entity_obs)
|
||||
|
@ -233,18 +286,42 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
for ind, obs in enumerate(self._observations):
|
||||
obs["id"] = ind
|
||||
|
||||
if "entity_id" in obs:
|
||||
entity_ids = [obs["entity_id"]]
|
||||
elif "value_template" in obs:
|
||||
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
|
||||
if "entity_id" not in obs:
|
||||
continue
|
||||
|
||||
entity_ids = [obs["entity_id"]]
|
||||
|
||||
for e_id in entity_ids:
|
||||
obs_list = observations_by_entity.get(e_id, [])
|
||||
obs_list.append(obs)
|
||||
observations_by_entity[e_id] = obs_list
|
||||
observations_by_entity.setdefault(e_id, []).append(obs)
|
||||
|
||||
return observations_by_entity
|
||||
|
||||
def _build_observations_by_template(self):
|
||||
"""
|
||||
Build and return data structure of the form below.
|
||||
|
||||
{
|
||||
"template": [{"id": 0, ...}, {"id": 1, ...}],
|
||||
"template2": [{"id": 2, ...}],
|
||||
...
|
||||
}
|
||||
|
||||
Each "observation" must be recognized uniquely, and it should be possible
|
||||
for all relevant observations to be looked up via their `template`.
|
||||
"""
|
||||
|
||||
observations_by_template = {}
|
||||
for ind, obs in enumerate(self._observations):
|
||||
obs["id"] = ind
|
||||
|
||||
if "value_template" not in obs:
|
||||
continue
|
||||
|
||||
template = obs.get(CONF_VALUE_TEMPLATE)
|
||||
observations_by_template.setdefault(template, []).append(obs)
|
||||
|
||||
return observations_by_template
|
||||
|
||||
def _process_numeric_state(self, entity_observation):
|
||||
"""Return True if numeric condition is met."""
|
||||
entity = entity_observation["entity_id"]
|
||||
|
@ -264,12 +341,6 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
return condition.state(self.hass, entity, entity_observation.get("to_state"))
|
||||
|
||||
def _process_template(self, entity_observation):
|
||||
"""Return True if template condition is True."""
|
||||
template = entity_observation.get(CONF_VALUE_TEMPLATE)
|
||||
template.hass = self.hass
|
||||
return condition.async_template(self.hass, template, entity_observation)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the sensor."""
|
||||
|
@ -307,7 +378,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
{
|
||||
obs.get("entity_id")
|
||||
for obs in self.current_observations.values()
|
||||
if obs is not None
|
||||
if obs is not None and obs.get("entity_id") is not None
|
||||
}
|
||||
),
|
||||
ATTR_PROBABILITY: round(self.probability, 2),
|
||||
|
@ -316,4 +387,10 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
|
||||
async def async_update(self):
|
||||
"""Get the latest data and update the states."""
|
||||
self._deviation = bool(self.probability >= self._probability_threshold)
|
||||
if not self._callbacks:
|
||||
self._recalculate_and_write_state()
|
||||
return
|
||||
# Force recalc of the templates. The states will
|
||||
# update automatically.
|
||||
for call in self._callbacks:
|
||||
call.async_refresh()
|
||||
|
|
|
@ -3,8 +3,12 @@ import json
|
|||
import unittest
|
||||
|
||||
from homeassistant.components.bayesian import binary_sensor as bayesian
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components.homeassistant import (
|
||||
DOMAIN as HA_DOMAIN,
|
||||
SERVICE_UPDATE_ENTITY,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
|
||||
from homeassistant.setup import async_setup_component, setup_component
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
|
@ -488,3 +492,142 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
|||
|
||||
for key, attrs in state.attributes.items():
|
||||
json.dumps(attrs)
|
||||
|
||||
|
||||
async def test_template_error(hass, caplog):
|
||||
"""Test sensor with template error."""
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "template",
|
||||
"value_template": "{{ xyz + 1 }}",
|
||||
"prob_given_true": 0.9,
|
||||
},
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
await async_setup_component(hass, "binary_sensor", config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||
|
||||
assert "TemplateError" in caplog.text
|
||||
assert "xyz" in caplog.text
|
||||
|
||||
|
||||
async def test_update_request_with_template(hass):
|
||||
"""Test sensor on template platform observations that gets an update request."""
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "template",
|
||||
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
|
||||
"prob_given_true": 0.8,
|
||||
"prob_given_false": 0.4,
|
||||
}
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
await async_setup_component(hass, "binary_sensor", config)
|
||||
await async_setup_component(hass, HA_DOMAIN, {})
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||
|
||||
await hass.services.async_call(
|
||||
HA_DOMAIN,
|
||||
SERVICE_UPDATE_ENTITY,
|
||||
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||
|
||||
|
||||
async def test_update_request_without_template(hass):
|
||||
"""Test sensor on template platform observations that gets an update request."""
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "state",
|
||||
"entity_id": "sensor.test_monitored",
|
||||
"to_state": "off",
|
||||
"prob_given_true": 0.9,
|
||||
"prob_given_false": 0.4,
|
||||
},
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
await async_setup_component(hass, "binary_sensor", config)
|
||||
await async_setup_component(hass, HA_DOMAIN, {})
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.states.async_set("sensor.test_monitored", "on")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||
|
||||
await hass.services.async_call(
|
||||
HA_DOMAIN,
|
||||
SERVICE_UPDATE_ENTITY,
|
||||
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||
|
||||
|
||||
async def test_monitored_sensor_goes_away(hass):
|
||||
"""Test sensor on template platform observations that goes away."""
|
||||
config = {
|
||||
"binary_sensor": {
|
||||
"name": "Test_Binary",
|
||||
"platform": "bayesian",
|
||||
"observations": [
|
||||
{
|
||||
"platform": "state",
|
||||
"entity_id": "sensor.test_monitored",
|
||||
"to_state": "on",
|
||||
"prob_given_true": 0.9,
|
||||
"prob_given_false": 0.4,
|
||||
},
|
||||
],
|
||||
"prior": 0.2,
|
||||
"probability_threshold": 0.32,
|
||||
}
|
||||
}
|
||||
|
||||
await async_setup_component(hass, "binary_sensor", config)
|
||||
await async_setup_component(hass, HA_DOMAIN, {})
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.states.async_set("sensor.test_monitored", "on")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "on"
|
||||
|
||||
hass.states.async_remove("sensor.test_monitored")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert hass.states.get("binary_sensor.test_binary").state == "on"
|
||||
|
|
Loading…
Add table
Reference in a new issue