diff --git a/homeassistant/components/flux/switch.py b/homeassistant/components/flux/switch.py index 800ccd1938f..7b58ffbe449 100644 --- a/homeassistant/components/flux/switch.py +++ b/homeassistant/components/flux/switch.py @@ -31,10 +31,12 @@ from homeassistant.const import ( CONF_LIGHTS, CONF_MODE, SERVICE_TURN_ON, + STATE_ON, SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET, ) from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.sun import get_astral_event_date from homeassistant.util import slugify from homeassistant.util.color import ( @@ -169,7 +171,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= hass.services.async_register(DOMAIN, service_name, async_update) -class FluxSwitch(SwitchDevice): +class FluxSwitch(SwitchDevice, RestoreEntity): """Representation of a Flux switch.""" def __init__( @@ -214,6 +216,12 @@ class FluxSwitch(SwitchDevice): """Return true if switch is on.""" return self.unsub_tracker is not None + async def async_added_to_hass(self): + """Call when entity about to be added to hass.""" + last_state = await self.async_get_last_state() + if last_state and last_state.state == STATE_ON: + await self.async_turn_on() + async def async_turn_on(self, **kwargs): """Turn on flux.""" if self.is_on: diff --git a/tests/components/flux/test_switch.py b/tests/components/flux/test_switch.py index fb35485f5c9..91871666f46 100644 --- a/tests/components/flux/test_switch.py +++ b/tests/components/flux/test_switch.py @@ -10,12 +10,14 @@ from homeassistant.const import ( SERVICE_TURN_ON, SUN_EVENT_SUNRISE, ) +from homeassistant.core import State import homeassistant.util.dt as dt_util from tests.common import ( assert_setup_component, async_fire_time_changed, async_mock_service, + mock_restore_cache, ) from tests.components.light import common as common_light from tests.components.switch import common @@ -35,6 +37,52 @@ async def test_valid_config(hass): }, ) + state = hass.states.get("switch.flux") + assert state + assert state.state == "off" + + +async def test_restore_state_last_on(hass): + """Test restoring state when the last state is on.""" + mock_restore_cache(hass, [State("switch.flux", "on")]) + + assert await async_setup_component( + hass, + "switch", + { + "switch": { + "platform": "flux", + "name": "flux", + "lights": ["light.desk", "light.lamp"], + } + }, + ) + + state = hass.states.get("switch.flux") + assert state + assert state.state == "on" + + +async def test_restore_state_last_off(hass): + """Test restoring state when the last state is off.""" + mock_restore_cache(hass, [State("switch.flux", "off")]) + + assert await async_setup_component( + hass, + "switch", + { + "switch": { + "platform": "flux", + "name": "flux", + "lights": ["light.desk", "light.lamp"], + } + }, + ) + + state = hass.states.get("switch.flux") + assert state + assert state.state == "off" + async def test_valid_config_with_info(hass): """Test configuration."""