Add support for multiple states/zones in conditions (#36835)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Franck Nijhof 2020-06-16 00:53:13 +02:00 committed by GitHub
parent 16cf16e418
commit 02f174e2e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 9 deletions

View file

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import functools as ft
import logging
import sys
from typing import Callable, Container, Optional, Set, Union, cast
from typing import Callable, Container, List, Optional, Set, Union, cast
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
@ -263,7 +263,7 @@ def async_numeric_state_from_config(
def state(
hass: HomeAssistant,
entity: Union[None, str, State],
req_state: str,
req_state: Union[str, List[str]],
for_period: Optional[timedelta] = None,
) -> bool:
"""Test if state matches requirements.
@ -277,7 +277,10 @@ def state(
return False
assert isinstance(entity, State)
is_state = entity.state == req_state
if isinstance(req_state, str):
req_state = [req_state]
is_state = entity.state in req_state
if for_period is None or not is_state:
return is_state
@ -292,13 +295,16 @@ def state_from_config(
if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, [])
req_state = cast(str, config.get(CONF_STATE))
req_states: Union[str, List[str]] = config.get(CONF_STATE, [])
for_period = config.get("for")
if not isinstance(req_states, list):
req_states = [req_states]
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return all(
state(hass, entity_id, req_state, for_period) for entity_id in entity_ids
state(hass, entity_id, req_states, for_period) for entity_id in entity_ids
)
return if_state
@ -512,11 +518,17 @@ def zone_from_config(
if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config)
entity_ids = config.get(CONF_ENTITY_ID, [])
zone_entity_id = config.get(CONF_ZONE)
zone_entity_ids = config.get(CONF_ZONE, [])
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return all(zone(hass, zone_entity_id, entity_id) for entity_id in entity_ids)
return all(
any(
zone(hass, zone_entity_id, entity_id)
for zone_entity_id in zone_entity_ids
)
for entity_id in entity_ids
)
return if_in_zone

View file

@ -858,7 +858,7 @@ STATE_CONDITION_SCHEMA = vol.All(
{
vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Required(CONF_STATE): str,
vol.Required(CONF_STATE): vol.Any(str, [str]),
vol.Optional(CONF_FOR): vol.All(time_period, positive_timedelta),
# To support use_trigger_value in automation
# Deprecated 2016/04/25
@ -906,7 +906,7 @@ ZONE_CONDITION_SCHEMA = vol.Schema(
{
vol.Required(CONF_CONDITION): "zone",
vol.Required(CONF_ENTITY_ID): entity_ids,
"zone": entity_id,
"zone": entity_ids,
# To support use_trigger_value in automation
# Deprecated 2016/04/25
vol.Optional("event"): vol.Any("enter", "leave"),

View file

@ -295,6 +295,32 @@ async def test_state_multiple_entities(hass):
assert not test(hass)
async def test_multiple_states(hass):
"""Test with multiple states in condition."""
test = await condition.async_from_config(
hass,
{
"condition": "and",
"conditions": [
{
"condition": "state",
"entity_id": "sensor.temperature",
"state": ["100", "200"],
},
],
},
)
hass.states.async_set("sensor.temperature", 100)
assert test(hass)
hass.states.async_set("sensor.temperature", 200)
assert test(hass)
hass.states.async_set("sensor.temperature", 42)
assert not test(hass)
async def test_numeric_state_multiple_entities(hass):
"""Test with multiple entities in condition."""
test = await condition.async_from_config(
@ -383,6 +409,55 @@ async def test_zone_multiple_entities(hass):
assert not test(hass)
async def test_multiple_zones(hass):
"""Test with multiple entities in condition."""
test = await condition.async_from_config(
hass,
{
"condition": "and",
"conditions": [
{
"condition": "zone",
"entity_id": "device_tracker.person",
"zone": ["zone.home", "zone.work"],
},
],
},
)
hass.states.async_set(
"zone.home",
"zoning",
{"name": "home", "latitude": 2.1, "longitude": 1.1, "radius": 10},
)
hass.states.async_set(
"zone.work",
"zoning",
{"name": "work", "latitude": 20.1, "longitude": 10.1, "radius": 10},
)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 2.1, "longitude": 1.1},
)
assert test(hass)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 20.1, "longitude": 10.1},
)
assert test(hass)
hass.states.async_set(
"device_tracker.person",
"home",
{"friendly_name": "person", "latitude": 50.1, "longitude": 20.1},
)
assert not test(hass)
async def test_extract_entities():
"""Test extracting entities."""
assert condition.async_extract_entities(