Use atomicwrites for mission critical core files (#59606)
This commit is contained in:
parent
04a258bf21
commit
96f7b0d910
19 changed files with 92 additions and 24 deletions
|
@ -42,7 +42,7 @@ class AuthStore:
|
||||||
self._groups: dict[str, models.Group] | None = None
|
self._groups: dict[str, models.Group] | None = None
|
||||||
self._perm_lookup: PermissionLookup | None = None
|
self._perm_lookup: PermissionLookup | None = None
|
||||||
self._store = hass.helpers.storage.Store(
|
self._store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||||
super().__init__(hass, config)
|
super().__init__(hass, config)
|
||||||
self._user_settings: _UsersDict | None = None
|
self._user_settings: _UsersDict | None = None
|
||||||
self._user_store = hass.helpers.storage.Store(
|
self._user_store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
self._include = config.get(CONF_INCLUDE, [])
|
self._include = config.get(CONF_INCLUDE, [])
|
||||||
self._exclude = config.get(CONF_EXCLUDE, [])
|
self._exclude = config.get(CONF_EXCLUDE, [])
|
||||||
|
|
|
@ -77,7 +77,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||||
super().__init__(hass, config)
|
super().__init__(hass, config)
|
||||||
self._users: dict[str, str] | None = None
|
self._users: dict[str, str] | None = None
|
||||||
self._user_store = hass.helpers.storage.Store(
|
self._user_store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Data:
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store = hass.helpers.storage.Store(
|
self._store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
self._data: dict[str, Any] | None = None
|
self._data: dict[str, Any] | None = None
|
||||||
# Legacy mode will allow usernames to start/end with whitespace
|
# Legacy mode will allow usernames to start/end with whitespace
|
||||||
|
|
|
@ -11,7 +11,7 @@ from homeassistant.const import CONF_ID, EVENT_COMPONENT_LOADED
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.setup import ATTR_COMPONENT
|
from homeassistant.setup import ATTR_COMPONENT
|
||||||
from homeassistant.util.file import write_utf8_file
|
from homeassistant.util.file import write_utf8_file_atomic
|
||||||
from homeassistant.util.yaml import dump, load_yaml
|
from homeassistant.util.yaml import dump, load_yaml
|
||||||
|
|
||||||
DOMAIN = "config"
|
DOMAIN = "config"
|
||||||
|
@ -254,4 +254,4 @@ def _write(path, data):
|
||||||
# Do it before opening file. If dump causes error it will now not
|
# Do it before opening file. If dump causes error it will now not
|
||||||
# truncate the file.
|
# truncate the file.
|
||||||
contents = dump(data)
|
contents = dump(data)
|
||||||
write_utf8_file(path, contents)
|
write_utf8_file_atomic(path, contents)
|
||||||
|
|
|
@ -25,7 +25,9 @@ class Network:
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the Network class."""
|
"""Initialize the Network class."""
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(
|
||||||
|
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||||
|
)
|
||||||
self._data: dict[str, Any] = {}
|
self._data: dict[str, Any] = {}
|
||||||
self.adapters: list[Adapter] = []
|
self.adapters: list[Adapter] = []
|
||||||
|
|
||||||
|
|
|
@ -1715,7 +1715,7 @@ class Config:
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
"""Load [homeassistant] core config."""
|
"""Load [homeassistant] core config."""
|
||||||
store = self.hass.helpers.storage.Store(
|
store = self.hass.helpers.storage.Store(
|
||||||
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True
|
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (data := await store.async_load()):
|
if not (data := await store.async_load()):
|
||||||
|
@ -1763,7 +1763,7 @@ class Config:
|
||||||
}
|
}
|
||||||
|
|
||||||
store = self.hass.helpers.storage.Store(
|
store = self.hass.helpers.storage.Store(
|
||||||
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True
|
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True, atomic_writes=True
|
||||||
)
|
)
|
||||||
await store.async_save(data)
|
await store.async_save(data)
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,9 @@ class AreaRegistry:
|
||||||
"""Initialize the area registry."""
|
"""Initialize the area registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.areas: MutableMapping[str, AreaEntry] = {}
|
self.areas: MutableMapping[str, AreaEntry] = {}
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(
|
||||||
|
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||||
|
)
|
||||||
self._normalized_name_area_idx: dict[str, str] = {}
|
self._normalized_name_area_idx: dict[str, str] = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -162,7 +162,9 @@ class DeviceRegistry:
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the device registry."""
|
"""Initialize the device registry."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(
|
||||||
|
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||||
|
)
|
||||||
self._clear_index()
|
self._clear_index()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -155,7 +155,9 @@ class EntityRegistry:
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entities: dict[str, RegistryEntry]
|
self.entities: dict[str, RegistryEntry]
|
||||||
self._index: dict[tuple[str, str, str], str] = {}
|
self._index: dict[tuple[str, str, str], str] = {}
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(
|
||||||
|
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
|
||||||
|
)
|
||||||
self.hass.bus.async_listen(
|
self.hass.bus.async_listen(
|
||||||
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
|
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
|
||||||
)
|
)
|
||||||
|
|
|
@ -76,6 +76,7 @@ class Store:
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
*,
|
*,
|
||||||
encoder: type[JSONEncoder] | None = None,
|
encoder: type[JSONEncoder] | None = None,
|
||||||
|
atomic_writes: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize storage class."""
|
"""Initialize storage class."""
|
||||||
self.version = version
|
self.version = version
|
||||||
|
@ -88,6 +89,7 @@ class Store:
|
||||||
self._write_lock = asyncio.Lock()
|
self._write_lock = asyncio.Lock()
|
||||||
self._load_task: asyncio.Future | None = None
|
self._load_task: asyncio.Future | None = None
|
||||||
self._encoder = encoder
|
self._encoder = encoder
|
||||||
|
self._atomic_writes = atomic_writes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self):
|
||||||
|
@ -238,7 +240,13 @@ class Store:
|
||||||
os.makedirs(os.path.dirname(path))
|
os.makedirs(os.path.dirname(path))
|
||||||
|
|
||||||
_LOGGER.debug("Writing data for %s to %s", self.key, path)
|
_LOGGER.debug("Writing data for %s to %s", self.key, path)
|
||||||
json_util.save_json(path, data, self._private, encoder=self._encoder)
|
json_util.save_json(
|
||||||
|
path,
|
||||||
|
data,
|
||||||
|
self._private,
|
||||||
|
encoder=self._encoder,
|
||||||
|
atomic_writes=self._atomic_writes,
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_migrate_func(self, old_version, old_data):
|
async def _async_migrate_func(self, old_version, old_data):
|
||||||
"""Migrate to the new version."""
|
"""Migrate to the new version."""
|
||||||
|
|
|
@ -6,6 +6,7 @@ aiohttp_cors==0.7.0
|
||||||
astral==2.2
|
astral==2.2
|
||||||
async-upnp-client==0.22.12
|
async-upnp-client==0.22.12
|
||||||
async_timeout==4.0.0
|
async_timeout==4.0.0
|
||||||
|
atomicwrites==1.4.0
|
||||||
attrs==21.2.0
|
attrs==21.2.0
|
||||||
awesomeversion==21.10.1
|
awesomeversion==21.10.1
|
||||||
backports.zoneinfo;python_version<"3.9"
|
backports.zoneinfo;python_version<"3.9"
|
||||||
|
|
|
@ -5,6 +5,8 @@ import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
from atomicwrites import AtomicWriter
|
||||||
|
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -14,6 +16,33 @@ class WriteError(HomeAssistantError):
|
||||||
"""Error writing the data."""
|
"""Error writing the data."""
|
||||||
|
|
||||||
|
|
||||||
|
def write_utf8_file_atomic(
|
||||||
|
filename: str,
|
||||||
|
utf8_data: str,
|
||||||
|
private: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Write a file and rename it into place using atomicwrites.
|
||||||
|
|
||||||
|
Writes all or nothing.
|
||||||
|
|
||||||
|
This function uses fsync under the hood. It should
|
||||||
|
only be used to write mission critical files as
|
||||||
|
fsync can block for a few seconds or longer is the
|
||||||
|
disk is busy.
|
||||||
|
|
||||||
|
Using this function frequently will significantly
|
||||||
|
negatively impact performance.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with AtomicWriter(filename, overwrite=True).open() as fdesc:
|
||||||
|
if not private:
|
||||||
|
os.fchmod(fdesc.fileno(), 0o644)
|
||||||
|
fdesc.write(utf8_data)
|
||||||
|
except OSError as error:
|
||||||
|
_LOGGER.exception("Saving file failed: %s", filename)
|
||||||
|
raise WriteError(error) from error
|
||||||
|
|
||||||
|
|
||||||
def write_utf8_file(
|
def write_utf8_file(
|
||||||
filename: str,
|
filename: str,
|
||||||
utf8_data: str,
|
utf8_data: str,
|
||||||
|
@ -34,7 +63,7 @@ def write_utf8_file(
|
||||||
fdesc.write(utf8_data)
|
fdesc.write(utf8_data)
|
||||||
tmp_filename = fdesc.name
|
tmp_filename = fdesc.name
|
||||||
if not private:
|
if not private:
|
||||||
os.chmod(tmp_filename, 0o644)
|
os.fchmod(fdesc.fileno(), 0o644)
|
||||||
os.replace(tmp_filename, filename)
|
os.replace(tmp_filename, filename)
|
||||||
except OSError as error:
|
except OSError as error:
|
||||||
_LOGGER.exception("Saving file failed: %s", filename)
|
_LOGGER.exception("Saving file failed: %s", filename)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .file import write_utf8_file
|
from .file import write_utf8_file, write_utf8_file_atomic
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ def save_json(
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
*,
|
*,
|
||||||
encoder: type[json.JSONEncoder] | None = None,
|
encoder: type[json.JSONEncoder] | None = None,
|
||||||
|
atomic_writes: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save JSON data to a file.
|
"""Save JSON data to a file.
|
||||||
|
|
||||||
|
@ -61,6 +62,9 @@ def save_json(
|
||||||
_LOGGER.error(msg)
|
_LOGGER.error(msg)
|
||||||
raise SerializationError(msg) from error
|
raise SerializationError(msg) from error
|
||||||
|
|
||||||
|
if atomic_writes:
|
||||||
|
write_utf8_file_atomic(filename, json_data, private)
|
||||||
|
else:
|
||||||
write_utf8_file(filename, json_data, private)
|
write_utf8_file(filename, json_data, private)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ aiohttp==3.8.0
|
||||||
astral==2.2
|
astral==2.2
|
||||||
async_timeout==4.0.0
|
async_timeout==4.0.0
|
||||||
attrs==21.2.0
|
attrs==21.2.0
|
||||||
|
atomicwrites==1.4.0
|
||||||
awesomeversion==21.10.1
|
awesomeversion==21.10.1
|
||||||
backports.zoneinfo;python_version<"3.9"
|
backports.zoneinfo;python_version<"3.9"
|
||||||
bcrypt==3.1.7
|
bcrypt==3.1.7
|
||||||
|
|
|
@ -31,6 +31,7 @@ responses==0.12.0
|
||||||
respx==0.17.0
|
respx==0.17.0
|
||||||
stdlib-list==0.7.0
|
stdlib-list==0.7.0
|
||||||
tqdm==4.49.0
|
tqdm==4.49.0
|
||||||
|
types-atomicwrites==1.4.1
|
||||||
types-croniter==1.0.0
|
types-croniter==1.0.0
|
||||||
types-backports==0.1.3
|
types-backports==0.1.3
|
||||||
types-certifi==0.1.4
|
types-certifi==0.1.4
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -36,6 +36,7 @@ REQUIRES = [
|
||||||
"astral==2.2",
|
"astral==2.2",
|
||||||
"async_timeout==4.0.0",
|
"async_timeout==4.0.0",
|
||||||
"attrs==21.2.0",
|
"attrs==21.2.0",
|
||||||
|
"atomicwrites==1.4.0",
|
||||||
"awesomeversion==21.10.1",
|
"awesomeversion==21.10.1",
|
||||||
'backports.zoneinfo;python_version<"3.9"',
|
'backports.zoneinfo;python_version<"3.9"',
|
||||||
"bcrypt==3.1.7",
|
"bcrypt==3.1.7",
|
||||||
|
|
|
@ -5,20 +5,21 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.util.file import WriteError, write_utf8_file
|
from homeassistant.util.file import WriteError, write_utf8_file, write_utf8_file_atomic
|
||||||
|
|
||||||
|
|
||||||
def test_write_utf8_file_private(tmpdir):
|
@pytest.mark.parametrize("func", [write_utf8_file, write_utf8_file_atomic])
|
||||||
|
def test_write_utf8_file_atomic_private(tmpdir, func):
|
||||||
"""Test files can be written as 0o600 or 0o644."""
|
"""Test files can be written as 0o600 or 0o644."""
|
||||||
test_dir = tmpdir.mkdir("files")
|
test_dir = tmpdir.mkdir("files")
|
||||||
test_file = Path(test_dir / "test.json")
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
func(test_file, '{"some":"data"}', False)
|
||||||
with open(test_file) as fh:
|
with open(test_file) as fh:
|
||||||
assert fh.read() == '{"some":"data"}'
|
assert fh.read() == '{"some":"data"}'
|
||||||
assert os.stat(test_file).st_mode & 0o777 == 0o644
|
assert os.stat(test_file).st_mode & 0o777 == 0o644
|
||||||
|
|
||||||
write_utf8_file(test_file, '{"some":"data"}', True)
|
func(test_file, '{"some":"data"}', True)
|
||||||
with open(test_file) as fh:
|
with open(test_file) as fh:
|
||||||
assert fh.read() == '{"some":"data"}'
|
assert fh.read() == '{"some":"data"}'
|
||||||
assert os.stat(test_file).st_mode & 0o777 == 0o600
|
assert os.stat(test_file).st_mode & 0o777 == 0o600
|
||||||
|
@ -63,3 +64,16 @@ def test_write_utf8_file_fails_at_rename_and_remove(tmpdir, caplog):
|
||||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
assert "File replacement cleanup failed" in caplog.text
|
assert "File replacement cleanup failed" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_atomic_fails(tmpdir):
|
||||||
|
"""Test OSError from write_utf8_file_atomic is rethrown as WriteError."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(WriteError), patch(
|
||||||
|
"homeassistant.util.file.AtomicWriter.open", side_effect=OSError
|
||||||
|
):
|
||||||
|
write_utf8_file_atomic(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert not os.path.exists(test_file)
|
||||||
|
|
|
@ -67,11 +67,12 @@ def test_save_and_load_private():
|
||||||
assert stats.st_mode & 0o77 == 0
|
assert stats.st_mode & 0o77 == 0
|
||||||
|
|
||||||
|
|
||||||
def test_overwrite_and_reload():
|
@pytest.mark.parametrize("atomic_writes", [True, False])
|
||||||
|
def test_overwrite_and_reload(atomic_writes):
|
||||||
"""Test that we can overwrite an existing file and read back."""
|
"""Test that we can overwrite an existing file and read back."""
|
||||||
fname = _path_for("test3")
|
fname = _path_for("test3")
|
||||||
save_json(fname, TEST_JSON_A)
|
save_json(fname, TEST_JSON_A, atomic_writes=atomic_writes)
|
||||||
save_json(fname, TEST_JSON_B)
|
save_json(fname, TEST_JSON_B, atomic_writes=atomic_writes)
|
||||||
data = load_json(fname)
|
data = load_json(fname)
|
||||||
assert data == TEST_JSON_B
|
assert data == TEST_JSON_B
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue