Split out storage delay save (#16017)

* Split out storage delayed write

* Update code using delayed save

* Fix tests

* Fix typing test

* Add callback decorator
This commit is contained in:
Paulus Schoutsen 2018-08-17 20:18:21 +02:00 committed by GitHub
parent fdbab3e20c
commit 2ad0bd4036
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 38 deletions

View file

@ -5,7 +5,7 @@ from logging import getLogger
from typing import Any, Dict, List, Optional # noqa: F401 from typing import Any, Dict, List, Optional # noqa: F401
import hmac import hmac
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import models from . import models
@ -32,7 +32,7 @@ class AuthStore:
async def async_get_users(self) -> List[models.User]: async def async_get_users(self) -> List[models.User]:
"""Retrieve all users.""" """Retrieve all users."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
return list(self._users.values()) return list(self._users.values())
@ -40,7 +40,7 @@ class AuthStore:
async def async_get_user(self, user_id: str) -> Optional[models.User]: async def async_get_user(self, user_id: str) -> Optional[models.User]:
"""Retrieve a user by id.""" """Retrieve a user by id."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
return self._users.get(user_id) return self._users.get(user_id)
@ -52,7 +52,7 @@ class AuthStore:
credentials: Optional[models.Credentials] = None) -> models.User: credentials: Optional[models.Credentials] = None) -> models.User:
"""Create a new user.""" """Create a new user."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
kwargs = { kwargs = {
@ -73,7 +73,7 @@ class AuthStore:
self._users[new_user.id] = new_user self._users[new_user.id] = new_user
if credentials is None: if credentials is None:
await self.async_save() self._async_schedule_save()
return new_user return new_user
# Saving is done inside the link. # Saving is done inside the link.
@ -84,33 +84,33 @@ class AuthStore:
credentials: models.Credentials) -> None: credentials: models.Credentials) -> None:
"""Add credentials to an existing user.""" """Add credentials to an existing user."""
user.credentials.append(credentials) user.credentials.append(credentials)
await self.async_save() self._async_schedule_save()
credentials.is_new = False credentials.is_new = False
async def async_remove_user(self, user: models.User) -> None: async def async_remove_user(self, user: models.User) -> None:
"""Remove a user.""" """Remove a user."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
self._users.pop(user.id) self._users.pop(user.id)
await self.async_save() self._async_schedule_save()
async def async_activate_user(self, user: models.User) -> None: async def async_activate_user(self, user: models.User) -> None:
"""Activate a user.""" """Activate a user."""
user.is_active = True user.is_active = True
await self.async_save() self._async_schedule_save()
async def async_deactivate_user(self, user: models.User) -> None: async def async_deactivate_user(self, user: models.User) -> None:
"""Activate a user.""" """Activate a user."""
user.is_active = False user.is_active = False
await self.async_save() self._async_schedule_save()
async def async_remove_credentials( async def async_remove_credentials(
self, credentials: models.Credentials) -> None: self, credentials: models.Credentials) -> None:
"""Remove credentials.""" """Remove credentials."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
for user in self._users.values(): for user in self._users.values():
@ -125,7 +125,7 @@ class AuthStore:
user.credentials.pop(found) user.credentials.pop(found)
break break
await self.async_save() self._async_schedule_save()
async def async_create_refresh_token( async def async_create_refresh_token(
self, user: models.User, client_id: Optional[str] = None) \ self, user: models.User, client_id: Optional[str] = None) \
@ -133,14 +133,14 @@ class AuthStore:
"""Create a new token for a user.""" """Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id) refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.id] = refresh_token user.refresh_tokens[refresh_token.id] = refresh_token
await self.async_save() self._async_schedule_save()
return refresh_token return refresh_token
async def async_get_refresh_token( async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]: self, token_id: str) -> Optional[models.RefreshToken]:
"""Get refresh token by id.""" """Get refresh token by id."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
for user in self._users.values(): for user in self._users.values():
@ -154,7 +154,7 @@ class AuthStore:
self, token: str) -> Optional[models.RefreshToken]: self, token: str) -> Optional[models.RefreshToken]:
"""Get refresh token by token.""" """Get refresh token by token."""
if self._users is None: if self._users is None:
await self.async_load() await self._async_load()
assert self._users is not None assert self._users is not None
found = None found = None
@ -166,7 +166,7 @@ class AuthStore:
return found return found
async def async_load(self) -> None: async def _async_load(self) -> None:
"""Load the users.""" """Load the users."""
data = await self._store.async_load() data = await self._store.async_load()
@ -218,11 +218,18 @@ class AuthStore:
self._users = users self._users = users
async def async_save(self) -> None: @callback
def _async_schedule_save(self) -> None:
"""Save users.""" """Save users."""
if self._users is None: if self._users is None:
await self.async_load() return
assert self._users is not None
self._store.async_delay_save(self._data_to_save, 1)
@callback
def _data_to_save(self) -> Dict:
"""Return the data to store."""
assert self._users is not None
users = [ users = [
{ {
@ -262,10 +269,8 @@ class AuthStore:
for refresh_token in user.refresh_tokens.values() for refresh_token in user.refresh_tokens.values()
] ]
data = { return {
'users': users, 'users': users,
'credentials': credentials, 'credentials': credentials,
'refresh_tokens': refresh_tokens, 'refresh_tokens': refresh_tokens,
} }
await self._store.async_save(data, delay=1)

View file

@ -320,7 +320,7 @@ class ConfigEntries:
raise UnknownEntry raise UnknownEntry
entry = self._entries.pop(found) entry = self._entries.pop(found)
await self._async_schedule_save() self._async_schedule_save()
unloaded = await entry.async_unload(self.hass) unloaded = await entry.async_unload(self.hass)
@ -391,7 +391,7 @@ class ConfigEntries:
source=context['source'], source=context['source'],
) )
self._entries.append(entry) self._entries.append(entry)
await self._async_schedule_save() self._async_schedule_save()
# Setup entry # Setup entry
if entry.domain in self.hass.config.components: if entry.domain in self.hass.config.components:
@ -439,12 +439,16 @@ class ConfigEntries:
flow.init_step = source flow.init_step = source
return flow return flow
async def _async_schedule_save(self): def _async_schedule_save(self):
"""Save the entity registry to a file.""" """Save the entity registry to a file."""
data = { self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self):
"""Return data to save."""
return {
'entries': [entry.as_dict() for entry in self._entries] 'entries': [entry.as_dict() for entry in self._entries]
} }
await self._store.async_save(data, delay=SAVE_DELAY)
async def _old_conf_migrator(old_config): async def _old_conf_migrator(old_config):

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
import os import os
from typing import Dict, Optional from typing import Dict, Optional, Callable
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback from homeassistant.core import callback
@ -76,8 +76,13 @@ class Store:
async def _async_load(self): async def _async_load(self):
"""Helper to load the data.""" """Helper to load the data."""
# Check if we have a pending write
if self._data is not None: if self._data is not None:
data = self._data data = self._data
# If we didn't generate data yet, do it now.
if 'data_func' in data:
data['data'] = data.pop('data_func')()
else: else:
data = await self.hass.async_add_executor_job( data = await self.hass.async_add_executor_job(
json.load_json, self.path) json.load_json, self.path)
@ -95,8 +100,8 @@ class Store:
self._load_task = None self._load_task = None
return stored return stored
async def async_save(self, data: Dict, *, delay: Optional[int] = None): async def async_save(self, data):
"""Save data with an optional delay.""" """Save data."""
self._data = { self._data = {
'version': self.version, 'version': self.version,
'key': self.key, 'key': self.key,
@ -104,11 +109,20 @@ class Store:
} }
self._async_cleanup_delay_listener() self._async_cleanup_delay_listener()
self._async_cleanup_stop_listener()
await self._async_handle_write_data()
if delay is None: @callback
self._async_cleanup_stop_listener() def async_delay_save(self, data_func: Callable[[], Dict],
await self._async_handle_write_data() delay: Optional[int] = None):
return """Save data with an optional delay."""
self._data = {
'version': self.version,
'key': self.key,
'data_func': data_func,
}
self._async_cleanup_delay_listener()
self._unsub_delay_listener = async_call_later( self._unsub_delay_listener = async_call_later(
self.hass, delay, self._async_callback_delayed_write) self.hass, delay, self._async_callback_delayed_write)
@ -151,6 +165,10 @@ class Store:
async def _async_handle_write_data(self, *_args): async def _async_handle_write_data(self, *_args):
"""Handler to handle writing the config.""" """Handler to handle writing the config."""
data = self._data data = self._data
if 'data_func' in data:
data['data'] = data.pop('data_func')()
self._data = None self._data = None
async with self._write_lock: async with self._write_lock:

View file

@ -56,7 +56,7 @@ async def test_loading_parallel(hass, store, hass_storage, caplog):
async def test_saving_with_delay(hass, store, hass_storage): async def test_saving_with_delay(hass, store, hass_storage):
"""Test saving data after a delay.""" """Test saving data after a delay."""
await store.async_save(MOCK_DATA, delay=1) store.async_delay_save(lambda: MOCK_DATA, 1)
assert store.key not in hass_storage assert store.key not in hass_storage
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1)) async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
@ -71,7 +71,7 @@ async def test_saving_with_delay(hass, store, hass_storage):
async def test_saving_on_stop(hass, hass_storage): async def test_saving_on_stop(hass, hass_storage):
"""Test delayed saves trigger when we quit Home Assistant.""" """Test delayed saves trigger when we quit Home Assistant."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY) store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await store.async_save(MOCK_DATA, delay=1) store.async_delay_save(lambda: MOCK_DATA, 1)
assert store.key not in hass_storage assert store.key not in hass_storage
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
@ -92,7 +92,7 @@ async def test_loading_while_delay(hass, store, hass_storage):
'data': {'delay': 'no'}, 'data': {'delay': 'no'},
} }
await store.async_save({'delay': 'yes'}, delay=1) store.async_delay_save(lambda: {'delay': 'yes'}, 1)
assert hass_storage[store.key] == { assert hass_storage[store.key] == {
'version': MOCK_VERSION, 'version': MOCK_VERSION,
'key': MOCK_KEY, 'key': MOCK_KEY,
@ -105,7 +105,7 @@ async def test_loading_while_delay(hass, store, hass_storage):
async def test_writing_while_writing_delay(hass, store, hass_storage): async def test_writing_while_writing_delay(hass, store, hass_storage):
"""Test a write while a write with delay is active.""" """Test a write while a write with delay is active."""
await store.async_save({'delay': 'yes'}, delay=1) store.async_delay_save(lambda: {'delay': 'yes'}, 1)
assert store.key not in hass_storage assert store.key not in hass_storage
await store.async_save({'delay': 'no'}) await store.async_save({'delay': 'no'})
assert hass_storage[store.key] == { assert hass_storage[store.key] == {