Use atomicwrites for mission critical core files (#59606)

This commit is contained in:
J. Nick Koston 2021-11-15 04:19:31 -06:00 committed by GitHub
parent 04a258bf21
commit 96f7b0d910
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 92 additions and 24 deletions

View file

@ -42,7 +42,7 @@ class AuthStore:
self._groups: dict[str, models.Group] | None = None self._groups: dict[str, models.Group] | None = None
self._perm_lookup: PermissionLookup | None = None self._perm_lookup: PermissionLookup | None = None
self._store = hass.helpers.storage.Store( self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._lock = asyncio.Lock() self._lock = asyncio.Lock()

View file

@ -100,7 +100,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
super().__init__(hass, config) super().__init__(hass, config)
self._user_settings: _UsersDict | None = None self._user_settings: _UsersDict | None = None
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._include = config.get(CONF_INCLUDE, []) self._include = config.get(CONF_INCLUDE, [])
self._exclude = config.get(CONF_EXCLUDE, []) self._exclude = config.get(CONF_EXCLUDE, [])

View file

@ -77,7 +77,7 @@ class TotpAuthModule(MultiFactorAuthModule):
super().__init__(hass, config) super().__init__(hass, config)
self._users: dict[str, str] | None = None self._users: dict[str, str] | None = None
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()

View file

@ -63,7 +63,7 @@ class Data:
"""Initialize the user data store.""" """Initialize the user data store."""
self.hass = hass self.hass = hass
self._store = hass.helpers.storage.Store( self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
) )
self._data: dict[str, Any] | None = None self._data: dict[str, Any] | None = None
# Legacy mode will allow usernames to start/end with whitespace # Legacy mode will allow usernames to start/end with whitespace

View file

@ -11,7 +11,7 @@ from homeassistant.const import CONF_ID, EVENT_COMPONENT_LOADED
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import ATTR_COMPONENT from homeassistant.setup import ATTR_COMPONENT
from homeassistant.util.file import write_utf8_file from homeassistant.util.file import write_utf8_file_atomic
from homeassistant.util.yaml import dump, load_yaml from homeassistant.util.yaml import dump, load_yaml
DOMAIN = "config" DOMAIN = "config"
@ -254,4 +254,4 @@ def _write(path, data):
# Do it before opening file. If dump causes error it will now not # Do it before opening file. If dump causes error it will now not
# truncate the file. # truncate the file.
contents = dump(data) contents = dump(data)
write_utf8_file(path, contents) write_utf8_file_atomic(path, contents)

View file

@ -25,7 +25,9 @@ class Network:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the Network class.""" """Initialize the Network class."""
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self._data: dict[str, Any] = {} self._data: dict[str, Any] = {}
self.adapters: list[Adapter] = [] self.adapters: list[Adapter] = []

View file

@ -1715,7 +1715,7 @@ class Config:
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load [homeassistant] core config.""" """Load [homeassistant] core config."""
store = self.hass.helpers.storage.Store( store = self.hass.helpers.storage.Store(
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True, atomic_writes=True
) )
if not (data := await store.async_load()): if not (data := await store.async_load()):
@ -1763,7 +1763,7 @@ class Config:
} }
store = self.hass.helpers.storage.Store( store = self.hass.helpers.storage.Store(
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True, atomic_writes=True
) )
await store.async_save(data) await store.async_save(data)

View file

@ -49,7 +49,9 @@ class AreaRegistry:
"""Initialize the area registry.""" """Initialize the area registry."""
self.hass = hass self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {} self.areas: MutableMapping[str, AreaEntry] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self._normalized_name_area_idx: dict[str, str] = {} self._normalized_name_area_idx: dict[str, str] = {}
@callback @callback

View file

@ -162,7 +162,9 @@ class DeviceRegistry:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the device registry.""" """Initialize the device registry."""
self.hass = hass self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self._clear_index() self._clear_index()
@callback @callback

View file

@ -155,7 +155,9 @@ class EntityRegistry:
self.hass = hass self.hass = hass
self.entities: dict[str, RegistryEntry] self.entities: dict[str, RegistryEntry]
self._index: dict[tuple[str, str, str], str] = {} self._index: dict[tuple[str, str, str], str] = {}
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, atomic_writes=True
)
self.hass.bus.async_listen( self.hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified EVENT_DEVICE_REGISTRY_UPDATED, self.async_device_modified
) )

View file

@ -76,6 +76,7 @@ class Store:
private: bool = False, private: bool = False,
*, *,
encoder: type[JSONEncoder] | None = None, encoder: type[JSONEncoder] | None = None,
atomic_writes: bool = False,
) -> None: ) -> None:
"""Initialize storage class.""" """Initialize storage class."""
self.version = version self.version = version
@ -88,6 +89,7 @@ class Store:
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._load_task: asyncio.Future | None = None self._load_task: asyncio.Future | None = None
self._encoder = encoder self._encoder = encoder
self._atomic_writes = atomic_writes
@property @property
def path(self): def path(self):
@ -238,7 +240,13 @@ class Store:
os.makedirs(os.path.dirname(path)) os.makedirs(os.path.dirname(path))
_LOGGER.debug("Writing data for %s to %s", self.key, path) _LOGGER.debug("Writing data for %s to %s", self.key, path)
json_util.save_json(path, data, self._private, encoder=self._encoder) json_util.save_json(
path,
data,
self._private,
encoder=self._encoder,
atomic_writes=self._atomic_writes,
)
async def _async_migrate_func(self, old_version, old_data): async def _async_migrate_func(self, old_version, old_data):
"""Migrate to the new version.""" """Migrate to the new version."""

View file

@ -6,6 +6,7 @@ aiohttp_cors==0.7.0
astral==2.2 astral==2.2
async-upnp-client==0.22.12 async-upnp-client==0.22.12
async_timeout==4.0.0 async_timeout==4.0.0
atomicwrites==1.4.0
attrs==21.2.0 attrs==21.2.0
awesomeversion==21.10.1 awesomeversion==21.10.1
backports.zoneinfo;python_version<"3.9" backports.zoneinfo;python_version<"3.9"

View file

@ -5,6 +5,8 @@ import logging
import os import os
import tempfile import tempfile
from atomicwrites import AtomicWriter
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -14,6 +16,33 @@ class WriteError(HomeAssistantError):
"""Error writing the data.""" """Error writing the data."""
def write_utf8_file_atomic(
filename: str,
utf8_data: str,
private: bool = False,
) -> None:
"""Write a file and rename it into place using atomicwrites.
Writes all or nothing.
This function uses fsync under the hood. It should
only be used to write mission critical files as
fsync can block for a few seconds or longer is the
disk is busy.
Using this function frequently will significantly
negatively impact performance.
"""
try:
with AtomicWriter(filename, overwrite=True).open() as fdesc:
if not private:
os.fchmod(fdesc.fileno(), 0o644)
fdesc.write(utf8_data)
except OSError as error:
_LOGGER.exception("Saving file failed: %s", filename)
raise WriteError(error) from error
def write_utf8_file( def write_utf8_file(
filename: str, filename: str,
utf8_data: str, utf8_data: str,
@ -34,7 +63,7 @@ def write_utf8_file(
fdesc.write(utf8_data) fdesc.write(utf8_data)
tmp_filename = fdesc.name tmp_filename = fdesc.name
if not private: if not private:
os.chmod(tmp_filename, 0o644) os.fchmod(fdesc.fileno(), 0o644)
os.replace(tmp_filename, filename) os.replace(tmp_filename, filename)
except OSError as error: except OSError as error:
_LOGGER.exception("Saving file failed: %s", filename) _LOGGER.exception("Saving file failed: %s", filename)

View file

@ -10,7 +10,7 @@ from typing import Any
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from .file import write_utf8_file from .file import write_utf8_file, write_utf8_file_atomic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -49,6 +49,7 @@ def save_json(
private: bool = False, private: bool = False,
*, *,
encoder: type[json.JSONEncoder] | None = None, encoder: type[json.JSONEncoder] | None = None,
atomic_writes: bool = False,
) -> None: ) -> None:
"""Save JSON data to a file. """Save JSON data to a file.
@ -61,6 +62,9 @@ def save_json(
_LOGGER.error(msg) _LOGGER.error(msg)
raise SerializationError(msg) from error raise SerializationError(msg) from error
if atomic_writes:
write_utf8_file_atomic(filename, json_data, private)
else:
write_utf8_file(filename, json_data, private) write_utf8_file(filename, json_data, private)

View file

@ -5,6 +5,7 @@ aiohttp==3.8.0
astral==2.2 astral==2.2
async_timeout==4.0.0 async_timeout==4.0.0
attrs==21.2.0 attrs==21.2.0
atomicwrites==1.4.0
awesomeversion==21.10.1 awesomeversion==21.10.1
backports.zoneinfo;python_version<"3.9" backports.zoneinfo;python_version<"3.9"
bcrypt==3.1.7 bcrypt==3.1.7

View file

@ -31,6 +31,7 @@ responses==0.12.0
respx==0.17.0 respx==0.17.0
stdlib-list==0.7.0 stdlib-list==0.7.0
tqdm==4.49.0 tqdm==4.49.0
types-atomicwrites==1.4.1
types-croniter==1.0.0 types-croniter==1.0.0
types-backports==0.1.3 types-backports==0.1.3
types-certifi==0.1.4 types-certifi==0.1.4

View file

@ -36,6 +36,7 @@ REQUIRES = [
"astral==2.2", "astral==2.2",
"async_timeout==4.0.0", "async_timeout==4.0.0",
"attrs==21.2.0", "attrs==21.2.0",
"atomicwrites==1.4.0",
"awesomeversion==21.10.1", "awesomeversion==21.10.1",
'backports.zoneinfo;python_version<"3.9"', 'backports.zoneinfo;python_version<"3.9"',
"bcrypt==3.1.7", "bcrypt==3.1.7",

View file

@ -5,20 +5,21 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.util.file import WriteError, write_utf8_file from homeassistant.util.file import WriteError, write_utf8_file, write_utf8_file_atomic
def test_write_utf8_file_private(tmpdir): @pytest.mark.parametrize("func", [write_utf8_file, write_utf8_file_atomic])
def test_write_utf8_file_atomic_private(tmpdir, func):
"""Test files can be written as 0o600 or 0o644.""" """Test files can be written as 0o600 or 0o644."""
test_dir = tmpdir.mkdir("files") test_dir = tmpdir.mkdir("files")
test_file = Path(test_dir / "test.json") test_file = Path(test_dir / "test.json")
write_utf8_file(test_file, '{"some":"data"}', False) func(test_file, '{"some":"data"}', False)
with open(test_file) as fh: with open(test_file) as fh:
assert fh.read() == '{"some":"data"}' assert fh.read() == '{"some":"data"}'
assert os.stat(test_file).st_mode & 0o777 == 0o644 assert os.stat(test_file).st_mode & 0o777 == 0o644
write_utf8_file(test_file, '{"some":"data"}', True) func(test_file, '{"some":"data"}', True)
with open(test_file) as fh: with open(test_file) as fh:
assert fh.read() == '{"some":"data"}' assert fh.read() == '{"some":"data"}'
assert os.stat(test_file).st_mode & 0o777 == 0o600 assert os.stat(test_file).st_mode & 0o777 == 0o600
@ -63,3 +64,16 @@ def test_write_utf8_file_fails_at_rename_and_remove(tmpdir, caplog):
write_utf8_file(test_file, '{"some":"data"}', False) write_utf8_file(test_file, '{"some":"data"}', False)
assert "File replacement cleanup failed" in caplog.text assert "File replacement cleanup failed" in caplog.text
def test_write_utf8_file_atomic_fails(tmpdir):
"""Test OSError from write_utf8_file_atomic is rethrown as WriteError."""
test_dir = tmpdir.mkdir("files")
test_file = Path(test_dir / "test.json")
with pytest.raises(WriteError), patch(
"homeassistant.util.file.AtomicWriter.open", side_effect=OSError
):
write_utf8_file_atomic(test_file, '{"some":"data"}', False)
assert not os.path.exists(test_file)

View file

@ -67,11 +67,12 @@ def test_save_and_load_private():
assert stats.st_mode & 0o77 == 0 assert stats.st_mode & 0o77 == 0
def test_overwrite_and_reload(): @pytest.mark.parametrize("atomic_writes", [True, False])
def test_overwrite_and_reload(atomic_writes):
"""Test that we can overwrite an existing file and read back.""" """Test that we can overwrite an existing file and read back."""
fname = _path_for("test3") fname = _path_for("test3")
save_json(fname, TEST_JSON_A) save_json(fname, TEST_JSON_A, atomic_writes=atomic_writes)
save_json(fname, TEST_JSON_B) save_json(fname, TEST_JSON_B, atomic_writes=atomic_writes)
data = load_json(fname) data = load_json(fname)
assert data == TEST_JSON_B assert data == TEST_JSON_B