Add handler to restore a backup file with the backup integration (#128365)

* Early pushout of restore handling for core/container

* Adjust after rebase

* Move logging definition, we should only do this if we go ahead with the restore

* First round

* More paths

* Add async_restore_backup to base class

* Block restore of new backup files

* manager tests

* Add websocket test

* Add testing to main

* Add coverage for missing backup file

* Catch FileNotFoundError instead

* Patch Path.read_text instead

* Remove HA_RESTORE from keep

* Use secure paths

* Fix restart test

* extend coverage

* Mock argv

* Adjustments
This commit is contained in:
Joakim Sørensen 2024-11-01 16:25:22 +01:00 committed by GitHub
parent 4da93f6a5e
commit 31dcc25ba5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 481 additions and 1 deletions

View file

@ -9,6 +9,7 @@ import os
import sys
import threading
from .backup_restore import restore_backup
from .const import REQUIRED_PYTHON_VER, RESTART_EXIT_CODE, __version__
FAULT_LOG_FILENAME = "home-assistant.log.fault"
@ -182,6 +183,9 @@ def main() -> int:
return scripts.run(args.script)
config_dir = os.path.abspath(os.path.join(os.getcwd(), args.config))
if restore_backup(config_dir):
return RESTART_EXIT_CODE
ensure_config_path(config_dir)
# pylint: disable-next=import-outside-toplevel

View file

@ -0,0 +1,126 @@
"""Home Assistant module to handle restoring backups."""
from dataclasses import dataclass
import json
import logging
from pathlib import Path
import shutil
import sys
from tempfile import TemporaryDirectory
from awesomeversion import AwesomeVersion
import securetar
from .const import __version__ as HA_VERSION
RESTORE_BACKUP_FILE = ".HA_RESTORE"
KEEP_PATHS = ("backups",)
_LOGGER = logging.getLogger(__name__)
@dataclass
class RestoreBackupFileContent:
"""Definition for restore backup file content."""
backup_file_path: Path
def restore_backup_file_content(config_dir: Path) -> RestoreBackupFileContent | None:
"""Return the contents of the restore backup file."""
instruction_path = config_dir.joinpath(RESTORE_BACKUP_FILE)
try:
instruction_content = instruction_path.read_text(encoding="utf-8")
return RestoreBackupFileContent(
backup_file_path=Path(instruction_content.split(";")[0])
)
except FileNotFoundError:
return None
def _clear_configuration_directory(config_dir: Path) -> None:
"""Delete all files and directories in the config directory except for the backups directory."""
keep_paths = [config_dir.joinpath(path) for path in KEEP_PATHS]
config_contents = sorted(
[entry for entry in config_dir.iterdir() if entry not in keep_paths]
)
for entry in config_contents:
entrypath = config_dir.joinpath(entry)
if entrypath.is_file():
entrypath.unlink()
elif entrypath.is_dir():
shutil.rmtree(entrypath)
def _extract_backup(config_dir: Path, backup_file_path: Path) -> None:
"""Extract the backup file to the config directory."""
with (
TemporaryDirectory() as tempdir,
securetar.SecureTarFile(
backup_file_path,
gzip=False,
mode="r",
) as ostf,
):
ostf.extractall(
path=Path(tempdir, "extracted"),
members=securetar.secure_path(ostf),
filter="fully_trusted",
)
backup_meta_file = Path(tempdir, "extracted", "backup.json")
backup_meta = json.loads(backup_meta_file.read_text(encoding="utf8"))
if (
backup_meta_version := AwesomeVersion(
backup_meta["homeassistant"]["version"]
)
) > HA_VERSION:
raise ValueError(
f"You need at least Home Assistant version {backup_meta_version} to restore this backup"
)
with securetar.SecureTarFile(
Path(
tempdir,
"extracted",
f"homeassistant.tar{'.gz' if backup_meta["compressed"] else ''}",
),
gzip=backup_meta["compressed"],
mode="r",
) as istf:
for member in istf.getmembers():
if member.name == "data":
continue
member.name = member.name.replace("data/", "")
_clear_configuration_directory(config_dir)
istf.extractall(
path=config_dir,
members=[
member
for member in securetar.secure_path(istf)
if member.name != "data"
],
filter="fully_trusted",
)
def restore_backup(config_dir_path: str) -> bool:
"""Restore the backup file if any.
Returns True if a restore backup file was found and restored, False otherwise.
"""
config_dir = Path(config_dir_path)
if not (restore_content := restore_backup_file_content(config_dir)):
return False
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
backup_file_path = restore_content.backup_file_path
_LOGGER.info("Restoring %s", backup_file_path)
try:
_extract_backup(config_dir, backup_file_path)
except FileNotFoundError as err:
raise ValueError(f"Backup file {backup_file_path} does not exist") from err
_LOGGER.info("Restore complete, restarting")
return True

View file

@ -17,6 +17,7 @@ LOGGER = getLogger(__package__)
EXCLUDE_FROM_BACKUP = [
"__pycache__/*",
".DS_Store",
".HA_RESTORE",
"*.db-shm",
"*.log.*",
"*.log",

View file

@ -16,6 +16,7 @@ from typing import Any, Protocol, cast
from securetar import SecureTarFile, atomic_contents_add
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
from homeassistant.const import __version__ as HAVERSION
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@ -123,6 +124,10 @@ class BaseBackupManager(abc.ABC):
LOGGER.debug("Loaded %s platforms", len(self.platforms))
self.loaded_platforms = True
@abc.abstractmethod
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None:
"""Restpre a backup."""
@abc.abstractmethod
async def async_create_backup(self, **kwargs: Any) -> Backup:
"""Generate a backup."""
@ -291,6 +296,25 @@ class BackupManager(BaseBackupManager):
return tar_file_path.stat().st_size
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None:
"""Restore a backup.
This will write the restore information to .HA_RESTORE which
will be handled during startup by the restore_backup module.
"""
if (backup := await self.async_get_backup(slug=slug)) is None:
raise HomeAssistantError(f"Backup {slug} not found")
def _write_restore_file() -> None:
"""Write the restore file."""
Path(self.hass.config.path(RESTORE_BACKUP_FILE)).write_text(
f"{backup.path.as_posix()};",
encoding="utf-8",
)
await self.hass.async_add_executor_job(_write_restore_file)
await self.hass.services.async_call("homeassistant", "restart", {})
def _generate_slug(date: str, name: str) -> str:
"""Generate a backup slug."""

View file

@ -22,6 +22,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) ->
websocket_api.async_register_command(hass, handle_info)
websocket_api.async_register_command(hass, handle_create)
websocket_api.async_register_command(hass, handle_remove)
websocket_api.async_register_command(hass, handle_restore)
@websocket_api.require_admin
@ -85,6 +86,24 @@ async def handle_remove(
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "backup/restore",
vol.Required("slug"): str,
}
)
@websocket_api.async_response
async def handle_restore(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Restore a backup."""
await hass.data[DATA_MANAGER].async_restore_backup(msg["slug"])
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "backup/generate"})
@websocket_api.async_response

View file

@ -57,6 +57,7 @@ PyTurboJPEG==1.7.5
pyudev==0.24.1
PyYAML==6.0.2
requests==2.32.3
securetar==2024.2.1
SQLAlchemy==2.0.31
typing-extensions>=4.12.2,<5.0
ulid-transform==1.0.2

View file

@ -63,6 +63,7 @@ dependencies = [
"python-slugify==8.0.4",
"PyYAML==6.0.2",
"requests==2.32.3",
"securetar==2024.2.1",
"SQLAlchemy==2.0.31",
"typing-extensions>=4.12.2,<5.0",
"ulid-transform==1.0.2",

View file

@ -35,6 +35,7 @@ psutil-home-assistant==0.0.1
python-slugify==8.0.4
PyYAML==6.0.2
requests==2.32.3
securetar==2024.2.1
SQLAlchemy==2.0.31
typing-extensions>=4.12.2,<5.0
ulid-transform==1.0.2

View file

@ -269,3 +269,22 @@
'type': 'result',
})
# ---
# name: test_restore[with_hassio]
dict({
'error': dict({
'code': 'unknown_command',
'message': 'Unknown command.',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_restore[without_hassio]
dict({
'id': 1,
'result': None,
'success': True,
'type': 'result',
})
# ---

View file

@ -333,3 +333,31 @@ async def test_loading_platforms_when_running_async_post_backup_actions(
assert len(manager.platforms) == 1
assert "Loaded 1 platforms" in caplog.text
async def test_async_trigger_restore(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test trigger restore."""
manager = BackupManager(hass)
manager.loaded_backups = True
manager.backups = {TEST_BACKUP.slug: TEST_BACKUP}
with (
patch("pathlib.Path.exists", return_value=True),
patch("pathlib.Path.write_text") as mocked_write_text,
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
):
await manager.async_restore_backup(TEST_BACKUP.slug)
assert mocked_write_text.call_args[0][0] == "abc123.tar;"
assert mocked_service_call.called
async def test_async_trigger_restore_missing_backup(hass: HomeAssistant) -> None:
"""Test trigger restore."""
manager = BackupManager(hass)
manager.loaded_backups = True
with pytest.raises(HomeAssistantError, match="Backup abc123 not found"):
await manager.async_restore_backup(TEST_BACKUP.slug)

View file

@ -141,6 +141,32 @@ async def test_generate(
assert snapshot == await client.receive_json()
@pytest.mark.parametrize(
"with_hassio",
[
pytest.param(True, id="with_hassio"),
pytest.param(False, id="without_hassio"),
],
)
async def test_restore(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
with_hassio: bool,
) -> None:
"""Test calling the restore command."""
await setup_backup_integration(hass, with_hassio=with_hassio)
client = await hass_ws_client(hass)
await hass.async_block_till_done()
with patch(
"homeassistant.components.backup.manager.BackupManager.async_restore_backup",
):
await client.send_json_auto_id({"type": "backup/restore", "slug": "abc123"})
assert await client.receive_json() == snapshot
@pytest.mark.parametrize(
"access_token_fixture_name",
["hass_access_token", "hass_supervisor_access_token"],

View file

@ -0,0 +1,220 @@
"""Test methods in backup_restore."""
from pathlib import Path
import tarfile
from unittest import mock
import pytest
from homeassistant import backup_restore
from .common import get_test_config_dir
@pytest.mark.parametrize(
("side_effect", "content", "expected"),
[
(FileNotFoundError, "", None),
(None, "", backup_restore.RestoreBackupFileContent(backup_file_path=Path(""))),
(
None,
"test;",
backup_restore.RestoreBackupFileContent(backup_file_path=Path("test")),
),
(
None,
"test;;;;",
backup_restore.RestoreBackupFileContent(backup_file_path=Path("test")),
),
],
)
def test_reading_the_instruction_contents(
side_effect: Exception | None,
content: str,
expected: backup_restore.RestoreBackupFileContent | None,
) -> None:
"""Test reading the content of the .HA_RESTORE file."""
with (
mock.patch(
"pathlib.Path.read_text",
return_value=content,
side_effect=side_effect,
),
):
read_content = backup_restore.restore_backup_file_content(
Path(get_test_config_dir())
)
assert read_content == expected
def test_restoring_backup_that_does_not_exist() -> None:
"""Test restoring a backup that does not exist."""
backup_file_path = Path(get_test_config_dir("backups", "test"))
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path
),
),
mock.patch("pathlib.Path.read_text", side_effect=FileNotFoundError),
pytest.raises(
ValueError, match=f"Backup file {backup_file_path} does not exist"
),
):
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
def test_restoring_backup_when_instructions_can_not_be_read() -> None:
"""Test restoring a backup when instructions can not be read."""
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=None,
),
):
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
def test_restoring_backup_that_is_not_a_file() -> None:
"""Test restoring a backup that is not a file."""
backup_file_path = Path(get_test_config_dir("backups", "test"))
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path
),
),
mock.patch("pathlib.Path.exists", return_value=True),
mock.patch("pathlib.Path.is_file", return_value=False),
pytest.raises(
ValueError, match=f"Backup file {backup_file_path} does not exist"
),
):
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
def test_aborting_for_older_versions() -> None:
"""Test that we abort for older versions."""
config_dir = Path(get_test_config_dir())
backup_file_path = Path(config_dir, "backups", "test.tar")
def _patched_path_read_text(path: Path, **kwargs):
return '{"homeassistant": {"version": "9999.99.99"}, "compressed": false}'
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path
),
),
mock.patch("securetar.SecureTarFile"),
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
mock.patch("homeassistant.backup_restore.HA_VERSION", "2013.09.17"),
pytest.raises(
ValueError,
match="You need at least Home Assistant version 9999.99.99 to restore this backup",
),
):
assert backup_restore.restore_backup(config_dir) is True
def test_removal_of_current_configuration_when_restoring() -> None:
"""Test that we are removing the current configuration directory."""
config_dir = Path(get_test_config_dir())
backup_file_path = Path(config_dir, "backups", "test.tar")
mock_config_dir = [
{"path": Path(config_dir, ".HA_RESTORE"), "is_file": True},
{"path": Path(config_dir, ".HA_VERSION"), "is_file": True},
{"path": Path(config_dir, "backups"), "is_file": False},
{"path": Path(config_dir, "www"), "is_file": False},
]
def _patched_path_read_text(path: Path, **kwargs):
return '{"homeassistant": {"version": "2013.09.17"}, "compressed": false}'
def _patched_path_is_file(path: Path, **kwargs):
return [x for x in mock_config_dir if x["path"] == path][0]["is_file"]
def _patched_path_is_dir(path: Path, **kwargs):
return not [x for x in mock_config_dir if x["path"] == path][0]["is_file"]
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path
),
),
mock.patch("securetar.SecureTarFile"),
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
mock.patch("homeassistant.backup_restore.HA_VERSION", "2013.09.17"),
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
mock.patch("pathlib.Path.is_file", _patched_path_is_file),
mock.patch("pathlib.Path.is_dir", _patched_path_is_dir),
mock.patch(
"pathlib.Path.iterdir",
return_value=[x["path"] for x in mock_config_dir],
),
mock.patch("pathlib.Path.unlink") as unlink_mock,
mock.patch("shutil.rmtree") as rmtreemock,
):
assert backup_restore.restore_backup(config_dir) is True
assert unlink_mock.call_count == 2
assert (
rmtreemock.call_count == 1
) # We have 2 directories in the config directory, but backups is kept
removed_directories = {Path(call.args[0]) for call in rmtreemock.mock_calls}
assert removed_directories == {Path(config_dir, "www")}
def test_extracting_the_contents_of_a_backup_file() -> None:
"""Test extracting the contents of a backup file."""
config_dir = Path(get_test_config_dir())
backup_file_path = Path(config_dir, "backups", "test.tar")
def _patched_path_read_text(path: Path, **kwargs):
return '{"homeassistant": {"version": "2013.09.17"}, "compressed": false}'
getmembers_mock = mock.MagicMock(
return_value=[
tarfile.TarInfo(name="data"),
tarfile.TarInfo(name="data/../test"),
tarfile.TarInfo(name="data/.HA_VERSION"),
tarfile.TarInfo(name="data/.storage"),
tarfile.TarInfo(name="data/www"),
]
)
extractall_mock = mock.MagicMock()
with (
mock.patch(
"homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path
),
),
mock.patch(
"tarfile.open",
return_value=mock.MagicMock(
getmembers=getmembers_mock,
extractall=extractall_mock,
__iter__=lambda x: iter(getmembers_mock.return_value),
),
),
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
mock.patch("pathlib.Path.is_file", return_value=False),
mock.patch("pathlib.Path.iterdir", return_value=[]),
):
assert backup_restore.restore_backup(config_dir) is True
assert getmembers_mock.call_count == 1
assert extractall_mock.call_count == 2
assert {
member.name for member in extractall_mock.mock_calls[-1].kwargs["members"]
} == {".HA_VERSION", ".storage", "www"}

View file

@ -3,7 +3,7 @@
from unittest.mock import PropertyMock, patch
from homeassistant import __main__ as main
from homeassistant.const import REQUIRED_PYTHON_VER
from homeassistant.const import REQUIRED_PYTHON_VER, RESTART_EXIT_CODE
@patch("sys.exit")
@ -86,3 +86,13 @@ def test_skip_pip_mutually_exclusive(mock_exit) -> None:
assert mock_exit.called is False
args = parse_args("--skip-pip", "--skip-pip-packages", "foo")
assert mock_exit.called is True
def test_restart_after_backup_restore() -> None:
"""Test restarting if we restored a backup."""
with (
patch("sys.argv", ["python"]),
patch("homeassistant.__main__.restore_backup", return_value=True),
):
exit_code = main.main()
assert exit_code == RESTART_EXIT_CODE