diff --git a/homeassistant/components/bayesian/binary_sensor.py b/homeassistant/components/bayesian/binary_sensor.py index 1107a682039..86f11cda7e1 100644 --- a/homeassistant/components/bayesian/binary_sensor.py +++ b/homeassistant/components/bayesian/binary_sensor.py @@ -1,5 +1,6 @@ """Use Bayesian Inference to trigger a binary sensor.""" from collections import OrderedDict +import logging import voluptuous as vol @@ -16,9 +17,14 @@ from homeassistant.const import ( STATE_UNKNOWN, ) from homeassistant.core import callback +from homeassistant.exceptions import TemplateError from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.event import async_track_state_change_event +from homeassistant.helpers.event import ( + async_track_state_change_event, + async_track_template_result, +) +from homeassistant.helpers.template import result_as_boolean ATTR_OBSERVATIONS = "observations" ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities" @@ -36,6 +42,9 @@ CONF_TO_STATE = "to_state" DEFAULT_NAME = "Bayesian Binary Sensor" DEFAULT_PROBABILITY_THRESHOLD = 0.5 +_LOGGER = logging.getLogger(__name__) + + NUMERIC_STATE_SCHEMA = vol.Schema( { CONF_PLATFORM: "numeric_state", @@ -107,8 +116,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= BayesianBinarySensor( name, prior, observations, probability_threshold, device_class ) - ], - True, + ] ) @@ -122,17 +130,19 @@ class BayesianBinarySensor(BinarySensorEntity): self._probability_threshold = probability_threshold self._device_class = device_class self._deviation = False + self._callbacks = [] + self.prior = prior self.probability = prior self.current_observations = OrderedDict({}) self.observations_by_entity = self._build_observations_by_entity() + self.observations_by_template = self._build_observations_by_template() self.observation_handlers = { "numeric_state": self._process_numeric_state, "state": self._process_state, - "template": self._process_template, } async def async_added_to_hass(self): @@ -166,17 +176,61 @@ class BayesianBinarySensor(BinarySensorEntity): entity = event.data.get("entity_id") self.current_observations.update(self._record_entity_observations(entity)) - self.probability = self._calculate_new_probability() + self._recalculate_and_write_state() - self.hass.async_add_job(self.async_update_ha_state, True) + self.async_on_remove( + async_track_state_change_event( + self.hass, + list(self.observations_by_entity), + async_threshold_sensor_state_listener, + ) + ) + + @callback + def _async_template_result_changed(event, template, last_result, result): + entity = event and event.data.get("entity_id") + + if isinstance(result, TemplateError): + _LOGGER.error( + "TemplateError('%s') " + "while processing template '%s' " + "in entity '%s'", + result, + template, + self.entity_id, + ) + + should_trigger = False + else: + should_trigger = result_as_boolean(result) + + for obs in self.observations_by_template[template]: + if should_trigger: + obs_entry = {"entity_id": entity, **obs} + else: + obs_entry = None + self.current_observations[obs["id"]] = obs_entry + + self._recalculate_and_write_state() + + for template in self.observations_by_template: + info = async_track_template_result( + self.hass, template, _async_template_result_changed + ) + + self._callbacks.append(info) + self.async_on_remove(info.async_remove) + info.async_refresh() self.current_observations.update(self._initialize_current_observations()) self.probability = self._calculate_new_probability() - async_track_state_change_event( - self.hass, - list(self.observations_by_entity), - async_threshold_sensor_state_listener, - ) + self._deviation = bool(self.probability >= self._probability_threshold) + + @callback + def _recalculate_and_write_state(self): + self.probability = self._calculate_new_probability() + self._deviation = bool(self.probability >= self._probability_threshold) + self.async_write_ha_state() def _initialize_current_observations(self): local_observations = OrderedDict({}) @@ -186,9 +240,8 @@ class BayesianBinarySensor(BinarySensorEntity): def _record_entity_observations(self, entity): local_observations = OrderedDict({}) - entity_obs_list = self.observations_by_entity[entity] - for entity_obs in entity_obs_list: + for entity_obs in self.observations_by_entity[entity]: platform = entity_obs["platform"] should_trigger = self.observation_handlers[platform](entity_obs) @@ -233,18 +286,42 @@ class BayesianBinarySensor(BinarySensorEntity): for ind, obs in enumerate(self._observations): obs["id"] = ind - if "entity_id" in obs: - entity_ids = [obs["entity_id"]] - elif "value_template" in obs: - entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities() + if "entity_id" not in obs: + continue + + entity_ids = [obs["entity_id"]] for e_id in entity_ids: - obs_list = observations_by_entity.get(e_id, []) - obs_list.append(obs) - observations_by_entity[e_id] = obs_list + observations_by_entity.setdefault(e_id, []).append(obs) return observations_by_entity + def _build_observations_by_template(self): + """ + Build and return data structure of the form below. + + { + "template": [{"id": 0, ...}, {"id": 1, ...}], + "template2": [{"id": 2, ...}], + ... + } + + Each "observation" must be recognized uniquely, and it should be possible + for all relevant observations to be looked up via their `template`. + """ + + observations_by_template = {} + for ind, obs in enumerate(self._observations): + obs["id"] = ind + + if "value_template" not in obs: + continue + + template = obs.get(CONF_VALUE_TEMPLATE) + observations_by_template.setdefault(template, []).append(obs) + + return observations_by_template + def _process_numeric_state(self, entity_observation): """Return True if numeric condition is met.""" entity = entity_observation["entity_id"] @@ -264,12 +341,6 @@ class BayesianBinarySensor(BinarySensorEntity): return condition.state(self.hass, entity, entity_observation.get("to_state")) - def _process_template(self, entity_observation): - """Return True if template condition is True.""" - template = entity_observation.get(CONF_VALUE_TEMPLATE) - template.hass = self.hass - return condition.async_template(self.hass, template, entity_observation) - @property def name(self): """Return the name of the sensor.""" @@ -307,7 +378,7 @@ class BayesianBinarySensor(BinarySensorEntity): { obs.get("entity_id") for obs in self.current_observations.values() - if obs is not None + if obs is not None and obs.get("entity_id") is not None } ), ATTR_PROBABILITY: round(self.probability, 2), @@ -316,4 +387,10 @@ class BayesianBinarySensor(BinarySensorEntity): async def async_update(self): """Get the latest data and update the states.""" - self._deviation = bool(self.probability >= self._probability_threshold) + if not self._callbacks: + self._recalculate_and_write_state() + return + # Force recalc of the templates. The states will + # update automatically. + for call in self._callbacks: + call.async_refresh() diff --git a/tests/components/bayesian/test_binary_sensor.py b/tests/components/bayesian/test_binary_sensor.py index 7bbb9eeda27..9e4983ab4d5 100644 --- a/tests/components/bayesian/test_binary_sensor.py +++ b/tests/components/bayesian/test_binary_sensor.py @@ -3,8 +3,12 @@ import json import unittest from homeassistant.components.bayesian import binary_sensor as bayesian -from homeassistant.const import STATE_UNKNOWN -from homeassistant.setup import setup_component +from homeassistant.components.homeassistant import ( + DOMAIN as HA_DOMAIN, + SERVICE_UPDATE_ENTITY, +) +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN +from homeassistant.setup import async_setup_component, setup_component from tests.common import get_test_home_assistant @@ -488,3 +492,142 @@ class TestBayesianBinarySensor(unittest.TestCase): for key, attrs in state.attributes.items(): json.dumps(attrs) + + +async def test_template_error(hass, caplog): + """Test sensor with template error.""" + config = { + "binary_sensor": { + "name": "Test_Binary", + "platform": "bayesian", + "observations": [ + { + "platform": "template", + "value_template": "{{ xyz + 1 }}", + "prob_given_true": 0.9, + }, + ], + "prior": 0.2, + "probability_threshold": 0.32, + } + } + + await async_setup_component(hass, "binary_sensor", config) + await hass.async_block_till_done() + + assert hass.states.get("binary_sensor.test_binary").state == "off" + + assert "TemplateError" in caplog.text + assert "xyz" in caplog.text + + +async def test_update_request_with_template(hass): + """Test sensor on template platform observations that gets an update request.""" + config = { + "binary_sensor": { + "name": "Test_Binary", + "platform": "bayesian", + "observations": [ + { + "platform": "template", + "value_template": "{{states('sensor.test_monitored') == 'off'}}", + "prob_given_true": 0.8, + "prob_given_false": 0.4, + } + ], + "prior": 0.2, + "probability_threshold": 0.32, + } + } + + await async_setup_component(hass, "binary_sensor", config) + await async_setup_component(hass, HA_DOMAIN, {}) + + await hass.async_block_till_done() + + assert hass.states.get("binary_sensor.test_binary").state == "off" + + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + {ATTR_ENTITY_ID: "binary_sensor.test_binary"}, + blocking=True, + ) + await hass.async_block_till_done() + assert hass.states.get("binary_sensor.test_binary").state == "off" + + +async def test_update_request_without_template(hass): + """Test sensor on template platform observations that gets an update request.""" + config = { + "binary_sensor": { + "name": "Test_Binary", + "platform": "bayesian", + "observations": [ + { + "platform": "state", + "entity_id": "sensor.test_monitored", + "to_state": "off", + "prob_given_true": 0.9, + "prob_given_false": 0.4, + }, + ], + "prior": 0.2, + "probability_threshold": 0.32, + } + } + + await async_setup_component(hass, "binary_sensor", config) + await async_setup_component(hass, HA_DOMAIN, {}) + + await hass.async_block_till_done() + + hass.states.async_set("sensor.test_monitored", "on") + await hass.async_block_till_done() + + assert hass.states.get("binary_sensor.test_binary").state == "off" + + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + {ATTR_ENTITY_ID: "binary_sensor.test_binary"}, + blocking=True, + ) + await hass.async_block_till_done() + assert hass.states.get("binary_sensor.test_binary").state == "off" + + +async def test_monitored_sensor_goes_away(hass): + """Test sensor on template platform observations that goes away.""" + config = { + "binary_sensor": { + "name": "Test_Binary", + "platform": "bayesian", + "observations": [ + { + "platform": "state", + "entity_id": "sensor.test_monitored", + "to_state": "on", + "prob_given_true": 0.9, + "prob_given_false": 0.4, + }, + ], + "prior": 0.2, + "probability_threshold": 0.32, + } + } + + await async_setup_component(hass, "binary_sensor", config) + await async_setup_component(hass, HA_DOMAIN, {}) + + await hass.async_block_till_done() + + hass.states.async_set("sensor.test_monitored", "on") + await hass.async_block_till_done() + + assert hass.states.get("binary_sensor.test_binary").state == "on" + + hass.states.async_remove("sensor.test_monitored") + + await hass.async_block_till_done() + assert hass.states.get("binary_sensor.test_binary").state == "on"