diff --git a/homeassistant/components/trend/binary_sensor.py b/homeassistant/components/trend/binary_sensor.py index 089e82b0f07..2d00f35202c 100644 --- a/homeassistant/components/trend/binary_sensor.py +++ b/homeassistant/components/trend/binary_sensor.py @@ -25,6 +25,7 @@ from homeassistant.const import ( CONF_ENTITY_ID, CONF_FRIENDLY_NAME, CONF_SENSORS, + STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN, ) @@ -37,6 +38,7 @@ from homeassistant.helpers.event import ( async_track_state_change_event, ) from homeassistant.helpers.reload import async_setup_reload_service +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType from homeassistant.util.dt import utcnow @@ -116,7 +118,7 @@ async def async_setup_platform( async_add_entities(sensors) -class SensorTrend(BinarySensorEntity): +class SensorTrend(BinarySensorEntity, RestoreEntity): """Representation of a trend Sensor.""" _attr_should_poll = False @@ -194,6 +196,12 @@ class SensorTrend(BinarySensorEntity): ) ) + if not (state := await self.async_get_last_state()): + return + if state.state == STATE_UNKNOWN: + return + self._state = state.state == STATE_ON + async def async_update(self) -> None: """Get the latest data and update the states.""" # Remove outdated samples diff --git a/tests/components/trend/test_binary_sensor.py b/tests/components/trend/test_binary_sensor.py index c477b9a11fe..cccf1add61b 100644 --- a/tests/components/trend/test_binary_sensor.py +++ b/tests/components/trend/test_binary_sensor.py @@ -2,16 +2,19 @@ from datetime import timedelta from unittest.mock import patch +import pytest + from homeassistant import config as hass_config, setup from homeassistant.components.trend.const import DOMAIN from homeassistant.const import SERVICE_RELOAD, STATE_UNKNOWN -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, State import homeassistant.util.dt as dt_util from tests.common import ( assert_setup_component, get_fixture_path, get_test_home_assistant, + mock_restore_cache, ) @@ -413,3 +416,28 @@ async def test_reload(hass: HomeAssistant) -> None: assert hass.states.get("binary_sensor.test_trend_sensor") is None assert hass.states.get("binary_sensor.second_test_trend_sensor") + + +@pytest.mark.parametrize( + ("saved_state", "restored_state"), + [("on", "on"), ("off", "off"), ("unknown", "unknown")], +) +async def test_restore_state( + hass: HomeAssistant, saved_state: str, restored_state: str +) -> None: + """Test we restore the trend state.""" + mock_restore_cache(hass, (State("binary_sensor.test_trend_sensor", saved_state),)) + + assert await setup.async_setup_component( + hass, + "binary_sensor", + { + "binary_sensor": { + "platform": "trend", + "sensors": {"test_trend_sensor": {"entity_id": "sensor.test_state"}}, + } + }, + ) + await hass.async_block_till_done() + + assert hass.states.get("binary_sensor.test_trend_sensor").state == restored_state