Add support for multiple states/zones in conditions (#36835)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Franck Nijhof 2020-06-16 00:53:13 +02:00 committed by GitHub
parent 16cf16e418
commit 02f174e2e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 9 deletions

View file

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import sys import sys
from typing import Callable, Container, Optional, Set, Union, cast from typing import Callable, Container, List, Optional, Set, Union, cast
from homeassistant.components import zone as zone_cmp from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import ( from homeassistant.components.device_automation import (
@ -263,7 +263,7 @@ def async_numeric_state_from_config(
def state( def state(
hass: HomeAssistant, hass: HomeAssistant,
entity: Union[None, str, State], entity: Union[None, str, State],
req_state: str, req_state: Union[str, List[str]],
for_period: Optional[timedelta] = None, for_period: Optional[timedelta] = None,
) -> bool: ) -> bool:
"""Test if state matches requirements. """Test if state matches requirements.
@ -277,7 +277,10 @@ def state(
return False return False
assert isinstance(entity, State) assert isinstance(entity, State)
is_state = entity.state == req_state if isinstance(req_state, str):
req_state = [req_state]
is_state = entity.state in req_state
if for_period is None or not is_state: if for_period is None or not is_state:
return is_state return is_state
@ -292,13 +295,16 @@ def state_from_config(
if config_validation: if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config) config = cv.STATE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, []) entity_ids = config.get(CONF_ENTITY_ID, [])
req_state = cast(str, config.get(CONF_STATE)) req_states: Union[str, List[str]] = config.get(CONF_STATE, [])
for_period = config.get("for") for_period = config.get("for")
if not isinstance(req_states, list):
req_states = [req_states]
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
return all( return all(
state(hass, entity_id, req_state, for_period) for entity_id in entity_ids state(hass, entity_id, req_states, for_period) for entity_id in entity_ids
) )
return if_state return if_state
@ -512,11 +518,17 @@ def zone_from_config(
if config_validation: if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config) config = cv.ZONE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, []) entity_ids = config.get(CONF_ENTITY_ID, [])
zone_entity_id = config.get(CONF_ZONE) zone_entity_ids = config.get(CONF_ZONE, [])
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
return all(zone(hass, zone_entity_id, entity_id) for entity_id in entity_ids) return all(
any(
zone(hass, zone_entity_id, entity_id)
for zone_entity_id in zone_entity_ids
)
for entity_id in entity_ids
)
return if_in_zone return if_in_zone

View file

@ -858,7 +858,7 @@ STATE_CONDITION_SCHEMA = vol.All(
{ {
vol.Required(CONF_CONDITION): "state", vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids, vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Required(CONF_STATE): str, vol.Required(CONF_STATE): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta), vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta),
# To support use_trigger_value in automation # To support use_trigger_value in automation
# Deprecated 2016/04/25 # Deprecated 2016/04/25
@ -906,7 +906,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_CONDITION): "zone", vol.Required(CONF_CONDITION): "zone",
vol.Required(CONF_ENTITY_ID): entity_ids, vol.Required(CONF_ENTITY_ID): entity_ids,
"zone": entity_id, "zone": entity_ids,
# To support use_trigger_value in automation # To support use_trigger_value in automation
# Deprecated 2016/04/25 # Deprecated 2016/04/25
vol.Optional("event"): vol.Any("enter", "leave"), vol.Optional("event"): vol.Any("enter", "leave"),

View file

@ -295,6 +295,32 @@ async def test_state_multiple_entities(hass):
assert not test(hass) assert not test(hass)
async def test_multiple_states(hass):
"""Test with multiple states in condition."""
test = await condition.async_from_config(
hass,
{
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": "sensor.temperature",
"state": ["100", "200"],
},
],
},
)
hass.states.async_set("sensor.temperature", 100)
assert test(hass)
hass.states.async_set("sensor.temperature", 200)
assert test(hass)
hass.states.async_set("sensor.temperature", 42)
assert not test(hass)
async def test_numeric_state_multiple_entities(hass): async def test_numeric_state_multiple_entities(hass):
"""Test with multiple entities in condition.""" """Test with multiple entities in condition."""
test = await condition.async_from_config( test = await condition.async_from_config(
@ -383,6 +409,55 @@ async def test_zone_multiple_entities(hass):
assert not test(hass) assert not test(hass)
async def test_multiple_zones(hass):
"""Test with multiple entities in condition."""
test = await condition.async_from_config(
hass,
{
"condition": "and",
"conditions": [
{
"condition": "zone",
"entity_id": "device_tracker.person",
"zone": ["zone.home", "zone.work"],
},
],
},
)
hass.states.async_set(
"zone.home",
"zoning",
{"name": "home", "latitude": 2.1, "longitude": 1.1, "radius": 10},
)
hass.states.async_set(
"zone.work",
"zoning",
{"name": "work", "latitude": 20.1, "longitude": 10.1, "radius": 10},
)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 2.1, "longitude": 1.1},
)
assert test(hass)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 20.1, "longitude": 10.1},
)
assert test(hass)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 50.1, "longitude": 20.1},
)
assert not test(hass)
async def test_extract_entities(): async def test_extract_entities():
"""Test extracting entities.""" """Test extracting entities."""
assert condition.async_extract_entities( assert condition.async_extract_entities(