Convert bayesian binary_sensor to use async_track_template_result (#39174)

Add coverage to reach 100% line coverage
This commit is contained in:
J. Nick Koston 2020-08-23 02:59:26 -05:00 committed by GitHub
parent 42227c1c53
commit b68c5cec94
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 250 additions and 30 deletions

View file

@ -1,5 +1,6 @@
"""Use Bayesian Inference to trigger a binary sensor."""
from collections import OrderedDict
import logging
import voluptuous as vol
@ -16,9 +17,14 @@ from homeassistant.const import (
STATE_UNKNOWN,
)
from homeassistant.core import callback
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.event import (
async_track_state_change_event,
async_track_template_result,
)
from homeassistant.helpers.template import result_as_boolean
ATTR_OBSERVATIONS = "observations"
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
@ -36,6 +42,9 @@ CONF_TO_STATE = "to_state"
DEFAULT_NAME = "Bayesian Binary Sensor"
DEFAULT_PROBABILITY_THRESHOLD = 0.5
_LOGGER = logging.getLogger(__name__)
NUMERIC_STATE_SCHEMA = vol.Schema(
{
CONF_PLATFORM: "numeric_state",
@ -107,8 +116,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
BayesianBinarySensor(
name, prior, observations, probability_threshold, device_class
)
],
True,
]
)
@ -122,17 +130,19 @@ class BayesianBinarySensor(BinarySensorEntity):
self._probability_threshold = probability_threshold
self._device_class = device_class
self._deviation = False
self._callbacks = []
self.prior = prior
self.probability = prior
self.current_observations = OrderedDict({})
self.observations_by_entity = self._build_observations_by_entity()
self.observations_by_template = self._build_observations_by_template()
self.observation_handlers = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"template": self._process_template,
}
async def async_added_to_hass(self):
@ -166,17 +176,61 @@ class BayesianBinarySensor(BinarySensorEntity):
entity = event.data.get("entity_id")
self.current_observations.update(self._record_entity_observations(entity))
self.probability = self._calculate_new_probability()
self._recalculate_and_write_state()
self.hass.async_add_job(self.async_update_ha_state, True)
self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
self.async_on_remove(
async_track_state_change_event(
self.hass,
list(self.observations_by_entity),
async_threshold_sensor_state_listener,
)
)
@callback
def _async_template_result_changed(event, template, last_result, result):
entity = event and event.data.get("entity_id")
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') "
"while processing template '%s' "
"in entity '%s'",
result,
template,
self.entity_id,
)
should_trigger = False
else:
should_trigger = result_as_boolean(result)
for obs in self.observations_by_template[template]:
if should_trigger:
obs_entry = {"entity_id": entity, **obs}
else:
obs_entry = None
self.current_observations[obs["id"]] = obs_entry
self._recalculate_and_write_state()
for template in self.observations_by_template:
info = async_track_template_result(
self.hass, template, _async_template_result_changed
)
self._callbacks.append(info)
self.async_on_remove(info.async_remove)
info.async_refresh()
self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
self._deviation = bool(self.probability >= self._probability_threshold)
@callback
def _recalculate_and_write_state(self):
self.probability = self._calculate_new_probability()
self._deviation = bool(self.probability >= self._probability_threshold)
self.async_write_ha_state()
def _initialize_current_observations(self):
local_observations = OrderedDict({})
@ -186,9 +240,8 @@ class BayesianBinarySensor(BinarySensorEntity):
def _record_entity_observations(self, entity):
local_observations = OrderedDict({})
entity_obs_list = self.observations_by_entity[entity]
for entity_obs in entity_obs_list:
for entity_obs in self.observations_by_entity[entity]:
platform = entity_obs["platform"]
should_trigger = self.observation_handlers[platform](entity_obs)
@ -233,18 +286,42 @@ class BayesianBinarySensor(BinarySensorEntity):
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
if "entity_id" not in obs:
continue
entity_ids = [obs["entity_id"]]
elif "value_template" in obs:
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
for e_id in entity_ids:
obs_list = observations_by_entity.get(e_id, [])
obs_list.append(obs)
observations_by_entity[e_id] = obs_list
observations_by_entity.setdefault(e_id, []).append(obs)
return observations_by_entity
def _build_observations_by_template(self):
"""
Build and return data structure of the form below.
{
"template": [{"id": 0, ...}, {"id": 1, ...}],
"template2": [{"id": 2, ...}],
...
}
Each "observation" must be recognized uniquely, and it should be possible
for all relevant observations to be looked up via their `template`.
"""
observations_by_template = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "value_template" not in obs:
continue
template = obs.get(CONF_VALUE_TEMPLATE)
observations_by_template.setdefault(template, []).append(obs)
return observations_by_template
def _process_numeric_state(self, entity_observation):
"""Return True if numeric condition is met."""
entity = entity_observation["entity_id"]
@ -264,12 +341,6 @@ class BayesianBinarySensor(BinarySensorEntity):
return condition.state(self.hass, entity, entity_observation.get("to_state"))
def _process_template(self, entity_observation):
"""Return True if template condition is True."""
template = entity_observation.get(CONF_VALUE_TEMPLATE)
template.hass = self.hass
return condition.async_template(self.hass, template, entity_observation)
@property
def name(self):
"""Return the name of the sensor."""
@ -307,7 +378,7 @@ class BayesianBinarySensor(BinarySensorEntity):
{
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None
if obs is not None and obs.get("entity_id") is not None
}
),
ATTR_PROBABILITY: round(self.probability, 2),
@ -316,4 +387,10 @@ class BayesianBinarySensor(BinarySensorEntity):
async def async_update(self):
"""Get the latest data and update the states."""
self._deviation = bool(self.probability >= self._probability_threshold)
if not self._callbacks:
self._recalculate_and_write_state()
return
# Force recalc of the templates. The states will
# update automatically.
for call in self._callbacks:
call.async_refresh()

View file

@ -3,8 +3,12 @@ import json
import unittest
from homeassistant.components.bayesian import binary_sensor as bayesian
from homeassistant.const import STATE_UNKNOWN
from homeassistant.setup import setup_component
from homeassistant.components.homeassistant import (
DOMAIN as HA_DOMAIN,
SERVICE_UPDATE_ENTITY,
)
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.setup import async_setup_component, setup_component
from tests.common import get_test_home_assistant
@ -488,3 +492,142 @@ class TestBayesianBinarySensor(unittest.TestCase):
for key, attrs in state.attributes.items():
json.dumps(attrs)
async def test_template_error(hass, caplog):
"""Test sensor with template error."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "template",
"value_template": "{{ xyz + 1 }}",
"prob_given_true": 0.9,
},
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
await async_setup_component(hass, "binary_sensor", config)
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "off"
assert "TemplateError" in caplog.text
assert "xyz" in caplog.text
async def test_update_request_with_template(hass):
"""Test sensor on template platform observations that gets an update request."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "template",
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
"prob_given_true": 0.8,
"prob_given_false": 0.4,
}
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
await async_setup_component(hass, "binary_sensor", config)
await async_setup_component(hass, HA_DOMAIN, {})
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "off"
await hass.services.async_call(
HA_DOMAIN,
SERVICE_UPDATE_ENTITY,
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "off"
async def test_update_request_without_template(hass):
"""Test sensor on template platform observations that gets an update request."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "off",
"prob_given_true": 0.9,
"prob_given_false": 0.4,
},
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
await async_setup_component(hass, "binary_sensor", config)
await async_setup_component(hass, HA_DOMAIN, {})
await hass.async_block_till_done()
hass.states.async_set("sensor.test_monitored", "on")
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "off"
await hass.services.async_call(
HA_DOMAIN,
SERVICE_UPDATE_ENTITY,
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
blocking=True,
)
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "off"
async def test_monitored_sensor_goes_away(hass):
"""Test sensor on template platform observations that goes away."""
config = {
"binary_sensor": {
"name": "Test_Binary",
"platform": "bayesian",
"observations": [
{
"platform": "state",
"entity_id": "sensor.test_monitored",
"to_state": "on",
"prob_given_true": 0.9,
"prob_given_false": 0.4,
},
],
"prior": 0.2,
"probability_threshold": 0.32,
}
}
await async_setup_component(hass, "binary_sensor", config)
await async_setup_component(hass, HA_DOMAIN, {})
await hass.async_block_till_done()
hass.states.async_set("sensor.test_monitored", "on")
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "on"
hass.states.async_remove("sensor.test_monitored")
await hass.async_block_till_done()
assert hass.states.get("binary_sensor.test_binary").state == "on"