From 5c3a4e3d10c5b0bfc0d5a10bfb64a4bfcc7aa62f Mon Sep 17 00:00:00 2001 From: Adam Mills Date: Wed, 28 Nov 2018 07:16:43 -0500 Subject: [PATCH] Restore states through a JSON store instead of recorder (#17270) * Restore states through a JSON store * Accept entity_id directly in restore state helper * Keep states stored between runs for a limited time * Remove warning --- .../components/alarm_control_panel/manual.py | 6 +- .../components/alarm_control_panel/mqtt.py | 3 +- .../components/automation/__init__.py | 8 +- .../components/binary_sensor/mqtt.py | 3 +- .../components/climate/generic_thermostat.py | 7 +- homeassistant/components/climate/mqtt.py | 3 +- homeassistant/components/counter/__init__.py | 8 +- homeassistant/components/cover/mqtt.py | 3 +- .../components/device_tracker/__init__.py | 8 +- homeassistant/components/fan/mqtt.py | 3 +- homeassistant/components/history.py | 14 - homeassistant/components/input_boolean.py | 7 +- homeassistant/components/input_datetime.py | 8 +- homeassistant/components/input_number.py | 8 +- homeassistant/components/input_select.py | 8 +- homeassistant/components/input_text.py | 8 +- .../components/light/limitlessled.py | 7 +- .../components/light/mqtt/schema_basic.py | 9 +- .../components/light/mqtt/schema_json.py | 10 +- .../components/light/mqtt/schema_template.py | 6 +- homeassistant/components/lock/mqtt.py | 3 +- homeassistant/components/mqtt/__init__.py | 3 + homeassistant/components/recorder/__init__.py | 7 - homeassistant/components/sensor/fastdotcom.py | 8 +- homeassistant/components/sensor/mqtt.py | 3 +- homeassistant/components/sensor/speedtest.py | 8 +- homeassistant/components/switch/mqtt.py | 11 +- homeassistant/components/switch/pilight.py | 7 +- homeassistant/components/timer/__init__.py | 7 +- homeassistant/helpers/entity.py | 11 +- homeassistant/helpers/entity_platform.py | 3 +- homeassistant/helpers/restore_state.py | 233 ++++++++----- homeassistant/helpers/storage.py | 21 +- homeassistant/util/json.py | 7 +- tests/common.py | 31 +- tests/components/emulated_hue/test_upnp.py | 32 +- tests/components/light/test_mqtt.py | 2 +- tests/components/light/test_mqtt_json.py | 2 +- tests/components/light/test_mqtt_template.py | 2 +- tests/components/recorder/test_migrate.py | 17 +- tests/components/switch/test_mqtt.py | 3 +- tests/components/test_history.py | 1 - tests/components/test_logbook.py | 2 - tests/helpers/test_restore_state.py | 315 +++++++++--------- tests/helpers/test_storage.py | 18 +- tests/util/test_json.py | 21 +- 46 files changed, 493 insertions(+), 422 deletions(-) diff --git a/homeassistant/components/alarm_control_panel/manual.py b/homeassistant/components/alarm_control_panel/manual.py index 362923a4ce2..0a79d74d686 100644 --- a/homeassistant/components/alarm_control_panel/manual.py +++ b/homeassistant/components/alarm_control_panel/manual.py @@ -21,7 +21,7 @@ from homeassistant.const import ( import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import track_point_in_time import homeassistant.util.dt as dt_util -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -116,7 +116,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None): )]) -class ManualAlarm(alarm.AlarmControlPanel): +class ManualAlarm(alarm.AlarmControlPanel, RestoreEntity): """ Representation of an alarm status. @@ -310,7 +310,7 @@ class ManualAlarm(alarm.AlarmControlPanel): async def async_added_to_hass(self): """Run when entity about to be added to hass.""" - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() if state: self._state = state.state self._state_ts = state.last_updated diff --git a/homeassistant/components/alarm_control_panel/mqtt.py b/homeassistant/components/alarm_control_panel/mqtt.py index 1b9bb020ead..5f0793ae58c 100644 --- a/homeassistant/components/alarm_control_panel/mqtt.py +++ b/homeassistant/components/alarm_control_panel/mqtt.py @@ -108,8 +108,7 @@ class MqttAlarm(MqttAvailability, MqttDiscoveryUpdate, async def async_added_to_hass(self): """Subscribe mqtt events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index f8563071fbc..f44d044ecfa 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -21,7 +21,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import extract_domain_configs, script, condition from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.util.dt import utcnow import homeassistant.helpers.config_validation as cv @@ -182,7 +182,7 @@ async def async_setup(hass, config): return True -class AutomationEntity(ToggleEntity): +class AutomationEntity(ToggleEntity, RestoreEntity): """Entity to show status of entity.""" def __init__(self, automation_id, name, async_attach_triggers, cond_func, @@ -227,12 +227,13 @@ class AutomationEntity(ToggleEntity): async def async_added_to_hass(self) -> None: """Startup with initial state or previous state.""" + await super().async_added_to_hass() if self._initial_state is not None: enable_automation = self._initial_state _LOGGER.debug("Automation %s initial state %s from config " "initial_state", self.entity_id, enable_automation) else: - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() if state: enable_automation = state.state == STATE_ON self._last_triggered = state.attributes.get('last_triggered') @@ -291,6 +292,7 @@ class AutomationEntity(ToggleEntity): async def async_will_remove_from_hass(self): """Remove listeners when removing automation from HASS.""" + await super().async_will_remove_from_hass() await self.async_turn_off() async def async_enable(self): diff --git a/homeassistant/components/binary_sensor/mqtt.py b/homeassistant/components/binary_sensor/mqtt.py index 4d7e2c07eba..acbad0d0419 100644 --- a/homeassistant/components/binary_sensor/mqtt.py +++ b/homeassistant/components/binary_sensor/mqtt.py @@ -102,8 +102,7 @@ class MqttBinarySensor(MqttAvailability, MqttDiscoveryUpdate, async def async_added_to_hass(self): """Subscribe mqtt events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/climate/generic_thermostat.py b/homeassistant/components/climate/generic_thermostat.py index 212c4265d8a..ffab50c989d 100644 --- a/homeassistant/components/climate/generic_thermostat.py +++ b/homeassistant/components/climate/generic_thermostat.py @@ -23,7 +23,7 @@ from homeassistant.helpers import condition from homeassistant.helpers.event import ( async_track_state_change, async_track_time_interval) import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -96,7 +96,7 @@ async def async_setup_platform(hass, config, async_add_entities, precision)]) -class GenericThermostat(ClimateDevice): +class GenericThermostat(ClimateDevice, RestoreEntity): """Representation of a Generic Thermostat device.""" def __init__(self, hass, name, heater_entity_id, sensor_entity_id, @@ -155,8 +155,9 @@ class GenericThermostat(ClimateDevice): async def async_added_to_hass(self): """Run when entity about to be added.""" + await super().async_added_to_hass() # Check If we have an old state - old_state = await async_get_last_state(self.hass, self.entity_id) + old_state = await self.async_get_last_state() if old_state is not None: # If we have no initial temperature, restore if self._target_temp is None: diff --git a/homeassistant/components/climate/mqtt.py b/homeassistant/components/climate/mqtt.py index 7436ffc41ea..bccf282f055 100644 --- a/homeassistant/components/climate/mqtt.py +++ b/homeassistant/components/climate/mqtt.py @@ -221,8 +221,7 @@ class MqttClimate(MqttAvailability, MqttDiscoveryUpdate, ClimateDevice): async def async_added_to_hass(self): """Handle being added to home assistant.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/counter/__init__.py b/homeassistant/components/counter/__init__.py index d67c93c0d6e..228870489a2 100644 --- a/homeassistant/components/counter/__init__.py +++ b/homeassistant/components/counter/__init__.py @@ -10,9 +10,8 @@ import voluptuous as vol from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -86,7 +85,7 @@ async def async_setup(hass, config): return True -class Counter(Entity): +class Counter(RestoreEntity): """Representation of a counter.""" def __init__(self, object_id, name, initial, restore, step, icon): @@ -128,10 +127,11 @@ class Counter(Entity): async def async_added_to_hass(self): """Call when entity about to be added to Home Assistant.""" + await super().async_added_to_hass() # __init__ will set self._state to self._initial, only override # if needed. if self._restore: - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() if state is not None: self._state = int(state.state) diff --git a/homeassistant/components/cover/mqtt.py b/homeassistant/components/cover/mqtt.py index 92394fc026b..94e2b948c48 100644 --- a/homeassistant/components/cover/mqtt.py +++ b/homeassistant/components/cover/mqtt.py @@ -205,8 +205,7 @@ class MqttCover(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, async def async_added_to_hass(self): """Subscribe MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/device_tracker/__init__.py b/homeassistant/components/device_tracker/__init__.py index a43a7c93bdc..35ecaf71616 100644 --- a/homeassistant/components/device_tracker/__init__.py +++ b/homeassistant/components/device_tracker/__init__.py @@ -22,9 +22,8 @@ from homeassistant.components.zone.zone import async_active_zone from homeassistant.config import load_yaml_config_file, async_log_exception from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_per_platform, discovery -from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import async_track_time_interval -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import GPSType, ConfigType, HomeAssistantType import homeassistant.helpers.config_validation as cv from homeassistant import util @@ -396,7 +395,7 @@ class DeviceTracker: await asyncio.wait(tasks, loop=self.hass.loop) -class Device(Entity): +class Device(RestoreEntity): """Represent a tracked device.""" host_name = None # type: str @@ -564,7 +563,8 @@ class Device(Entity): async def async_added_to_hass(self): """Add an entity.""" - state = await async_get_last_state(self.hass, self.entity_id) + await super().async_added_to_hass() + state = await self.async_get_last_state() if not state: return self._state = state.state diff --git a/homeassistant/components/fan/mqtt.py b/homeassistant/components/fan/mqtt.py index 505a6e90720..75be8e0277c 100644 --- a/homeassistant/components/fan/mqtt.py +++ b/homeassistant/components/fan/mqtt.py @@ -151,8 +151,7 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index 21d4cdc6e56..1773a55b3f1 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -38,20 +38,6 @@ SIGNIFICANT_DOMAINS = ('thermostat', 'climate') IGNORE_DOMAINS = ('zone', 'scene',) -def last_recorder_run(hass): - """Retrieve the last closed recorder run from the database.""" - from homeassistant.components.recorder.models import RecorderRuns - - with session_scope(hass=hass) as session: - res = (session.query(RecorderRuns) - .filter(RecorderRuns.end.isnot(None)) - .order_by(RecorderRuns.end.desc()).first()) - if res is None: - return None - session.expunge(res) - return res - - def get_significant_states(hass, start_time, end_time=None, entity_ids=None, filters=None, include_start_time_state=True): """ diff --git a/homeassistant/components/input_boolean.py b/homeassistant/components/input_boolean.py index 18c9808c6d2..541e38202fc 100644 --- a/homeassistant/components/input_boolean.py +++ b/homeassistant/components/input_boolean.py @@ -15,7 +15,7 @@ from homeassistant.loader import bind_hass import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity DOMAIN = 'input_boolean' @@ -84,7 +84,7 @@ async def async_setup(hass, config): return True -class InputBoolean(ToggleEntity): +class InputBoolean(ToggleEntity, RestoreEntity): """Representation of a boolean input.""" def __init__(self, object_id, name, initial, icon): @@ -117,10 +117,11 @@ class InputBoolean(ToggleEntity): async def async_added_to_hass(self): """Call when entity about to be added to hass.""" # If not None, we got an initial value. + await super().async_added_to_hass() if self._state is not None: return - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() self._state = state and state.state == STATE_ON async def async_turn_on(self, **kwargs): diff --git a/homeassistant/components/input_datetime.py b/homeassistant/components/input_datetime.py index df35ae53ba9..6ac9a24d044 100644 --- a/homeassistant/components/input_datetime.py +++ b/homeassistant/components/input_datetime.py @@ -11,9 +11,8 @@ import voluptuous as vol from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.util import dt as dt_util @@ -97,7 +96,7 @@ async def async_setup(hass, config): return True -class InputDatetime(Entity): +class InputDatetime(RestoreEntity): """Representation of a datetime input.""" def __init__(self, object_id, name, has_date, has_time, icon, initial): @@ -112,6 +111,7 @@ class InputDatetime(Entity): async def async_added_to_hass(self): """Run when entity about to be added.""" + await super().async_added_to_hass() restore_val = None # Priority 1: Initial State @@ -120,7 +120,7 @@ class InputDatetime(Entity): # Priority 2: Old state if restore_val is None: - old_state = await async_get_last_state(self.hass, self.entity_id) + old_state = await self.async_get_last_state() if old_state is not None: restore_val = old_state.state diff --git a/homeassistant/components/input_number.py b/homeassistant/components/input_number.py index f52b9add821..b6c6eab3cf5 100644 --- a/homeassistant/components/input_number.py +++ b/homeassistant/components/input_number.py @@ -11,9 +11,8 @@ import voluptuous as vol import homeassistant.helpers.config_validation as cv from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE) -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -123,7 +122,7 @@ async def async_setup(hass, config): return True -class InputNumber(Entity): +class InputNumber(RestoreEntity): """Representation of a slider.""" def __init__(self, object_id, name, initial, minimum, maximum, step, icon, @@ -178,10 +177,11 @@ class InputNumber(Entity): async def async_added_to_hass(self): """Run when entity about to be added to hass.""" + await super().async_added_to_hass() if self._current_value is not None: return - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() value = state and float(state.state) # Check against None because value can be 0 diff --git a/homeassistant/components/input_select.py b/homeassistant/components/input_select.py index b8398e1be3d..cc9a73bf915 100644 --- a/homeassistant/components/input_select.py +++ b/homeassistant/components/input_select.py @@ -10,9 +10,8 @@ import voluptuous as vol from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -116,7 +115,7 @@ async def async_setup(hass, config): return True -class InputSelect(Entity): +class InputSelect(RestoreEntity): """Representation of a select input.""" def __init__(self, object_id, name, initial, options, icon): @@ -129,10 +128,11 @@ class InputSelect(Entity): async def async_added_to_hass(self): """Run when entity about to be added.""" + await super().async_added_to_hass() if self._current_option is not None: return - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() if not state or state.state not in self._options: self._current_option = self._options[0] else: diff --git a/homeassistant/components/input_text.py b/homeassistant/components/input_text.py index 956d9a6466d..8ac64b398f4 100644 --- a/homeassistant/components/input_text.py +++ b/homeassistant/components/input_text.py @@ -11,9 +11,8 @@ import voluptuous as vol import homeassistant.helpers.config_validation as cv from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE) -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -104,7 +103,7 @@ async def async_setup(hass, config): return True -class InputText(Entity): +class InputText(RestoreEntity): """Represent a text box.""" def __init__(self, object_id, name, initial, minimum, maximum, icon, @@ -157,10 +156,11 @@ class InputText(Entity): async def async_added_to_hass(self): """Run when entity about to be added to hass.""" + await super().async_added_to_hass() if self._current_value is not None: return - state = await async_get_last_state(self.hass, self.entity_id) + state = await self.async_get_last_state() value = state and state.state # Check against None because value can be 0 diff --git a/homeassistant/components/light/limitlessled.py b/homeassistant/components/light/limitlessled.py index 2e2971cfdc2..3a0225d8d65 100644 --- a/homeassistant/components/light/limitlessled.py +++ b/homeassistant/components/light/limitlessled.py @@ -18,7 +18,7 @@ from homeassistant.components.light import ( import homeassistant.helpers.config_validation as cv from homeassistant.util.color import ( color_temperature_mired_to_kelvin, color_hs_to_RGB) -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity REQUIREMENTS = ['limitlessled==1.1.3'] @@ -157,7 +157,7 @@ def state(new_state): return decorator -class LimitlessLEDGroup(Light): +class LimitlessLEDGroup(Light, RestoreEntity): """Representation of a LimitessLED group.""" def __init__(self, group, config): @@ -189,7 +189,8 @@ class LimitlessLEDGroup(Light): async def async_added_to_hass(self): """Handle entity about to be added to hass event.""" - last_state = await async_get_last_state(self.hass, self.entity_id) + await super().async_added_to_hass() + last_state = await self.async_get_last_state() if last_state: self._is_on = (last_state.state == STATE_ON) self._brightness = last_state.attributes.get('brightness') diff --git a/homeassistant/components/light/mqtt/schema_basic.py b/homeassistant/components/light/mqtt/schema_basic.py index 6c7b0e75301..6a151092ef0 100644 --- a/homeassistant/components/light/mqtt/schema_basic.py +++ b/homeassistant/components/light/mqtt/schema_basic.py @@ -22,7 +22,7 @@ from homeassistant.components.mqtt import ( CONF_AVAILABILITY_TOPIC, CONF_COMMAND_TOPIC, CONF_PAYLOAD_AVAILABLE, CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, MqttAvailability, MqttDiscoveryUpdate) -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.helpers.config_validation as cv import homeassistant.util.color as color_util @@ -166,7 +166,7 @@ async def async_setup_entity_basic(hass, config, async_add_entities, )]) -class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light): +class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light, RestoreEntity): """Representation of a MQTT light.""" def __init__(self, name, unique_id, effect_list, topic, templates, @@ -237,8 +237,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light): async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() templates = {} for key, tpl in list(self._templates.items()): @@ -248,7 +247,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light): tpl.hass = self.hass templates[key] = tpl.async_render_with_possible_json_value - last_state = await async_get_last_state(self.hass, self.entity_id) + last_state = await self.async_get_last_state() @callback def state_received(topic, payload, qos): diff --git a/homeassistant/components/light/mqtt/schema_json.py b/homeassistant/components/light/mqtt/schema_json.py index 43e0f655f0b..55df6cbfd5e 100644 --- a/homeassistant/components/light/mqtt/schema_json.py +++ b/homeassistant/components/light/mqtt/schema_json.py @@ -25,7 +25,7 @@ from homeassistant.const import ( CONF_RGB, CONF_WHITE_VALUE, CONF_XY, STATE_ON) from homeassistant.core import callback import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, HomeAssistantType import homeassistant.util.color as color_util @@ -121,7 +121,8 @@ async def async_setup_entity_json(hass: HomeAssistantType, config: ConfigType, )]) -class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light): +class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light, + RestoreEntity): """Representation of a MQTT JSON light.""" def __init__(self, name, unique_id, effect_list, topic, qos, retain, @@ -183,10 +184,9 @@ class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light): async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() - last_state = await async_get_last_state(self.hass, self.entity_id) + last_state = await self.async_get_last_state() @callback def state_received(topic, payload, qos): diff --git a/homeassistant/components/light/mqtt/schema_template.py b/homeassistant/components/light/mqtt/schema_template.py index 082e4674cb9..81ef3e901dd 100644 --- a/homeassistant/components/light/mqtt/schema_template.py +++ b/homeassistant/components/light/mqtt/schema_template.py @@ -21,7 +21,7 @@ from homeassistant.components.mqtt import ( MqttAvailability) import homeassistant.helpers.config_validation as cv import homeassistant.util.color as color_util -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -102,7 +102,7 @@ async def async_setup_entity_template(hass, config, async_add_entities, )]) -class MqttTemplate(MqttAvailability, Light): +class MqttTemplate(MqttAvailability, Light, RestoreEntity): """Representation of a MQTT Template light.""" def __init__(self, hass, name, effect_list, topics, templates, optimistic, @@ -153,7 +153,7 @@ class MqttTemplate(MqttAvailability, Light): """Subscribe to MQTT events.""" await super().async_added_to_hass() - last_state = await async_get_last_state(self.hass, self.entity_id) + last_state = await self.async_get_last_state() @callback def state_received(topic, payload, qos): diff --git a/homeassistant/components/lock/mqtt.py b/homeassistant/components/lock/mqtt.py index b62382e6dd1..28849c88159 100644 --- a/homeassistant/components/lock/mqtt.py +++ b/homeassistant/components/lock/mqtt.py @@ -111,8 +111,7 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, LockDevice): async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() @callback def message_received(topic, payload, qos): diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 72684c7ec13..7ff32a79142 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -840,6 +840,7 @@ class MqttAvailability(Entity): This method must be run in the event loop and returns a coroutine. """ + await super().async_added_to_hass() await self._availability_subscribe_topics() async def availability_discovery_update(self, config: dict): @@ -900,6 +901,8 @@ class MqttDiscoveryUpdate(Entity): async def async_added_to_hass(self) -> None: """Subscribe to discovery updates.""" + await super().async_added_to_hass() + from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.components.mqtt.discovery import ( ALREADY_DISCOVERED, MQTT_DISCOVERY_UPDATED) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index ddb508d1282..c53fa051a27 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -28,7 +28,6 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entityfilter import generate_filter from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util -from homeassistant.loader import bind_hass from . import migration, purge from .const import DATA_INSTANCE @@ -83,12 +82,6 @@ CONFIG_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) -@bind_hass -async def wait_connection_ready(hass): - """Wait till the connection is ready.""" - return await hass.data[DATA_INSTANCE].async_db_ready - - def run_information(hass, point_in_time: Optional[datetime] = None): """Return information about current run. diff --git a/homeassistant/components/sensor/fastdotcom.py b/homeassistant/components/sensor/fastdotcom.py index 761dc7c6a00..8e975c48574 100644 --- a/homeassistant/components/sensor/fastdotcom.py +++ b/homeassistant/components/sensor/fastdotcom.py @@ -10,9 +10,8 @@ import voluptuous as vol from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import track_time_change -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.util.dt as dt_util REQUIREMENTS = ['fastdotcom==0.0.3'] @@ -51,7 +50,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None): hass.services.register(DOMAIN, 'update_fastdotcom', update) -class SpeedtestSensor(Entity): +class SpeedtestSensor(RestoreEntity): """Implementation of a FAst.com sensor.""" def __init__(self, speedtest_data): @@ -86,7 +85,8 @@ class SpeedtestSensor(Entity): async def async_added_to_hass(self): """Handle entity which will be added.""" - state = await async_get_last_state(self.hass, self.entity_id) + await super().async_added_to_hass() + state = await self.async_get_last_state() if not state: return self._state = state.state diff --git a/homeassistant/components/sensor/mqtt.py b/homeassistant/components/sensor/mqtt.py index 68f49961cf9..bd97cc0e90d 100644 --- a/homeassistant/components/sensor/mqtt.py +++ b/homeassistant/components/sensor/mqtt.py @@ -119,8 +119,7 @@ class MqttSensor(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() await self._subscribe_topics() async def discovery_update(self, discovery_payload): diff --git a/homeassistant/components/sensor/speedtest.py b/homeassistant/components/sensor/speedtest.py index a08eec56e17..f834b51b064 100644 --- a/homeassistant/components/sensor/speedtest.py +++ b/homeassistant/components/sensor/speedtest.py @@ -11,9 +11,8 @@ import voluptuous as vol from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA from homeassistant.const import ATTR_ATTRIBUTION, CONF_MONITORED_CONDITIONS import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity import Entity from homeassistant.helpers.event import track_time_change -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.util.dt as dt_util REQUIREMENTS = ['speedtest-cli==2.0.2'] @@ -76,7 +75,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None): hass.services.register(DOMAIN, 'update_speedtest', update) -class SpeedtestSensor(Entity): +class SpeedtestSensor(RestoreEntity): """Implementation of a speedtest.net sensor.""" def __init__(self, speedtest_data, sensor_type): @@ -137,7 +136,8 @@ class SpeedtestSensor(Entity): async def async_added_to_hass(self): """Handle all entity which are about to be added.""" - state = await async_get_last_state(self.hass, self.entity_id) + await super().async_added_to_hass() + state = await self.async_get_last_state() if not state: return self._state = state.state diff --git a/homeassistant/components/switch/mqtt.py b/homeassistant/components/switch/mqtt.py index ad2b963629e..250fe36b700 100644 --- a/homeassistant/components/switch/mqtt.py +++ b/homeassistant/components/switch/mqtt.py @@ -24,7 +24,7 @@ from homeassistant.components import mqtt, switch import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.typing import HomeAssistantType, ConfigType -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -102,8 +102,9 @@ async def _async_setup_entity(hass, config, async_add_entities, async_add_entities([newswitch]) +# pylint: disable=too-many-ancestors class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, - SwitchDevice): + SwitchDevice, RestoreEntity): """Representation of a switch that can be toggled using MQTT.""" def __init__(self, name, icon, @@ -136,8 +137,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, async def async_added_to_hass(self): """Subscribe to MQTT events.""" - await MqttAvailability.async_added_to_hass(self) - await MqttDiscoveryUpdate.async_added_to_hass(self) + await super().async_added_to_hass() @callback def state_message_received(topic, payload, qos): @@ -161,8 +161,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo, self._qos) if self._optimistic: - last_state = await async_get_last_state(self.hass, - self.entity_id) + last_state = await self.async_get_last_state() if last_state: self._state = last_state.state == STATE_ON diff --git a/homeassistant/components/switch/pilight.py b/homeassistant/components/switch/pilight.py index 16dfc075409..3bbe2e69110 100644 --- a/homeassistant/components/switch/pilight.py +++ b/homeassistant/components/switch/pilight.py @@ -13,7 +13,7 @@ from homeassistant.components import pilight from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import (CONF_NAME, CONF_ID, CONF_SWITCHES, CONF_STATE, CONF_PROTOCOL, STATE_ON) -from homeassistant.helpers.restore_state import async_get_last_state +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -97,7 +97,7 @@ class _ReceiveHandle: switch.set_state(turn_on=turn_on, send_code=self.echo) -class PilightSwitch(SwitchDevice): +class PilightSwitch(SwitchDevice, RestoreEntity): """Representation of a Pilight switch.""" def __init__(self, hass, name, code_on, code_off, code_on_receive, @@ -123,7 +123,8 @@ class PilightSwitch(SwitchDevice): async def async_added_to_hass(self): """Call when entity about to be added to hass.""" - state = await async_get_last_state(self._hass, self.entity_id) + await super().async_added_to_hass() + state = await self.async_get_last_state() if state: self._state = state.state == STATE_ON diff --git a/homeassistant/components/timer/__init__.py b/homeassistant/components/timer/__init__.py index c29df9db858..3f758edea86 100644 --- a/homeassistant/components/timer/__init__.py +++ b/homeassistant/components/timer/__init__.py @@ -12,9 +12,9 @@ import voluptuous as vol import homeassistant.util.dt as dt_util import homeassistant.helpers.config_validation as cv from homeassistant.const import (ATTR_ENTITY_ID, CONF_ICON, CONF_NAME) -from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_track_point_in_utc_time +from homeassistant.helpers.restore_state import RestoreEntity _LOGGER = logging.getLogger(__name__) @@ -97,7 +97,7 @@ async def async_setup(hass, config): return True -class Timer(Entity): +class Timer(RestoreEntity): """Representation of a timer.""" def __init__(self, hass, object_id, name, icon, duration): @@ -146,8 +146,7 @@ class Timer(Entity): if self._state is not None: return - restore_state = self._hass.helpers.restore_state - state = await restore_state.async_get_last_state(self.entity_id) + state = await self.async_get_last_state() self._state = state and state.state == state async def async_start(self, duration): diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 687ed0b6f8b..2d4ad68dbbe 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -363,10 +363,7 @@ class Entity: async def async_remove(self): """Remove entity from Home Assistant.""" - will_remove = getattr(self, 'async_will_remove_from_hass', None) - - if will_remove: - await will_remove() # pylint: disable=not-callable + await self.async_will_remove_from_hass() if self._on_remove is not None: while self._on_remove: @@ -390,6 +387,12 @@ class Entity: self.hass.async_create_task(readd()) + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + + async def async_will_remove_from_hass(self) -> None: + """Run when entity will be removed from hass.""" + def __eq__(self, other): """Return the comparison.""" if not isinstance(other, self.__class__): diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index ec7b5579342..ece0fbd071a 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -346,8 +346,7 @@ class EntityPlatform: self.entities[entity_id] = entity entity.async_on_remove(lambda: self.entities.pop(entity_id)) - if hasattr(entity, 'async_added_to_hass'): - await entity.async_added_to_hass() + await entity.async_added_to_hass() await entity.async_update_ha_state() diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index eb88a3db369..51f1bd76c2a 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -2,97 +2,174 @@ import asyncio import logging from datetime import timedelta +from typing import Any, Dict, List, Set, Optional # noqa pylint_disable=unused-import -import async_timeout - -from homeassistant.core import HomeAssistant, CoreState, callback -from homeassistant.const import EVENT_HOMEASSISTANT_START -from homeassistant.loader import bind_hass -from homeassistant.components.history import get_states, last_recorder_run -from homeassistant.components.recorder import ( - wait_connection_ready, DOMAIN as _RECORDER) +from homeassistant.core import HomeAssistant, callback, State, CoreState +from homeassistant.const import ( + EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP) import homeassistant.util.dt as dt_util +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.event import async_track_time_interval +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.json import JSONEncoder +from homeassistant.helpers.storage import Store # noqa pylint_disable=unused-import + +DATA_RESTORE_STATE_TASK = 'restore_state_task' -RECORDER_TIMEOUT = 10 -DATA_RESTORE_CACHE = 'restore_state_cache' -_LOCK = 'restore_lock' _LOGGER = logging.getLogger(__name__) +STORAGE_KEY = 'core.restore_state' +STORAGE_VERSION = 1 + +# How long between periodically saving the current states to disk +STATE_DUMP_INTERVAL = timedelta(minutes=15) + +# How long should a saved state be preserved if the entity no longer exists +STATE_EXPIRATION = timedelta(days=7) + + +class RestoreStateData(): + """Helper class for managing the helper saved data.""" + + @classmethod + async def async_get_instance( + cls, hass: HomeAssistant) -> 'RestoreStateData': + """Get the singleton instance of this data helper.""" + task = hass.data.get(DATA_RESTORE_STATE_TASK) + + if task is None: + async def load_instance(hass: HomeAssistant) -> 'RestoreStateData': + """Set up the restore state helper.""" + data = cls(hass) + + try: + states = await data.store.async_load() + except HomeAssistantError as exc: + _LOGGER.error("Error loading last states", exc_info=exc) + states = None + + if states is None: + _LOGGER.debug('Not creating cache - no saved states found') + data.last_states = {} + else: + data.last_states = { + state['entity_id']: State.from_dict(state) + for state in states} + _LOGGER.debug( + 'Created cache with %s', list(data.last_states)) + + if hass.state == CoreState.running: + data.async_setup_dump() + else: + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_START, data.async_setup_dump) + + return data + + task = hass.data[DATA_RESTORE_STATE_TASK] = hass.async_create_task( + load_instance(hass)) + + return await task + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the restore state data class.""" + self.hass = hass # type: HomeAssistant + self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY, + encoder=JSONEncoder) # type: Store + self.last_states = {} # type: Dict[str, State] + self.entity_ids = set() # type: Set[str] + + def async_get_states(self) -> List[State]: + """Get the set of states which should be stored. + + This includes the states of all registered entities, as well as the + stored states from the previous run, which have not been created as + entities on this run, and have not expired. + """ + all_states = self.hass.states.async_all() + current_entity_ids = set(state.entity_id for state in all_states) + + # Start with the currently registered states + states = [state for state in all_states + if state.entity_id in self.entity_ids] + + expiration_time = dt_util.utcnow() - STATE_EXPIRATION + + for entity_id, state in self.last_states.items(): + # Don't save old states that have entities in the current run + if entity_id in current_entity_ids: + continue + + # Don't save old states that have expired + if state.last_updated < expiration_time: + continue + + states.append(state) + + return states + + async def async_dump_states(self) -> None: + """Save the current state machine to storage.""" + _LOGGER.debug("Dumping states") + try: + await self.store.async_save([ + state.as_dict() for state in self.async_get_states()]) + except HomeAssistantError as exc: + _LOGGER.error("Error saving current states", exc_info=exc) -def _load_restore_cache(hass: HomeAssistant): - """Load the restore cache to be used by other components.""" @callback - def remove_cache(event): - """Remove the states cache.""" - hass.data.pop(DATA_RESTORE_CACHE, None) + def async_setup_dump(self, *args: Any) -> None: + """Set up the restore state listeners.""" + # Dump the initial states now. This helps minimize the risk of having + # old states loaded by overwritting the last states once home assistant + # has started and the old states have been read. + self.hass.async_create_task(self.async_dump_states()) - hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache) + # Dump states periodically + async_track_time_interval( + self.hass, lambda *_: self.hass.async_create_task( + self.async_dump_states()), STATE_DUMP_INTERVAL) - last_run = last_recorder_run(hass) + # Dump states when stopping hass + self.hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, lambda *_: self.hass.async_create_task( + self.async_dump_states())) - if last_run is None or last_run.end is None: - _LOGGER.debug('Not creating cache - no suitable last run found: %s', - last_run) - hass.data[DATA_RESTORE_CACHE] = {} - return + @callback + def async_register_entity(self, entity_id: str) -> None: + """Store this entity's state when hass is shutdown.""" + self.entity_ids.add(entity_id) - last_end_time = last_run.end - timedelta(seconds=1) - # Unfortunately the recorder_run model do not return offset-aware time - last_end_time = last_end_time.replace(tzinfo=dt_util.UTC) - _LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time) - - states = get_states(hass, last_end_time, run=last_run) - - # Cache the states - hass.data[DATA_RESTORE_CACHE] = { - state.entity_id: state for state in states} - _LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE])) + @callback + def async_unregister_entity(self, entity_id: str) -> None: + """Unregister this entity from saving state.""" + self.entity_ids.remove(entity_id) -@bind_hass -async def async_get_last_state(hass, entity_id: str): - """Restore state.""" - if DATA_RESTORE_CACHE in hass.data: - return hass.data[DATA_RESTORE_CACHE].get(entity_id) +class RestoreEntity(Entity): + """Mixin class for restoring previous entity state.""" - if _RECORDER not in hass.config.components: - return None + async def async_added_to_hass(self) -> None: + """Register this entity as a restorable entity.""" + _, data = await asyncio.gather( + super().async_added_to_hass(), + RestoreStateData.async_get_instance(self.hass), + ) + data.async_register_entity(self.entity_id) - if hass.state not in (CoreState.starting, CoreState.not_running): - _LOGGER.debug("Cache for %s can only be loaded during startup, not %s", - entity_id, hass.state) - return None + async def async_will_remove_from_hass(self) -> None: + """Run when entity will be removed from hass.""" + _, data = await asyncio.gather( + super().async_will_remove_from_hass(), + RestoreStateData.async_get_instance(self.hass), + ) + data.async_unregister_entity(self.entity_id) - try: - with async_timeout.timeout(RECORDER_TIMEOUT, loop=hass.loop): - connected = await wait_connection_ready(hass) - except asyncio.TimeoutError: - return None - - if not connected: - return None - - if _LOCK not in hass.data: - hass.data[_LOCK] = asyncio.Lock(loop=hass.loop) - - async with hass.data[_LOCK]: - if DATA_RESTORE_CACHE not in hass.data: - await hass.async_add_job( - _load_restore_cache, hass) - - return hass.data.get(DATA_RESTORE_CACHE, {}).get(entity_id) - - -async def async_restore_state(entity, extract_info): - """Call entity.async_restore_state with cached info.""" - if entity.hass.state not in (CoreState.starting, CoreState.not_running): - _LOGGER.debug("Not restoring state for %s: Hass is not starting: %s", - entity.entity_id, entity.hass.state) - return - - state = await async_get_last_state(entity.hass, entity.entity_id) - - if not state: - return - - await entity.async_restore_state(**extract_info(state)) + async def async_get_last_state(self) -> Optional[State]: + """Get the entity state from the previous run.""" + if self.hass is None or self.entity_id is None: + # Return None if this entity isn't added to hass yet + _LOGGER.warning("Cannot get last state. Entity not added to hass") + return None + data = await RestoreStateData.async_get_instance(self.hass) + return data.last_states.get(self.entity_id) diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py index cfe73d6d147..5fbb7700458 100644 --- a/homeassistant/helpers/storage.py +++ b/homeassistant/helpers/storage.py @@ -1,13 +1,14 @@ """Helper to help store data.""" import asyncio +from json import JSONEncoder import logging import os -from typing import Dict, Optional, Callable, Any +from typing import Dict, List, Optional, Callable, Union from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import callback from homeassistant.loader import bind_hass -from homeassistant.util import json +from homeassistant.util import json as json_util from homeassistant.helpers.event import async_call_later STORAGE_DIR = '.storage' @@ -16,7 +17,7 @@ _LOGGER = logging.getLogger(__name__) @bind_hass async def async_migrator(hass, old_path, store, *, - old_conf_load_func=json.load_json, + old_conf_load_func=json_util.load_json, old_conf_migrate_func=None): """Migrate old data to a store and then load data. @@ -46,7 +47,8 @@ async def async_migrator(hass, old_path, store, *, class Store: """Class to help storing data.""" - def __init__(self, hass, version: int, key: str, private: bool = False): + def __init__(self, hass, version: int, key: str, private: bool = False, *, + encoder: JSONEncoder = None): """Initialize storage class.""" self.version = version self.key = key @@ -57,13 +59,14 @@ class Store: self._unsub_stop_listener = None self._write_lock = asyncio.Lock(loop=hass.loop) self._load_task = None + self._encoder = encoder @property def path(self): """Return the config path.""" return self.hass.config.path(STORAGE_DIR, self.key) - async def async_load(self) -> Optional[Dict[str, Any]]: + async def async_load(self) -> Optional[Union[Dict, List]]: """Load data. If the expected version does not match the given version, the migrate @@ -88,7 +91,7 @@ class Store: data['data'] = data.pop('data_func')() else: data = await self.hass.async_add_executor_job( - json.load_json, self.path) + json_util.load_json, self.path) if data == {}: return None @@ -103,7 +106,7 @@ class Store: self._load_task = None return stored - async def async_save(self, data): + async def async_save(self, data: Union[Dict, List]) -> None: """Save data.""" self._data = { 'version': self.version, @@ -178,7 +181,7 @@ class Store: try: await self.hass.async_add_executor_job( self._write_data, self.path, data) - except (json.SerializationError, json.WriteError) as err: + except (json_util.SerializationError, json_util.WriteError) as err: _LOGGER.error('Error writing config for %s: %s', self.key, err) def _write_data(self, path: str, data: Dict): @@ -187,7 +190,7 @@ class Store: os.makedirs(os.path.dirname(path)) _LOGGER.debug('Writing data for %s', self.key) - json.save_json(path, data, self._private) + json_util.save_json(path, data, self._private, encoder=self._encoder) async def _async_migrate_func(self, old_version, old_data): """Migrate to the new version.""" diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index b002c8e3147..8ca1c702b6c 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -1,6 +1,6 @@ """JSON utility functions.""" import logging -from typing import Union, List, Dict +from typing import Union, List, Dict, Optional import json import os @@ -41,7 +41,8 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \ def save_json(filename: str, data: Union[List, Dict], - private: bool = False) -> None: + private: bool = False, *, + encoder: Optional[json.JSONEncoder] = None) -> None: """Save JSON data to a file. Returns True on success. @@ -49,7 +50,7 @@ def save_json(filename: str, data: Union[List, Dict], tmp_filename = "" tmp_path = os.path.split(filename)[0] try: - json_data = json.dumps(data, sort_keys=True, indent=4) + json_data = json.dumps(data, sort_keys=True, indent=4, cls=encoder) # Modern versions of Python tempfile create this file with mode 0o600 with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8', dir=tmp_path, delete=False) as fdesc: diff --git a/tests/common.py b/tests/common.py index d5056e220f0..86bc0643d65 100644 --- a/tests/common.py +++ b/tests/common.py @@ -114,8 +114,7 @@ def get_test_home_assistant(): # pylint: disable=protected-access -@asyncio.coroutine -def async_test_home_assistant(loop): +async def async_test_home_assistant(loop): """Return a Home Assistant object pointing at test config dir.""" hass = ha.HomeAssistant(loop) hass.config.async_load = Mock() @@ -168,13 +167,12 @@ def async_test_home_assistant(loop): # Mock async_start orig_start = hass.async_start - @asyncio.coroutine - def mock_async_start(): + async def mock_async_start(): """Start the mocking.""" # We only mock time during tests and we want to track tasks with patch('homeassistant.core._async_create_timer'), \ patch.object(hass, 'async_stop_track_tasks'): - yield from orig_start() + await orig_start() hass.async_start = mock_async_start @@ -715,14 +713,20 @@ def init_recorder_component(hass, add_config=None): def mock_restore_cache(hass, states): """Mock the DATA_RESTORE_CACHE.""" - key = restore_state.DATA_RESTORE_CACHE - hass.data[key] = { + key = restore_state.DATA_RESTORE_STATE_TASK + data = restore_state.RestoreStateData(hass) + + data.last_states = { state.entity_id: state for state in states} - _LOGGER.debug('Restore cache: %s', hass.data[key]) - assert len(hass.data[key]) == len(states), \ + _LOGGER.debug('Restore cache: %s', data.last_states) + assert len(data.last_states) == len(states), \ "Duplicate entity_id? {}".format(states) - hass.state = ha.CoreState.starting - mock_component(hass, recorder.DOMAIN) + + async def get_restore_state_data() -> restore_state.RestoreStateData: + return data + + # Patch the singleton task in hass.data to return our new RestoreStateData + hass.data[key] = hass.async_create_task(get_restore_state_data()) class MockDependency: @@ -846,9 +850,10 @@ def mock_storage(data=None): def mock_write_data(store, path, data_to_write): """Mock version of write data.""" - # To ensure that the data can be serialized _LOGGER.info('Writing data to %s: %s', store.key, data_to_write) - data[store.key] = json.loads(json.dumps(data_to_write)) + # To ensure that the data can be serialized + data[store.key] = json.loads(json.dumps( + data_to_write, cls=store._encoder)) with patch('homeassistant.helpers.storage.Store._async_load', side_effect=mock_async_load, autospec=True), \ diff --git a/tests/components/emulated_hue/test_upnp.py b/tests/components/emulated_hue/test_upnp.py index 9c549f00ee8..0a82dc3513d 100644 --- a/tests/components/emulated_hue/test_upnp.py +++ b/tests/components/emulated_hue/test_upnp.py @@ -6,10 +6,8 @@ from unittest.mock import patch import requests from aiohttp.hdrs import CONTENT_TYPE -from homeassistant import setup, const, core -import homeassistant.components as core_components +from homeassistant import setup, const from homeassistant.components import emulated_hue, http -from homeassistant.util.async_ import run_coroutine_threadsafe from tests.common import get_test_instance_port, get_test_home_assistant @@ -20,29 +18,6 @@ BRIDGE_URL_BASE = 'http://127.0.0.1:{}'.format(BRIDGE_SERVER_PORT) + '{}' JSON_HEADERS = {CONTENT_TYPE: const.CONTENT_TYPE_JSON} -def setup_hass_instance(emulated_hue_config): - """Set up the Home Assistant instance to test.""" - hass = get_test_home_assistant() - - # We need to do this to get access to homeassistant/turn_(on,off) - run_coroutine_threadsafe( - core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop - ).result() - - setup.setup_component( - hass, http.DOMAIN, - {http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}}) - - setup.setup_component(hass, emulated_hue.DOMAIN, emulated_hue_config) - - return hass - - -def start_hass_instance(hass): - """Start the Home Assistant instance to test.""" - hass.start() - - class TestEmulatedHue(unittest.TestCase): """Test the emulated Hue component.""" @@ -53,11 +28,6 @@ class TestEmulatedHue(unittest.TestCase): """Set up the class.""" cls.hass = hass = get_test_home_assistant() - # We need to do this to get access to homeassistant/turn_(on,off) - run_coroutine_threadsafe( - core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop - ).result() - setup.setup_component( hass, http.DOMAIN, {http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}}) diff --git a/tests/components/light/test_mqtt.py b/tests/components/light/test_mqtt.py index c56835afc9f..3b4ff586c94 100644 --- a/tests/components/light/test_mqtt.py +++ b/tests/components/light/test_mqtt.py @@ -585,7 +585,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock): 'effect': 'random', 'color_temp': 100, 'white_value': 50}) - with patch('homeassistant.components.light.mqtt.schema_basic' + with patch('homeassistant.helpers.restore_state.RestoreEntity' '.async_get_last_state', return_value=mock_coro(fake_state)): with assert_setup_component(1, light.DOMAIN): diff --git a/tests/components/light/test_mqtt_json.py b/tests/components/light/test_mqtt_json.py index e509cd5718c..ae34cb6d827 100644 --- a/tests/components/light/test_mqtt_json.py +++ b/tests/components/light/test_mqtt_json.py @@ -279,7 +279,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock): 'color_temp': 100, 'white_value': 50}) - with patch('homeassistant.components.light.mqtt.schema_json' + with patch('homeassistant.helpers.restore_state.RestoreEntity' '.async_get_last_state', return_value=mock_coro(fake_state)): assert await async_setup_component(hass, light.DOMAIN, { diff --git a/tests/components/light/test_mqtt_template.py b/tests/components/light/test_mqtt_template.py index 0d26d6edb12..56030da43f2 100644 --- a/tests/components/light/test_mqtt_template.py +++ b/tests/components/light/test_mqtt_template.py @@ -245,7 +245,7 @@ async def test_optimistic(hass, mqtt_mock): 'color_temp': 100, 'white_value': 50}) - with patch('homeassistant.components.light.mqtt.schema_template' + with patch('homeassistant.helpers.restore_state.RestoreEntity' '.async_get_last_state', return_value=mock_coro(fake_state)): with assert_setup_component(1, light.DOMAIN): diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 93da4ec109b..d008f868466 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -1,6 +1,5 @@ """The tests for the Recorder component.""" # pylint: disable=protected-access -import asyncio from unittest.mock import patch, call import pytest @@ -9,7 +8,7 @@ from sqlalchemy.pool import StaticPool from homeassistant.bootstrap import async_setup_component from homeassistant.components.recorder import ( - wait_connection_ready, migration, const, models) + migration, const, models) from tests.components.recorder import models_original @@ -23,26 +22,24 @@ def create_engine_test(*args, **kwargs): return engine -@asyncio.coroutine -def test_schema_update_calls(hass): +async def test_schema_update_calls(hass): """Test that schema migrations occur in correct order.""" with patch('sqlalchemy.create_engine', new=create_engine_test), \ patch('homeassistant.components.recorder.migration._apply_update') as \ update: - yield from async_setup_component(hass, 'recorder', { + await async_setup_component(hass, 'recorder', { 'recorder': { 'db_url': 'sqlite://' } }) - yield from wait_connection_ready(hass) + await hass.async_block_till_done() update.assert_has_calls([ call(hass.data[const.DATA_INSTANCE].engine, version+1, 0) for version in range(0, models.SCHEMA_VERSION)]) -@asyncio.coroutine -def test_schema_migrate(hass): +async def test_schema_migrate(hass): """Test the full schema migration logic. We're just testing that the logic can execute successfully here without @@ -52,12 +49,12 @@ def test_schema_migrate(hass): with patch('sqlalchemy.create_engine', new=create_engine_test), \ patch('homeassistant.components.recorder.Recorder._setup_run') as \ setup_run: - yield from async_setup_component(hass, 'recorder', { + await async_setup_component(hass, 'recorder', { 'recorder': { 'db_url': 'sqlite://' } }) - yield from wait_connection_ready(hass) + await hass.async_block_till_done() assert setup_run.called diff --git a/tests/components/switch/test_mqtt.py b/tests/components/switch/test_mqtt.py index 4099a5b7951..5cfefd7a0c8 100644 --- a/tests/components/switch/test_mqtt.py +++ b/tests/components/switch/test_mqtt.py @@ -57,7 +57,8 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mock_publish): """Test the sending MQTT commands in optimistic mode.""" fake_state = ha.State('switch.test', 'on') - with patch('homeassistant.components.switch.mqtt.async_get_last_state', + with patch('homeassistant.helpers.restore_state.RestoreEntity' + '.async_get_last_state', return_value=mock_coro(fake_state)): assert await async_setup_component(hass, switch.DOMAIN, { switch.DOMAIN: { diff --git a/tests/components/test_history.py b/tests/components/test_history.py index 641dff3b4e6..0c9062414e7 100644 --- a/tests/components/test_history.py +++ b/tests/components/test_history.py @@ -519,7 +519,6 @@ async def test_fetch_period_api(hass, hass_client): """Test the fetch period view for history.""" await hass.async_add_job(init_recorder_component, hass) await async_setup_component(hass, 'history', {}) - await hass.components.recorder.wait_connection_ready() await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done) client = await hass_client() response = await client.get( diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index ae1e3d1d51a..4619dc7ec2e 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -575,7 +575,6 @@ async def test_logbook_view(hass, aiohttp_client): """Test the logbook view.""" await hass.async_add_job(init_recorder_component, hass) await async_setup_component(hass, 'logbook', {}) - await hass.components.recorder.wait_connection_ready() await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done) client = await aiohttp_client(hass.http.app) response = await client.get( @@ -587,7 +586,6 @@ async def test_logbook_view_period_entity(hass, aiohttp_client): """Test the logbook view with period and entity.""" await hass.async_add_job(init_recorder_component, hass) await async_setup_component(hass, 'logbook', {}) - await hass.components.recorder.wait_connection_ready() await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done) entity_id_test = 'switch.test' diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 15dda24a529..1ac48264d45 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -1,60 +1,52 @@ """The tests for the Restore component.""" -import asyncio -from datetime import timedelta -from unittest.mock import patch, MagicMock +from datetime import datetime -from homeassistant.setup import setup_component from homeassistant.const import EVENT_HOMEASSISTANT_START -from homeassistant.core import CoreState, split_entity_id, State -import homeassistant.util.dt as dt_util -from homeassistant.components import input_boolean, recorder +from homeassistant.core import CoreState, State +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity import Entity from homeassistant.helpers.restore_state import ( - async_get_last_state, DATA_RESTORE_CACHE) -from homeassistant.components.recorder.models import RecorderRuns, States + RestoreStateData, RestoreEntity, DATA_RESTORE_STATE_TASK) +from homeassistant.util import dt as dt_util -from tests.common import ( - get_test_home_assistant, mock_coro, init_recorder_component, - mock_component) +from asynctest import patch + +from tests.common import mock_coro -@asyncio.coroutine -def test_caching_data(hass): +async def test_caching_data(hass): """Test that we cache data.""" - mock_component(hass, 'recorder') - hass.state = CoreState.starting - states = [ State('input_boolean.b0', 'on'), State('input_boolean.b1', 'on'), State('input_boolean.b2', 'on'), ] - with patch('homeassistant.helpers.restore_state.last_recorder_run', - return_value=MagicMock(end=dt_util.utcnow())), \ - patch('homeassistant.helpers.restore_state.get_states', - return_value=states), \ - patch('homeassistant.helpers.restore_state.wait_connection_ready', - return_value=mock_coro(True)): - state = yield from async_get_last_state(hass, 'input_boolean.b1') + data = await RestoreStateData.async_get_instance(hass) + await data.store.async_save([state.as_dict() for state in states]) - assert DATA_RESTORE_CACHE in hass.data - assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states} + # Emulate a fresh load + hass.data[DATA_RESTORE_STATE_TASK] = None + + entity = RestoreEntity() + entity.hass = hass + entity.entity_id = 'input_boolean.b1' + + # Mock that only b1 is present this run + with patch('homeassistant.helpers.restore_state.Store.async_save' + ) as mock_write_data: + state = await entity.async_get_last_state() assert state is not None assert state.entity_id == 'input_boolean.b1' assert state.state == 'on' - hass.bus.async_fire(EVENT_HOMEASSISTANT_START) - - yield from hass.async_block_till_done() - - assert DATA_RESTORE_CACHE not in hass.data + assert mock_write_data.called -@asyncio.coroutine -def test_hass_running(hass): - """Test that cache cannot be accessed while hass is running.""" - mock_component(hass, 'recorder') +async def test_hass_starting(hass): + """Test that we cache data.""" + hass.state = CoreState.starting states = [ State('input_boolean.b0', 'on'), @@ -62,129 +54,144 @@ def test_hass_running(hass): State('input_boolean.b2', 'on'), ] - with patch('homeassistant.helpers.restore_state.last_recorder_run', - return_value=MagicMock(end=dt_util.utcnow())), \ - patch('homeassistant.helpers.restore_state.get_states', - return_value=states), \ - patch('homeassistant.helpers.restore_state.wait_connection_ready', - return_value=mock_coro(True)): - state = yield from async_get_last_state(hass, 'input_boolean.b1') - assert state is None + data = await RestoreStateData.async_get_instance(hass) + await data.store.async_save([state.as_dict() for state in states]) + # Emulate a fresh load + hass.data[DATA_RESTORE_STATE_TASK] = None -@asyncio.coroutine -def test_not_connected(hass): - """Test that cache cannot be accessed if db connection times out.""" - mock_component(hass, 'recorder') - hass.state = CoreState.starting + entity = RestoreEntity() + entity.hass = hass + entity.entity_id = 'input_boolean.b1' - states = [State('input_boolean.b1', 'on')] + # Mock that only b1 is present this run + states = [ + State('input_boolean.b1', 'on'), + ] + with patch('homeassistant.helpers.restore_state.Store.async_save' + ) as mock_write_data, patch.object( + hass.states, 'async_all', return_value=states): + state = await entity.async_get_last_state() - with patch('homeassistant.helpers.restore_state.last_recorder_run', - return_value=MagicMock(end=dt_util.utcnow())), \ - patch('homeassistant.helpers.restore_state.get_states', - return_value=states), \ - patch('homeassistant.helpers.restore_state.wait_connection_ready', - return_value=mock_coro(False)): - state = yield from async_get_last_state(hass, 'input_boolean.b1') - assert state is None - - -@asyncio.coroutine -def test_no_last_run_found(hass): - """Test that cache cannot be accessed if no last run found.""" - mock_component(hass, 'recorder') - hass.state = CoreState.starting - - states = [State('input_boolean.b1', 'on')] - - with patch('homeassistant.helpers.restore_state.last_recorder_run', - return_value=None), \ - patch('homeassistant.helpers.restore_state.get_states', - return_value=states), \ - patch('homeassistant.helpers.restore_state.wait_connection_ready', - return_value=mock_coro(True)): - state = yield from async_get_last_state(hass, 'input_boolean.b1') - assert state is None - - -@asyncio.coroutine -def test_cache_timeout(hass): - """Test that cache timeout returns none.""" - mock_component(hass, 'recorder') - hass.state = CoreState.starting - - states = [State('input_boolean.b1', 'on')] - - @asyncio.coroutine - def timeout_coro(): - raise asyncio.TimeoutError() - - with patch('homeassistant.helpers.restore_state.last_recorder_run', - return_value=MagicMock(end=dt_util.utcnow())), \ - patch('homeassistant.helpers.restore_state.get_states', - return_value=states), \ - patch('homeassistant.helpers.restore_state.wait_connection_ready', - return_value=timeout_coro()): - state = yield from async_get_last_state(hass, 'input_boolean.b1') - assert state is None - - -def _add_data_in_last_run(hass, entities): - """Add test data in the last recorder_run.""" - # pylint: disable=protected-access - t_now = dt_util.utcnow() - timedelta(minutes=10) - t_min_1 = t_now - timedelta(minutes=20) - t_min_2 = t_now - timedelta(minutes=30) - - with recorder.session_scope(hass=hass) as session: - session.add(RecorderRuns( - start=t_min_2, - end=t_now, - created=t_min_2 - )) - - for entity_id, state in entities.items(): - session.add(States( - entity_id=entity_id, - domain=split_entity_id(entity_id)[0], - state=state, - attributes='{}', - last_changed=t_min_1, - last_updated=t_min_1, - created=t_min_1)) - - -def test_filling_the_cache(): - """Test filling the cache from the DB.""" - test_entity_id1 = 'input_boolean.b1' - test_entity_id2 = 'input_boolean.b2' - - hass = get_test_home_assistant() - hass.state = CoreState.starting - - init_recorder_component(hass) - - _add_data_in_last_run(hass, { - test_entity_id1: 'on', - test_entity_id2: 'off', - }) - - hass.block_till_done() - setup_component(hass, input_boolean.DOMAIN, { - input_boolean.DOMAIN: { - 'b1': None, - 'b2': None, - }}) - - hass.start() - - state = hass.states.get('input_boolean.b1') - assert state + assert state is not None + assert state.entity_id == 'input_boolean.b1' assert state.state == 'on' - state = hass.states.get('input_boolean.b2') - assert state - assert state.state == 'off' + # Assert that no data was written yet, since hass is still starting. + assert not mock_write_data.called - hass.stop() + # Finish hass startup + with patch('homeassistant.helpers.restore_state.Store.async_save' + ) as mock_write_data: + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + + # Assert that this session states were written + assert mock_write_data.called + + +async def test_dump_data(hass): + """Test that we cache data.""" + states = [ + State('input_boolean.b0', 'on'), + State('input_boolean.b1', 'on'), + State('input_boolean.b2', 'on'), + ] + + entity = Entity() + entity.hass = hass + entity.entity_id = 'input_boolean.b0' + await entity.async_added_to_hass() + + entity = RestoreEntity() + entity.hass = hass + entity.entity_id = 'input_boolean.b1' + await entity.async_added_to_hass() + + data = await RestoreStateData.async_get_instance(hass) + data.last_states = { + 'input_boolean.b0': State('input_boolean.b0', 'off'), + 'input_boolean.b1': State('input_boolean.b1', 'off'), + 'input_boolean.b2': State('input_boolean.b2', 'off'), + 'input_boolean.b3': State('input_boolean.b3', 'off'), + 'input_boolean.b4': State( + 'input_boolean.b4', 'off', last_updated=datetime( + 1985, 10, 26, 1, 22, tzinfo=dt_util.UTC)), + } + + with patch('homeassistant.helpers.restore_state.Store.async_save' + ) as mock_write_data, patch.object( + hass.states, 'async_all', return_value=states): + await data.async_dump_states() + + assert mock_write_data.called + args = mock_write_data.mock_calls[0][1] + written_states = args[0] + + # b0 should not be written, since it didn't extend RestoreEntity + # b1 should be written, since it is present in the current run + # b2 should not be written, since it is not registered with the helper + # b3 should be written, since it is still not expired + # b4 should not be written, since it is now expired + assert len(written_states) == 2 + assert written_states[0]['entity_id'] == 'input_boolean.b1' + assert written_states[0]['state'] == 'on' + assert written_states[1]['entity_id'] == 'input_boolean.b3' + assert written_states[1]['state'] == 'off' + + # Test that removed entities are not persisted + await entity.async_will_remove_from_hass() + + with patch('homeassistant.helpers.restore_state.Store.async_save' + ) as mock_write_data, patch.object( + hass.states, 'async_all', return_value=states): + await data.async_dump_states() + + assert mock_write_data.called + args = mock_write_data.mock_calls[0][1] + written_states = args[0] + assert len(written_states) == 1 + assert written_states[0]['entity_id'] == 'input_boolean.b3' + assert written_states[0]['state'] == 'off' + + +async def test_dump_error(hass): + """Test that we cache data.""" + states = [ + State('input_boolean.b0', 'on'), + State('input_boolean.b1', 'on'), + State('input_boolean.b2', 'on'), + ] + + entity = Entity() + entity.hass = hass + entity.entity_id = 'input_boolean.b0' + await entity.async_added_to_hass() + + entity = RestoreEntity() + entity.hass = hass + entity.entity_id = 'input_boolean.b1' + await entity.async_added_to_hass() + + data = await RestoreStateData.async_get_instance(hass) + + with patch('homeassistant.helpers.restore_state.Store.async_save', + return_value=mock_coro(exception=HomeAssistantError) + ) as mock_write_data, patch.object( + hass.states, 'async_all', return_value=states): + await data.async_dump_states() + + assert mock_write_data.called + + +async def test_load_error(hass): + """Test that we cache data.""" + entity = RestoreEntity() + entity.hass = hass + entity.entity_id = 'input_boolean.b1' + + with patch('homeassistant.helpers.storage.Store.async_load', + return_value=mock_coro(exception=HomeAssistantError)): + state = await entity.async_get_last_state() + + assert state is None diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py index 38b8a7cd380..7c713082372 100644 --- a/tests/helpers/test_storage.py +++ b/tests/helpers/test_storage.py @@ -1,7 +1,8 @@ """Tests for the storage helper.""" import asyncio from datetime import timedelta -from unittest.mock import patch +import json +from unittest.mock import patch, Mock import pytest @@ -31,6 +32,21 @@ async def test_loading(hass, store): assert data == MOCK_DATA +async def test_custom_encoder(hass): + """Test we can save and load data.""" + class JSONEncoder(json.JSONEncoder): + """Mock JSON encoder.""" + + def default(self, o): + """Mock JSON encode method.""" + return "9" + + store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, encoder=JSONEncoder) + await store.async_save(Mock()) + data = await store.async_load() + assert data == "9" + + async def test_loading_non_existing(hass, store): """Test we can save and load data.""" with patch('homeassistant.util.json.open', side_effect=FileNotFoundError): diff --git a/tests/util/test_json.py b/tests/util/test_json.py index 414a9f400aa..a7df74d9225 100644 --- a/tests/util/test_json.py +++ b/tests/util/test_json.py @@ -1,14 +1,17 @@ """Test Home Assistant json utility functions.""" +from json import JSONEncoder import os import unittest import sys from tempfile import mkdtemp -from homeassistant.util.json import (SerializationError, - load_json, save_json) +from homeassistant.util.json import ( + SerializationError, load_json, save_json) from homeassistant.exceptions import HomeAssistantError import pytest +from unittest.mock import Mock + # Test data that can be saved as JSON TEST_JSON_A = {"a": 1, "B": "two"} TEST_JSON_B = {"a": "one", "B": 2} @@ -74,3 +77,17 @@ class TestJSON(unittest.TestCase): fh.write(TEST_BAD_SERIALIED) with pytest.raises(HomeAssistantError): load_json(fname) + + def test_custom_encoder(self): + """Test serializing with a custom encoder.""" + class MockJSONEncoder(JSONEncoder): + """Mock JSON encoder.""" + + def default(self, o): + """Mock JSON encode method.""" + return "9" + + fname = self._path_for("test6") + save_json(fname, Mock(), encoder=MockJSONEncoder) + data = load_json(fname) + self.assertEqual(data, "9")