Split out storage delay save (#16017)
* Split out storage delayed write * Update code using delayed save * Fix tests * Fix typing test * Add callback decorator
This commit is contained in:
parent
fdbab3e20c
commit
2ad0bd4036
4 changed files with 65 additions and 38 deletions
|
@ -5,7 +5,7 @@ from logging import getLogger
|
|||
from typing import Any, Dict, List, Optional # noqa: F401
|
||||
import hmac
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import models
|
||||
|
@ -32,7 +32,7 @@ class AuthStore:
|
|||
async def async_get_users(self) -> List[models.User]:
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return list(self._users.values())
|
||||
|
@ -40,7 +40,7 @@ class AuthStore:
|
|||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
@ -52,7 +52,7 @@ class AuthStore:
|
|||
credentials: Optional[models.Credentials] = None) -> models.User:
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
kwargs = {
|
||||
|
@ -73,7 +73,7 @@ class AuthStore:
|
|||
self._users[new_user.id] = new_user
|
||||
|
||||
if credentials is None:
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
return new_user
|
||||
|
||||
# Saving is done inside the link.
|
||||
|
@ -84,33 +84,33 @@ class AuthStore:
|
|||
credentials: models.Credentials) -> None:
|
||||
"""Add credentials to an existing user."""
|
||||
user.credentials.append(credentials)
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
credentials.is_new = False
|
||||
|
||||
async def async_remove_user(self, user: models.User) -> None:
|
||||
"""Remove a user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_activate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
user.is_active = True
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_deactivate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
user.is_active = False
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_remove_credentials(
|
||||
self, credentials: models.Credentials) -> None:
|
||||
"""Remove credentials."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
|
@ -125,7 +125,7 @@ class AuthStore:
|
|||
user.credentials.pop(found)
|
||||
break
|
||||
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_create_refresh_token(
|
||||
self, user: models.User, client_id: Optional[str] = None) \
|
||||
|
@ -133,14 +133,14 @@ class AuthStore:
|
|||
"""Create a new token for a user."""
|
||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||
await self.async_save()
|
||||
self._async_schedule_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
|
@ -154,7 +154,7 @@ class AuthStore:
|
|||
self, token: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
await self._async_load()
|
||||
assert self._users is not None
|
||||
|
||||
found = None
|
||||
|
@ -166,7 +166,7 @@ class AuthStore:
|
|||
|
||||
return found
|
||||
|
||||
async def async_load(self) -> None:
|
||||
async def _async_load(self) -> None:
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
|
@ -218,10 +218,17 @@ class AuthStore:
|
|||
|
||||
self._users = users
|
||||
|
||||
async def async_save(self) -> None:
|
||||
@callback
|
||||
def _async_schedule_save(self) -> None:
|
||||
"""Save users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
return
|
||||
|
||||
self._store.async_delay_save(self._data_to_save, 1)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> Dict:
|
||||
"""Return the data to store."""
|
||||
assert self._users is not None
|
||||
|
||||
users = [
|
||||
|
@ -262,10 +269,8 @@ class AuthStore:
|
|||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
data = {
|
||||
return {
|
||||
'users': users,
|
||||
'credentials': credentials,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
}
|
||||
|
||||
await self._store.async_save(data, delay=1)
|
||||
|
|
|
@ -320,7 +320,7 @@ class ConfigEntries:
|
|||
raise UnknownEntry
|
||||
|
||||
entry = self._entries.pop(found)
|
||||
await self._async_schedule_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
unloaded = await entry.async_unload(self.hass)
|
||||
|
||||
|
@ -391,7 +391,7 @@ class ConfigEntries:
|
|||
source=context['source'],
|
||||
)
|
||||
self._entries.append(entry)
|
||||
await self._async_schedule_save()
|
||||
self._async_schedule_save()
|
||||
|
||||
# Setup entry
|
||||
if entry.domain in self.hass.config.components:
|
||||
|
@ -439,12 +439,16 @@ class ConfigEntries:
|
|||
flow.init_step = source
|
||||
return flow
|
||||
|
||||
async def _async_schedule_save(self):
|
||||
def _async_schedule_save(self):
|
||||
"""Save the entity registry to a file."""
|
||||
data = {
|
||||
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self):
|
||||
"""Return data to save."""
|
||||
return {
|
||||
'entries': [entry.as_dict() for entry in self._entries]
|
||||
}
|
||||
await self._store.async_save(data, delay=SAVE_DELAY)
|
||||
|
||||
|
||||
async def _old_conf_migrator(old_config):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Callable
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
|
@ -76,8 +76,13 @@ class Store:
|
|||
|
||||
async def _async_load(self):
|
||||
"""Helper to load the data."""
|
||||
# Check if we have a pending write
|
||||
if self._data is not None:
|
||||
data = self._data
|
||||
|
||||
# If we didn't generate data yet, do it now.
|
||||
if 'data_func' in data:
|
||||
data['data'] = data.pop('data_func')()
|
||||
else:
|
||||
data = await self.hass.async_add_executor_job(
|
||||
json.load_json, self.path)
|
||||
|
@ -95,8 +100,8 @@ class Store:
|
|||
self._load_task = None
|
||||
return stored
|
||||
|
||||
async def async_save(self, data: Dict, *, delay: Optional[int] = None):
|
||||
"""Save data with an optional delay."""
|
||||
async def async_save(self, data):
|
||||
"""Save data."""
|
||||
self._data = {
|
||||
'version': self.version,
|
||||
'key': self.key,
|
||||
|
@ -104,11 +109,20 @@ class Store:
|
|||
}
|
||||
|
||||
self._async_cleanup_delay_listener()
|
||||
|
||||
if delay is None:
|
||||
self._async_cleanup_stop_listener()
|
||||
await self._async_handle_write_data()
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_delay_save(self, data_func: Callable[[], Dict],
|
||||
delay: Optional[int] = None):
|
||||
"""Save data with an optional delay."""
|
||||
self._data = {
|
||||
'version': self.version,
|
||||
'key': self.key,
|
||||
'data_func': data_func,
|
||||
}
|
||||
|
||||
self._async_cleanup_delay_listener()
|
||||
|
||||
self._unsub_delay_listener = async_call_later(
|
||||
self.hass, delay, self._async_callback_delayed_write)
|
||||
|
@ -151,6 +165,10 @@ class Store:
|
|||
async def _async_handle_write_data(self, *_args):
|
||||
"""Handler to handle writing the config."""
|
||||
data = self._data
|
||||
|
||||
if 'data_func' in data:
|
||||
data['data'] = data.pop('data_func')()
|
||||
|
||||
self._data = None
|
||||
|
||||
async with self._write_lock:
|
||||
|
|
|
@ -56,7 +56,7 @@ async def test_loading_parallel(hass, store, hass_storage, caplog):
|
|||
|
||||
async def test_saving_with_delay(hass, store, hass_storage):
|
||||
"""Test saving data after a delay."""
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
store.async_delay_save(lambda: MOCK_DATA, 1)
|
||||
assert store.key not in hass_storage
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
|
||||
|
@ -71,7 +71,7 @@ async def test_saving_with_delay(hass, store, hass_storage):
|
|||
async def test_saving_on_stop(hass, hass_storage):
|
||||
"""Test delayed saves trigger when we quit Home Assistant."""
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
await store.async_save(MOCK_DATA, delay=1)
|
||||
store.async_delay_save(lambda: MOCK_DATA, 1)
|
||||
assert store.key not in hass_storage
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
|
@ -92,7 +92,7 @@ async def test_loading_while_delay(hass, store, hass_storage):
|
|||
'data': {'delay': 'no'},
|
||||
}
|
||||
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
store.async_delay_save(lambda: {'delay': 'yes'}, 1)
|
||||
assert hass_storage[store.key] == {
|
||||
'version': MOCK_VERSION,
|
||||
'key': MOCK_KEY,
|
||||
|
@ -105,7 +105,7 @@ async def test_loading_while_delay(hass, store, hass_storage):
|
|||
|
||||
async def test_writing_while_writing_delay(hass, store, hass_storage):
|
||||
"""Test a write while a write with delay is active."""
|
||||
await store.async_save({'delay': 'yes'}, delay=1)
|
||||
store.async_delay_save(lambda: {'delay': 'yes'}, 1)
|
||||
assert store.key not in hass_storage
|
||||
await store.async_save({'delay': 'no'})
|
||||
assert hass_storage[store.key] == {
|
||||
|
|
Loading…
Add table
Reference in a new issue