diff --git a/homeassistant/components/counter/reproduce_state.py b/homeassistant/components/counter/reproduce_state.py new file mode 100644 index 00000000000..ac5045d68e7 --- /dev/null +++ b/homeassistant/components/counter/reproduce_state.py @@ -0,0 +1,71 @@ +"""Reproduce an Counter state.""" +import asyncio +import logging +from typing import Iterable, Optional + +from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.core import Context, State +from homeassistant.helpers.typing import HomeAssistantType + +from . import ( + ATTR_INITIAL, + ATTR_MAXIMUM, + ATTR_MINIMUM, + ATTR_STEP, + VALUE, + DOMAIN, + SERVICE_CONFIGURE, +) + +_LOGGER = logging.getLogger(__name__) + + +async def _async_reproduce_state( + hass: HomeAssistantType, state: State, context: Optional[Context] = None +) -> None: + """Reproduce a single state.""" + cur_state = hass.states.get(state.entity_id) + + if cur_state is None: + _LOGGER.warning("Unable to find entity %s", state.entity_id) + return + + if not state.state.isdigit(): + _LOGGER.warning( + "Invalid state specified for %s: %s", state.entity_id, state.state + ) + return + + # Return if we are already at the right state. + if ( + cur_state.state == state.state + and cur_state.attributes.get(ATTR_INITIAL) == state.attributes.get(ATTR_INITIAL) + and cur_state.attributes.get(ATTR_MAXIMUM) == state.attributes.get(ATTR_MAXIMUM) + and cur_state.attributes.get(ATTR_MINIMUM) == state.attributes.get(ATTR_MINIMUM) + and cur_state.attributes.get(ATTR_STEP) == state.attributes.get(ATTR_STEP) + ): + return + + service_data = {ATTR_ENTITY_ID: state.entity_id, VALUE: state.state} + service = SERVICE_CONFIGURE + if ATTR_INITIAL in state.attributes: + service_data[ATTR_INITIAL] = state.attributes[ATTR_INITIAL] + if ATTR_MAXIMUM in state.attributes: + service_data[ATTR_MAXIMUM] = state.attributes[ATTR_MAXIMUM] + if ATTR_MINIMUM in state.attributes: + service_data[ATTR_MINIMUM] = state.attributes[ATTR_MINIMUM] + if ATTR_STEP in state.attributes: + service_data[ATTR_STEP] = state.attributes[ATTR_STEP] + + await hass.services.async_call( + DOMAIN, service, service_data, context=context, blocking=True + ) + + +async def async_reproduce_states( + hass: HomeAssistantType, states: Iterable[State], context: Optional[Context] = None +) -> None: + """Reproduce Counter states.""" + await asyncio.gather( + *(_async_reproduce_state(hass, state, context) for state in states) + ) diff --git a/tests/components/counter/test_reproduce_state.py b/tests/components/counter/test_reproduce_state.py new file mode 100644 index 00000000000..aa2c5ddbd9a --- /dev/null +++ b/tests/components/counter/test_reproduce_state.py @@ -0,0 +1,71 @@ +"""Test reproduce state for Counter.""" +from homeassistant.core import State + +from tests.common import async_mock_service + + +async def test_reproducing_states(hass, caplog): + """Test reproducing Counter states.""" + hass.states.async_set("counter.entity", "5", {}) + hass.states.async_set( + "counter.entity_attr", + "8", + {"initial": 12, "minimum": 5, "maximum": 15, "step": 3}, + ) + + configure_calls = async_mock_service(hass, "counter", "configure") + + # These calls should do nothing as entities already in desired state + await hass.helpers.state.async_reproduce_state( + [ + State("counter.entity", "5"), + State( + "counter.entity_attr", + "8", + {"initial": 12, "minimum": 5, "maximum": 15, "step": 3}, + ), + ], + blocking=True, + ) + + assert len(configure_calls) == 0 + + # Test invalid state is handled + await hass.helpers.state.async_reproduce_state( + [State("counter.entity", "not_supported")], blocking=True + ) + + assert "not_supported" in caplog.text + assert len(configure_calls) == 0 + + # Make sure correct services are called + await hass.helpers.state.async_reproduce_state( + [ + State("counter.entity", "2"), + State( + "counter.entity_attr", + "7", + {"initial": 10, "minimum": 3, "maximum": 21, "step": 5}, + ), + # Should not raise + State("counter.non_existing", "6"), + ], + blocking=True, + ) + + valid_calls = [ + {"entity_id": "counter.entity", "value": "2"}, + { + "entity_id": "counter.entity_attr", + "value": "7", + "initial": 10, + "minimum": 3, + "maximum": 21, + "step": 5, + }, + ] + assert len(configure_calls) == 2 + for call in configure_calls: + assert call.domain == "counter" + assert call.data in valid_calls + valid_calls.remove(call.data)