From 9c96ec858a1506a60bcc7ee4031ba030efea530a Mon Sep 17 00:00:00 2001 From: Santobert Date: Fri, 4 Oct 2019 23:32:10 +0200 Subject: [PATCH] switch reproduce state (#27202) --- .../components/switch/reproduce_state.py | 61 +++++++++++++++++++ .../components/switch/test_reproduce_state.py | 50 +++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 homeassistant/components/switch/reproduce_state.py create mode 100644 tests/components/switch/test_reproduce_state.py diff --git a/homeassistant/components/switch/reproduce_state.py b/homeassistant/components/switch/reproduce_state.py new file mode 100644 index 00000000000..7ed1f70cb97 --- /dev/null +++ b/homeassistant/components/switch/reproduce_state.py @@ -0,0 +1,61 @@ +"""Reproduce an Switch state.""" +import asyncio +import logging +from typing import Iterable, Optional + +from homeassistant.const import ( + ATTR_ENTITY_ID, + STATE_ON, + STATE_OFF, + SERVICE_TURN_OFF, + SERVICE_TURN_ON, +) +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import DOMAIN + +_LOGGER = logging.getLogger(__name__) + +VALID_STATES = {STATE_ON, STATE_OFF} + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if state.state not in VALID_STATES: + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if cur_state.state == state.state: + return + + service_data = {ATTR_ENTITY_ID: state.entity_id} + + if state.state == STATE_ON: + service = SERVICE_TURN_ON + elif state.state == STATE_OFF: + service = SERVICE_TURN_OFF + + await hass.services.async_call( + DOMAIN, service, service_data, context=context, blocking=True + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Switch states.""" + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) diff --git a/tests/components/switch/test_reproduce_state.py b/tests/components/switch/test_reproduce_state.py new file mode 100644 index 00000000000..4b6db84bfdd --- /dev/null +++ b/tests/components/switch/test_reproduce_state.py @@ -0,0 +1,50 @@ +"""Test reproduce state for Switch.""" +from homeassistant.core import State + +from tests.common import async_mock_service + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Switch states.""" + hass.states.async_set("switch.entity_off", "off", {}) + hass.states.async_set("switch.entity_on", "on", {}) + + turn_on_calls = async_mock_service(hass, "switch", "turn_on") + turn_off_calls = async_mock_service(hass, "switch", "turn_off") + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [State("switch.entity_off", "off"), State("switch.entity_on", "on", {})], + blocking=True, + ) + + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("switch.entity_off", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(turn_on_calls) == 0 + assert len(turn_off_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("switch.entity_on", "off"), + State("switch.entity_off", "on", {}), + # Should not raise + State("switch.non_existing", "on"), + ], + blocking=True, + ) + + assert len(turn_on_calls) == 1 + assert turn_on_calls[0].domain == "switch" + assert turn_on_calls[0].data == {"entity_id": "switch.entity_off"} + + assert len(turn_off_calls) == 1 + assert turn_off_calls[0].domain == "switch" + assert turn_off_calls[0].data == {"entity_id": "switch.entity_on"}