Convert bayesian binary_sensor to use async_track_template_result (#39174)
Add coverage to reach 100% line coverage
This commit is contained in:
parent
42227c1c53
commit
b68c5cec94
2 changed files with 250 additions and 30 deletions
|
@ -1,5 +1,6 @@
|
||||||
"""Use Bayesian Inference to trigger a binary sensor."""
|
"""Use Bayesian Inference to trigger a binary sensor."""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -16,9 +17,14 @@ from homeassistant.const import (
|
||||||
STATE_UNKNOWN,
|
STATE_UNKNOWN,
|
||||||
)
|
)
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.exceptions import TemplateError
|
||||||
from homeassistant.helpers import condition
|
from homeassistant.helpers import condition
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.event import async_track_state_change_event
|
from homeassistant.helpers.event import (
|
||||||
|
async_track_state_change_event,
|
||||||
|
async_track_template_result,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers.template import result_as_boolean
|
||||||
|
|
||||||
ATTR_OBSERVATIONS = "observations"
|
ATTR_OBSERVATIONS = "observations"
|
||||||
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
||||||
|
@ -36,6 +42,9 @@ CONF_TO_STATE = "to_state"
|
||||||
DEFAULT_NAME = "Bayesian Binary Sensor"
|
DEFAULT_NAME = "Bayesian Binary Sensor"
|
||||||
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
NUMERIC_STATE_SCHEMA = vol.Schema(
|
NUMERIC_STATE_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
CONF_PLATFORM: "numeric_state",
|
CONF_PLATFORM: "numeric_state",
|
||||||
|
@ -107,8 +116,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
||||||
BayesianBinarySensor(
|
BayesianBinarySensor(
|
||||||
name, prior, observations, probability_threshold, device_class
|
name, prior, observations, probability_threshold, device_class
|
||||||
)
|
)
|
||||||
],
|
]
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,17 +130,19 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
self._probability_threshold = probability_threshold
|
self._probability_threshold = probability_threshold
|
||||||
self._device_class = device_class
|
self._device_class = device_class
|
||||||
self._deviation = False
|
self._deviation = False
|
||||||
|
self._callbacks = []
|
||||||
|
|
||||||
self.prior = prior
|
self.prior = prior
|
||||||
self.probability = prior
|
self.probability = prior
|
||||||
|
|
||||||
self.current_observations = OrderedDict({})
|
self.current_observations = OrderedDict({})
|
||||||
|
|
||||||
self.observations_by_entity = self._build_observations_by_entity()
|
self.observations_by_entity = self._build_observations_by_entity()
|
||||||
|
self.observations_by_template = self._build_observations_by_template()
|
||||||
|
|
||||||
self.observation_handlers = {
|
self.observation_handlers = {
|
||||||
"numeric_state": self._process_numeric_state,
|
"numeric_state": self._process_numeric_state,
|
||||||
"state": self._process_state,
|
"state": self._process_state,
|
||||||
"template": self._process_template,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def async_added_to_hass(self):
|
async def async_added_to_hass(self):
|
||||||
|
@ -166,17 +176,61 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
entity = event.data.get("entity_id")
|
entity = event.data.get("entity_id")
|
||||||
|
|
||||||
self.current_observations.update(self._record_entity_observations(entity))
|
self.current_observations.update(self._record_entity_observations(entity))
|
||||||
self.probability = self._calculate_new_probability()
|
self._recalculate_and_write_state()
|
||||||
|
|
||||||
self.hass.async_add_job(self.async_update_ha_state, True)
|
self.async_on_remove(
|
||||||
|
async_track_state_change_event(
|
||||||
|
self.hass,
|
||||||
|
list(self.observations_by_entity),
|
||||||
|
async_threshold_sensor_state_listener,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_template_result_changed(event, template, last_result, result):
|
||||||
|
entity = event and event.data.get("entity_id")
|
||||||
|
|
||||||
|
if isinstance(result, TemplateError):
|
||||||
|
_LOGGER.error(
|
||||||
|
"TemplateError('%s') "
|
||||||
|
"while processing template '%s' "
|
||||||
|
"in entity '%s'",
|
||||||
|
result,
|
||||||
|
template,
|
||||||
|
self.entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
should_trigger = False
|
||||||
|
else:
|
||||||
|
should_trigger = result_as_boolean(result)
|
||||||
|
|
||||||
|
for obs in self.observations_by_template[template]:
|
||||||
|
if should_trigger:
|
||||||
|
obs_entry = {"entity_id": entity, **obs}
|
||||||
|
else:
|
||||||
|
obs_entry = None
|
||||||
|
self.current_observations[obs["id"]] = obs_entry
|
||||||
|
|
||||||
|
self._recalculate_and_write_state()
|
||||||
|
|
||||||
|
for template in self.observations_by_template:
|
||||||
|
info = async_track_template_result(
|
||||||
|
self.hass, template, _async_template_result_changed
|
||||||
|
)
|
||||||
|
|
||||||
|
self._callbacks.append(info)
|
||||||
|
self.async_on_remove(info.async_remove)
|
||||||
|
info.async_refresh()
|
||||||
|
|
||||||
self.current_observations.update(self._initialize_current_observations())
|
self.current_observations.update(self._initialize_current_observations())
|
||||||
self.probability = self._calculate_new_probability()
|
self.probability = self._calculate_new_probability()
|
||||||
async_track_state_change_event(
|
self._deviation = bool(self.probability >= self._probability_threshold)
|
||||||
self.hass,
|
|
||||||
list(self.observations_by_entity),
|
@callback
|
||||||
async_threshold_sensor_state_listener,
|
def _recalculate_and_write_state(self):
|
||||||
)
|
self.probability = self._calculate_new_probability()
|
||||||
|
self._deviation = bool(self.probability >= self._probability_threshold)
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
def _initialize_current_observations(self):
|
def _initialize_current_observations(self):
|
||||||
local_observations = OrderedDict({})
|
local_observations = OrderedDict({})
|
||||||
|
@ -186,9 +240,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
|
|
||||||
def _record_entity_observations(self, entity):
|
def _record_entity_observations(self, entity):
|
||||||
local_observations = OrderedDict({})
|
local_observations = OrderedDict({})
|
||||||
entity_obs_list = self.observations_by_entity[entity]
|
|
||||||
|
|
||||||
for entity_obs in entity_obs_list:
|
for entity_obs in self.observations_by_entity[entity]:
|
||||||
platform = entity_obs["platform"]
|
platform = entity_obs["platform"]
|
||||||
|
|
||||||
should_trigger = self.observation_handlers[platform](entity_obs)
|
should_trigger = self.observation_handlers[platform](entity_obs)
|
||||||
|
@ -233,18 +286,42 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
for ind, obs in enumerate(self._observations):
|
for ind, obs in enumerate(self._observations):
|
||||||
obs["id"] = ind
|
obs["id"] = ind
|
||||||
|
|
||||||
if "entity_id" in obs:
|
if "entity_id" not in obs:
|
||||||
entity_ids = [obs["entity_id"]]
|
continue
|
||||||
elif "value_template" in obs:
|
|
||||||
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()
|
entity_ids = [obs["entity_id"]]
|
||||||
|
|
||||||
for e_id in entity_ids:
|
for e_id in entity_ids:
|
||||||
obs_list = observations_by_entity.get(e_id, [])
|
observations_by_entity.setdefault(e_id, []).append(obs)
|
||||||
obs_list.append(obs)
|
|
||||||
observations_by_entity[e_id] = obs_list
|
|
||||||
|
|
||||||
return observations_by_entity
|
return observations_by_entity
|
||||||
|
|
||||||
|
def _build_observations_by_template(self):
|
||||||
|
"""
|
||||||
|
Build and return data structure of the form below.
|
||||||
|
|
||||||
|
{
|
||||||
|
"template": [{"id": 0, ...}, {"id": 1, ...}],
|
||||||
|
"template2": [{"id": 2, ...}],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
|
||||||
|
Each "observation" must be recognized uniquely, and it should be possible
|
||||||
|
for all relevant observations to be looked up via their `template`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
observations_by_template = {}
|
||||||
|
for ind, obs in enumerate(self._observations):
|
||||||
|
obs["id"] = ind
|
||||||
|
|
||||||
|
if "value_template" not in obs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
template = obs.get(CONF_VALUE_TEMPLATE)
|
||||||
|
observations_by_template.setdefault(template, []).append(obs)
|
||||||
|
|
||||||
|
return observations_by_template
|
||||||
|
|
||||||
def _process_numeric_state(self, entity_observation):
|
def _process_numeric_state(self, entity_observation):
|
||||||
"""Return True if numeric condition is met."""
|
"""Return True if numeric condition is met."""
|
||||||
entity = entity_observation["entity_id"]
|
entity = entity_observation["entity_id"]
|
||||||
|
@ -264,12 +341,6 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
|
|
||||||
return condition.state(self.hass, entity, entity_observation.get("to_state"))
|
return condition.state(self.hass, entity, entity_observation.get("to_state"))
|
||||||
|
|
||||||
def _process_template(self, entity_observation):
|
|
||||||
"""Return True if template condition is True."""
|
|
||||||
template = entity_observation.get(CONF_VALUE_TEMPLATE)
|
|
||||||
template.hass = self.hass
|
|
||||||
return condition.async_template(self.hass, template, entity_observation)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""Return the name of the sensor."""
|
"""Return the name of the sensor."""
|
||||||
|
@ -307,7 +378,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
{
|
{
|
||||||
obs.get("entity_id")
|
obs.get("entity_id")
|
||||||
for obs in self.current_observations.values()
|
for obs in self.current_observations.values()
|
||||||
if obs is not None
|
if obs is not None and obs.get("entity_id") is not None
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
ATTR_PROBABILITY: round(self.probability, 2),
|
ATTR_PROBABILITY: round(self.probability, 2),
|
||||||
|
@ -316,4 +387,10 @@ class BayesianBinarySensor(BinarySensorEntity):
|
||||||
|
|
||||||
async def async_update(self):
|
async def async_update(self):
|
||||||
"""Get the latest data and update the states."""
|
"""Get the latest data and update the states."""
|
||||||
self._deviation = bool(self.probability >= self._probability_threshold)
|
if not self._callbacks:
|
||||||
|
self._recalculate_and_write_state()
|
||||||
|
return
|
||||||
|
# Force recalc of the templates. The states will
|
||||||
|
# update automatically.
|
||||||
|
for call in self._callbacks:
|
||||||
|
call.async_refresh()
|
||||||
|
|
|
@ -3,8 +3,12 @@ import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.components.bayesian import binary_sensor as bayesian
|
from homeassistant.components.bayesian import binary_sensor as bayesian
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.components.homeassistant import (
|
||||||
from homeassistant.setup import setup_component
|
DOMAIN as HA_DOMAIN,
|
||||||
|
SERVICE_UPDATE_ENTITY,
|
||||||
|
)
|
||||||
|
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
|
||||||
|
from homeassistant.setup import async_setup_component, setup_component
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant
|
from tests.common import get_test_home_assistant
|
||||||
|
|
||||||
|
@ -488,3 +492,142 @@ class TestBayesianBinarySensor(unittest.TestCase):
|
||||||
|
|
||||||
for key, attrs in state.attributes.items():
|
for key, attrs in state.attributes.items():
|
||||||
json.dumps(attrs)
|
json.dumps(attrs)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_template_error(hass, caplog):
|
||||||
|
"""Test sensor with template error."""
|
||||||
|
config = {
|
||||||
|
"binary_sensor": {
|
||||||
|
"name": "Test_Binary",
|
||||||
|
"platform": "bayesian",
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"platform": "template",
|
||||||
|
"value_template": "{{ xyz + 1 }}",
|
||||||
|
"prob_given_true": 0.9,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"prior": 0.2,
|
||||||
|
"probability_threshold": 0.32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await async_setup_component(hass, "binary_sensor", config)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||||
|
|
||||||
|
assert "TemplateError" in caplog.text
|
||||||
|
assert "xyz" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_request_with_template(hass):
|
||||||
|
"""Test sensor on template platform observations that gets an update request."""
|
||||||
|
config = {
|
||||||
|
"binary_sensor": {
|
||||||
|
"name": "Test_Binary",
|
||||||
|
"platform": "bayesian",
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"platform": "template",
|
||||||
|
"value_template": "{{states('sensor.test_monitored') == 'off'}}",
|
||||||
|
"prob_given_true": 0.8,
|
||||||
|
"prob_given_false": 0.4,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"prior": 0.2,
|
||||||
|
"probability_threshold": 0.32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await async_setup_component(hass, "binary_sensor", config)
|
||||||
|
await async_setup_component(hass, HA_DOMAIN, {})
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
HA_DOMAIN,
|
||||||
|
SERVICE_UPDATE_ENTITY,
|
||||||
|
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_request_without_template(hass):
|
||||||
|
"""Test sensor on template platform observations that gets an update request."""
|
||||||
|
config = {
|
||||||
|
"binary_sensor": {
|
||||||
|
"name": "Test_Binary",
|
||||||
|
"platform": "bayesian",
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"platform": "state",
|
||||||
|
"entity_id": "sensor.test_monitored",
|
||||||
|
"to_state": "off",
|
||||||
|
"prob_given_true": 0.9,
|
||||||
|
"prob_given_false": 0.4,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"prior": 0.2,
|
||||||
|
"probability_threshold": 0.32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await async_setup_component(hass, "binary_sensor", config)
|
||||||
|
await async_setup_component(hass, HA_DOMAIN, {})
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.states.async_set("sensor.test_monitored", "on")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
HA_DOMAIN,
|
||||||
|
SERVICE_UPDATE_ENTITY,
|
||||||
|
{ATTR_ENTITY_ID: "binary_sensor.test_binary"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "off"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_monitored_sensor_goes_away(hass):
|
||||||
|
"""Test sensor on template platform observations that goes away."""
|
||||||
|
config = {
|
||||||
|
"binary_sensor": {
|
||||||
|
"name": "Test_Binary",
|
||||||
|
"platform": "bayesian",
|
||||||
|
"observations": [
|
||||||
|
{
|
||||||
|
"platform": "state",
|
||||||
|
"entity_id": "sensor.test_monitored",
|
||||||
|
"to_state": "on",
|
||||||
|
"prob_given_true": 0.9,
|
||||||
|
"prob_given_false": 0.4,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"prior": 0.2,
|
||||||
|
"probability_threshold": 0.32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await async_setup_component(hass, "binary_sensor", config)
|
||||||
|
await async_setup_component(hass, HA_DOMAIN, {})
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.states.async_set("sensor.test_monitored", "on")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "on"
|
||||||
|
|
||||||
|
hass.states.async_remove("sensor.test_monitored")
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert hass.states.get("binary_sensor.test_binary").state == "on"
|
||||||
|
|
Loading…
Add table
Reference in a new issue