diff --git a/homeassistant/components/sensor/fitbit.py b/homeassistant/components/sensor/fitbit.py index f312d1f22cc..87bd735a03d 100644 --- a/homeassistant/components/sensor/fitbit.py +++ b/homeassistant/components/sensor/fitbit.py @@ -225,7 +225,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None): hass, config, add_devices, config_path, discovery_info=None) return False else: - config_file = save_json(config_path, DEFAULT_CONFIG) + save_json(config_path, DEFAULT_CONFIG) request_app_setup( hass, config, add_devices, config_path, discovery_info=None) return False diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index db2912d7b42..13cb7de62ef 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -112,15 +112,13 @@ the flow from the config panel. """ import logging -import os import uuid -from . import data_entry_flow -from .core import callback -from .exceptions import HomeAssistantError -from .setup import async_setup_component, async_process_deps_reqs -from .util.json import load_json, save_json -from .util.decorator import Registry +from homeassistant import data_entry_flow +from homeassistant.core import callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.setup import async_setup_component, async_process_deps_reqs +from homeassistant.util.decorator import Registry _LOGGER = logging.getLogger(__name__) @@ -136,6 +134,10 @@ FLOWS = [ ] +STORAGE_KEY = 'core.config_entries' +STORAGE_VERSION = 1 + +# Deprecated since 0.73 PATH_CONFIG = '.config_entries.json' SAVE_DELAY = 1 @@ -271,7 +273,7 @@ class ConfigEntries: hass, self._async_create_flow, self._async_finish_flow) self._hass_config = hass_config self._entries = None - self._sched_save = None + self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) @callback def async_domains(self): @@ -305,7 +307,7 @@ class ConfigEntries: raise UnknownEntry entry = self._entries.pop(found) - self._async_schedule_save() + await self._async_schedule_save() unloaded = await entry.async_unload(self.hass) @@ -314,14 +316,14 @@ class ConfigEntries: } async def async_load(self): - """Load the config.""" - path = self.hass.config.path(PATH_CONFIG) - if not os.path.isfile(path): - self._entries = [] - return + """Handle loading the config.""" + # Migrating for config entries stored before 0.73 + config = await self.hass.helpers.storage.async_migrator( + self.hass.config.path(PATH_CONFIG), self._store, + old_conf_migrate_func=_old_conf_migrator + ) - entries = await self.hass.async_add_job(load_json, path) - self._entries = [ConfigEntry(**entry) for entry in entries] + self._entries = [ConfigEntry(**entry) for entry in config['entries']] async def async_forward_entry_setup(self, entry, component): """Forward the setup of an entry to a different component. @@ -372,7 +374,7 @@ class ConfigEntries: source=result['source'], ) self._entries.append(entry) - self._async_schedule_save() + await self._async_schedule_save() # Setup entry if entry.domain in self.hass.config.components: @@ -416,20 +418,14 @@ class ConfigEntries: return handler() - @callback - def _async_schedule_save(self): - """Schedule saving the entity registry.""" - if self._sched_save is not None: - self._sched_save.cancel() - - self._sched_save = self.hass.loop.call_later( - SAVE_DELAY, self.hass.async_add_job, self._async_save - ) - - async def _async_save(self): + async def _async_schedule_save(self): """Save the entity registry to a file.""" - self._sched_save = None - data = [entry.as_dict() for entry in self._entries] + data = { + 'entries': [entry.as_dict() for entry in self._entries] + } + await self._store.async_save(data, delay=SAVE_DELAY) - await self.hass.async_add_job( - save_json, self.hass.config.path(PATH_CONFIG), data) + +async def _old_conf_migrator(old_config): + """Migrate the pre-0.73 config format to the latest version.""" + return {'entries': old_config} diff --git a/homeassistant/core.py b/homeassistant/core.py index 5e6dcd81310..e0950172913 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -230,6 +230,20 @@ class HomeAssistant(object): return task + @callback + def async_add_executor_job( + self, + target: Callable[..., Any], + *args: Any) -> asyncio.tasks.Task: + """Add an executor job from within the event loop.""" + task = self.loop.run_in_executor(None, target, *args) + + # If a task is scheduled + if self._track_task: + self._pending_tasks.append(task) + + return task + @callback def async_track_tasks(self): """Track tasks so you can wait for all tasks to be done.""" diff --git a/homeassistant/helpers/storage.py b/homeassistant/helpers/storage.py new file mode 100644 index 00000000000..4b0c576f129 --- /dev/null +++ b/homeassistant/helpers/storage.py @@ -0,0 +1,157 @@ +"""Helper to help store data.""" +import asyncio +import logging +import os +from typing import Dict, Optional + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import callback +from homeassistant.loader import bind_hass +from homeassistant.util import json +from homeassistant.helpers.event import async_call_later + +STORAGE_DIR = '.storage' +_LOGGER = logging.getLogger(__name__) + + +@bind_hass +async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None): + """Helper function to migrate old data to a store and then load data. + + async def old_conf_migrate_func(old_data) + """ + def load_old_config(): + """Helper to load old config.""" + if not os.path.isfile(old_path): + return None + + return json.load_json(old_path) + + config = await hass.async_add_executor_job(load_old_config) + + if config is None: + return await store.async_load() + + if old_conf_migrate_func is not None: + config = await old_conf_migrate_func(config) + + await store.async_save(config) + await hass.async_add_executor_job(os.remove, old_path) + return config + + +@bind_hass +class Store: + """Class to help storing data.""" + + def __init__(self, hass, version: int, key: str): + """Initialize storage class.""" + self.version = version + self.key = key + self.hass = hass + self._data = None + self._unsub_delay_listener = None + self._unsub_stop_listener = None + self._write_lock = asyncio.Lock() + + @property + def path(self): + """Return the config path.""" + return self.hass.config.path(STORAGE_DIR, self.key) + + async def async_load(self): + """Load data. + + If the expected version does not match the given version, the migrate + function will be invoked with await migrate_func(version, config). + """ + if self._data is not None: + data = self._data + else: + data = await self.hass.async_add_executor_job( + json.load_json, self.path, None) + + if data is None: + return {} + + if data['version'] == self.version: + return data['data'] + + return await self._async_migrate_func(data['version'], data['data']) + + async def async_save(self, data: Dict, *, delay: Optional[int] = None): + """Save data with an optional delay.""" + self._data = { + 'version': self.version, + 'key': self.key, + 'data': data, + } + + self._async_cleanup_delay_listener() + + if delay is None: + self._async_cleanup_stop_listener() + await self._async_handle_write_data() + return + + self._unsub_delay_listener = async_call_later( + self.hass, delay, self._async_callback_delayed_write) + + self._async_ensure_stop_listener() + + @callback + def _async_ensure_stop_listener(self): + """Ensure that we write if we quit before delay has passed.""" + if self._unsub_stop_listener is None: + self._unsub_stop_listener = self.hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, self._async_callback_stop_write) + + @callback + def _async_cleanup_stop_listener(self): + """Clean up a stop listener.""" + if self._unsub_stop_listener is not None: + self._unsub_stop_listener() + self._unsub_stop_listener = None + + @callback + def _async_cleanup_delay_listener(self): + """Clean up a delay listener.""" + if self._unsub_delay_listener is not None: + self._unsub_delay_listener() + self._unsub_delay_listener = None + + async def _async_callback_delayed_write(self, _now): + """Handle a delayed write callback.""" + self._unsub_delay_listener = None + self._async_cleanup_stop_listener() + await self._async_handle_write_data() + + async def _async_callback_stop_write(self, _event): + """Handle a write because Home Assistant is stopping.""" + self._unsub_stop_listener = None + self._async_cleanup_delay_listener() + await self._async_handle_write_data() + + async def _async_handle_write_data(self, *_args): + """Handler to handle writing the config.""" + data = self._data + self._data = None + + async with self._write_lock: + try: + await self.hass.async_add_executor_job( + self._write_data, self.path, data) + except (json.SerializationError, json.WriteError) as err: + _LOGGER.error('Error writing config for %s: %s', self.key, err) + + def _write_data(self, path: str, data: Dict): + """Write the data.""" + if not os.path.isdir(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + _LOGGER.debug('Writing data for %s', self.key) + json.save_json(path, data) + + async def _async_migrate_func(self, old_version, old_data): + """Migrate to the new version.""" + raise NotImplementedError diff --git a/homeassistant/util/json.py b/homeassistant/util/json.py index b2577ff6be6..0e53342b0ca 100644 --- a/homeassistant/util/json.py +++ b/homeassistant/util/json.py @@ -11,6 +11,14 @@ _LOGGER = logging.getLogger(__name__) _UNDEFINED = object() +class SerializationError(HomeAssistantError): + """Error serializing the data to JSON.""" + + +class WriteError(HomeAssistantError): + """Error writing the data.""" + + def load_json(filename: str, default: Union[List, Dict] = _UNDEFINED) \ -> Union[List, Dict]: """Load JSON data from a file and return as dict or list. @@ -41,13 +49,11 @@ def save_json(filename: str, data: Union[List, Dict]): data = json.dumps(data, sort_keys=True, indent=4) with open(filename, 'w', encoding='utf-8') as fdesc: fdesc.write(data) - return True except TypeError as error: _LOGGER.exception('Failed to serialize to JSON: %s', filename) - raise HomeAssistantError(error) + raise SerializationError(error) except OSError as error: _LOGGER.exception('Saving JSON file failed: %s', filename) - raise HomeAssistantError(error) - return False + raise WriteError(error) diff --git a/tests/common.py b/tests/common.py index 556935a6ac1..56575bdb1e9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,7 +14,7 @@ from homeassistant import auth, core as ha, data_entry_flow, config_entries from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers import ( - intent, entity, restore_state, entity_registry, + intent, entity, restore_state, entity_registry, entity_platform) from homeassistant.util.unit_system import METRIC_SYSTEM import homeassistant.util.dt as date_util @@ -110,8 +110,6 @@ def get_test_home_assistant(): def async_test_home_assistant(loop): """Return a Home Assistant object pointing at test config dir.""" hass = ha.HomeAssistant(loop) - hass.config_entries = config_entries.ConfigEntries(hass, {}) - hass.config_entries._entries = [] hass.config.async_load = Mock() store = auth.AuthStore(hass) hass.auth = auth.AuthManager(hass, store, {}) @@ -137,6 +135,10 @@ def async_test_home_assistant(loop): hass.config.units = METRIC_SYSTEM hass.config.skip_pip = True + hass.config_entries = config_entries.ConfigEntries(hass, {}) + hass.config_entries._entries = [] + hass.config_entries._store._async_ensure_stop_listener = lambda: None + hass.state = ha.CoreState.running # Mock async_start diff --git a/tests/helpers/test_storage.py b/tests/helpers/test_storage.py new file mode 100644 index 00000000000..289d07edab2 --- /dev/null +++ b/tests/helpers/test_storage.py @@ -0,0 +1,158 @@ +"""Tests for the storage helper.""" +from datetime import timedelta +from unittest.mock import patch + +import pytest + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.helpers import storage +from homeassistant.util import dt + +from tests.common import async_fire_time_changed, mock_coro + + +MOCK_VERSION = 1 +MOCK_KEY = 'storage-test' +MOCK_DATA = {'hello': 'world'} + + +@pytest.fixture +def mock_save(): + """Fixture to mock JSON save.""" + written = [] + with patch('homeassistant.util.json.save_json', + side_effect=lambda *args: written.append(args)): + yield written + + +@pytest.fixture +def mock_load(mock_save): + """Fixture to mock JSON read.""" + with patch('homeassistant.util.json.load_json', + side_effect=lambda *args: mock_save[-1][1]): + yield + + +@pytest.fixture +def store(hass): + """Fixture of a store that prevents writing on HASS stop.""" + store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) + store._async_ensure_stop_listener = lambda: None + yield store + + +async def test_loading(hass, store, mock_save, mock_load): + """Test we can save and load data.""" + await store.async_save(MOCK_DATA) + data = await store.async_load() + assert data == MOCK_DATA + + +async def test_loading_non_existing(hass, store): + """Test we can save and load data.""" + with patch('homeassistant.util.json.open', side_effect=FileNotFoundError): + data = await store.async_load() + assert data == {} + + +async def test_saving_with_delay(hass, store, mock_save): + """Test saving data after a delay.""" + await store.async_save(MOCK_DATA, delay=1) + assert len(mock_save) == 0 + + async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) + await hass.async_block_till_done() + assert len(mock_save) == 1 + + +async def test_saving_on_stop(hass, mock_save): + """Test delayed saves trigger when we quit Home Assistant.""" + store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) + await store.async_save(MOCK_DATA, delay=1) + assert len(mock_save) == 0 + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + assert len(mock_save) == 1 + + +async def test_loading_while_delay(hass, store, mock_save, mock_load): + """Test we load new data even if not written yet.""" + await store.async_save({'delay': 'no'}) + assert len(mock_save) == 1 + + await store.async_save({'delay': 'yes'}, delay=1) + assert len(mock_save) == 1 + + data = await store.async_load() + assert data == {'delay': 'yes'} + + +async def test_writing_while_writing_delay(hass, store, mock_save, mock_load): + """Test a write while a write with delay is active.""" + await store.async_save({'delay': 'yes'}, delay=1) + assert len(mock_save) == 0 + await store.async_save({'delay': 'no'}) + assert len(mock_save) == 1 + + async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) + await hass.async_block_till_done() + assert len(mock_save) == 1 + + data = await store.async_load() + assert data == {'delay': 'no'} + + +async def test_migrator_no_existing_config(hass, store, mock_save): + """Test migrator with no existing config.""" + with patch('os.path.isfile', return_value=False), \ + patch.object(store, 'async_load', + return_value=mock_coro({'cur': 'config'})): + data = await storage.async_migrator( + hass, 'old-path', store) + + assert data == {'cur': 'config'} + assert len(mock_save) == 0 + + +async def test_migrator_existing_config(hass, store, mock_save): + """Test migrating existing config.""" + with patch('os.path.isfile', return_value=True), \ + patch('os.remove') as mock_remove, \ + patch('homeassistant.util.json.load_json', + return_value={'old': 'config'}): + data = await storage.async_migrator( + hass, 'old-path', store) + + assert len(mock_remove.mock_calls) == 1 + assert data == {'old': 'config'} + assert len(mock_save) == 1 + assert mock_save[0][1] == { + 'key': MOCK_KEY, + 'version': MOCK_VERSION, + 'data': data, + } + + +async def test_migrator_transforming_config(hass, store, mock_save): + """Test migrating config to new format.""" + async def old_conf_migrate_func(old_config): + """Migrate old config to new format.""" + return {'new': old_config['old']} + + with patch('os.path.isfile', return_value=True), \ + patch('os.remove') as mock_remove, \ + patch('homeassistant.util.json.load_json', + return_value={'old': 'config'}): + data = await storage.async_migrator( + hass, 'old-path', store, + old_conf_migrate_func=old_conf_migrate_func) + + assert len(mock_remove.mock_calls) == 1 + assert data == {'new': 'config'} + assert len(mock_save) == 1 + assert mock_save[0][1] == { + 'key': MOCK_KEY, + 'version': MOCK_VERSION, + 'data': data, + } diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 84bd0771542..fc0a549f1ae 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1,13 +1,16 @@ """Test the config manager.""" import asyncio +from datetime import timedelta from unittest.mock import MagicMock, patch, mock_open import pytest from homeassistant import config_entries, loader, data_entry_flow from homeassistant.setup import async_setup_component +from homeassistant.util import dt -from tests.common import MockModule, mock_coro, MockConfigEntry +from tests.common import ( + MockModule, mock_coro, MockConfigEntry, async_fire_time_changed) @pytest.fixture @@ -15,6 +18,7 @@ def manager(hass): """Fixture of a loaded config manager.""" manager = config_entries.ConfigEntries(hass, {}) manager._entries = [] + manager._store._async_ensure_stop_listener = lambda: None hass.config_entries = manager return manager @@ -151,7 +155,9 @@ def test_domains_gets_uniques(manager): @asyncio.coroutine def test_saving_and_loading(hass): """Test that we're saving and loading correctly.""" - loader.set_component(hass, 'test', MockModule('test')) + loader.set_component( + hass, 'test', + MockModule('test', async_setup_entry=lambda *args: mock_coro(True))) class TestFlow(data_entry_flow.FlowHandler): VERSION = 5 @@ -183,13 +189,12 @@ def test_saving_and_loading(hass): json_path = 'homeassistant.util.json.open' with patch('homeassistant.config_entries.HANDLERS.get', - return_value=Test2Flow), \ - patch.object(config_entries, 'SAVE_DELAY', 0): + return_value=Test2Flow): yield from hass.config_entries.flow.async_init('test') with patch(json_path, mock_open(), create=True) as mock_write: # To trigger the call_later - yield from asyncio.sleep(0, loop=hass.loop) + async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) # To execute the save yield from hass.async_block_till_done() @@ -199,7 +204,7 @@ def test_saving_and_loading(hass): # Now load written data in new config manager manager = config_entries.ConfigEntries(hass, {}) - with patch('os.path.isfile', return_value=True), \ + with patch('os.path.isfile', return_value=False), \ patch(json_path, mock_open(read_data=written), create=True): yield from manager.async_load()