Add handler to restore a backup file with the backup integration (#128365)
* Early pushout of restore handling for core/container * Adjust after rebase * Move logging definition, we should only do this if we go ahead with the restore * First round * More paths * Add async_restore_backup to base class * Block restore of new backup files * manager tests * Add websocket test * Add testing to main * Add coverage for missing backup file * Catch FileNotFoundError instead * Patch Path.read_text instead * Remove HA_RESTORE from keep * Use secure paths * Fix restart test * extend coverage * Mock argv * Adjustments
This commit is contained in:
parent
4da93f6a5e
commit
31dcc25ba5
13 changed files with 481 additions and 1 deletions
|
@ -9,6 +9,7 @@ import os
|
|||
import sys
|
||||
import threading
|
||||
|
||||
from .backup_restore import restore_backup
|
||||
from .const import REQUIRED_PYTHON_VER, RESTART_EXIT_CODE, __version__
|
||||
|
||||
FAULT_LOG_FILENAME = "home-assistant.log.fault"
|
||||
|
@ -182,6 +183,9 @@ def main() -> int:
|
|||
return scripts.run(args.script)
|
||||
|
||||
config_dir = os.path.abspath(os.path.join(os.getcwd(), args.config))
|
||||
if restore_backup(config_dir):
|
||||
return RESTART_EXIT_CODE
|
||||
|
||||
ensure_config_path(config_dir)
|
||||
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
|
|
126
homeassistant/backup_restore.py
Normal file
126
homeassistant/backup_restore.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
"""Home Assistant module to handle restoring backups."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from awesomeversion import AwesomeVersion
|
||||
import securetar
|
||||
|
||||
from .const import __version__ as HA_VERSION
|
||||
|
||||
RESTORE_BACKUP_FILE = ".HA_RESTORE"
|
||||
KEEP_PATHS = ("backups",)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestoreBackupFileContent:
|
||||
"""Definition for restore backup file content."""
|
||||
|
||||
backup_file_path: Path
|
||||
|
||||
|
||||
def restore_backup_file_content(config_dir: Path) -> RestoreBackupFileContent | None:
|
||||
"""Return the contents of the restore backup file."""
|
||||
instruction_path = config_dir.joinpath(RESTORE_BACKUP_FILE)
|
||||
try:
|
||||
instruction_content = instruction_path.read_text(encoding="utf-8")
|
||||
return RestoreBackupFileContent(
|
||||
backup_file_path=Path(instruction_content.split(";")[0])
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def _clear_configuration_directory(config_dir: Path) -> None:
|
||||
"""Delete all files and directories in the config directory except for the backups directory."""
|
||||
keep_paths = [config_dir.joinpath(path) for path in KEEP_PATHS]
|
||||
config_contents = sorted(
|
||||
[entry for entry in config_dir.iterdir() if entry not in keep_paths]
|
||||
)
|
||||
|
||||
for entry in config_contents:
|
||||
entrypath = config_dir.joinpath(entry)
|
||||
|
||||
if entrypath.is_file():
|
||||
entrypath.unlink()
|
||||
elif entrypath.is_dir():
|
||||
shutil.rmtree(entrypath)
|
||||
|
||||
|
||||
def _extract_backup(config_dir: Path, backup_file_path: Path) -> None:
|
||||
"""Extract the backup file to the config directory."""
|
||||
with (
|
||||
TemporaryDirectory() as tempdir,
|
||||
securetar.SecureTarFile(
|
||||
backup_file_path,
|
||||
gzip=False,
|
||||
mode="r",
|
||||
) as ostf,
|
||||
):
|
||||
ostf.extractall(
|
||||
path=Path(tempdir, "extracted"),
|
||||
members=securetar.secure_path(ostf),
|
||||
filter="fully_trusted",
|
||||
)
|
||||
backup_meta_file = Path(tempdir, "extracted", "backup.json")
|
||||
backup_meta = json.loads(backup_meta_file.read_text(encoding="utf8"))
|
||||
|
||||
if (
|
||||
backup_meta_version := AwesomeVersion(
|
||||
backup_meta["homeassistant"]["version"]
|
||||
)
|
||||
) > HA_VERSION:
|
||||
raise ValueError(
|
||||
f"You need at least Home Assistant version {backup_meta_version} to restore this backup"
|
||||
)
|
||||
|
||||
with securetar.SecureTarFile(
|
||||
Path(
|
||||
tempdir,
|
||||
"extracted",
|
||||
f"homeassistant.tar{'.gz' if backup_meta["compressed"] else ''}",
|
||||
),
|
||||
gzip=backup_meta["compressed"],
|
||||
mode="r",
|
||||
) as istf:
|
||||
for member in istf.getmembers():
|
||||
if member.name == "data":
|
||||
continue
|
||||
member.name = member.name.replace("data/", "")
|
||||
_clear_configuration_directory(config_dir)
|
||||
istf.extractall(
|
||||
path=config_dir,
|
||||
members=[
|
||||
member
|
||||
for member in securetar.secure_path(istf)
|
||||
if member.name != "data"
|
||||
],
|
||||
filter="fully_trusted",
|
||||
)
|
||||
|
||||
|
||||
def restore_backup(config_dir_path: str) -> bool:
|
||||
"""Restore the backup file if any.
|
||||
|
||||
Returns True if a restore backup file was found and restored, False otherwise.
|
||||
"""
|
||||
config_dir = Path(config_dir_path)
|
||||
if not (restore_content := restore_backup_file_content(config_dir)):
|
||||
return False
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
backup_file_path = restore_content.backup_file_path
|
||||
_LOGGER.info("Restoring %s", backup_file_path)
|
||||
try:
|
||||
_extract_backup(config_dir, backup_file_path)
|
||||
except FileNotFoundError as err:
|
||||
raise ValueError(f"Backup file {backup_file_path} does not exist") from err
|
||||
_LOGGER.info("Restore complete, restarting")
|
||||
return True
|
|
@ -17,6 +17,7 @@ LOGGER = getLogger(__package__)
|
|||
EXCLUDE_FROM_BACKUP = [
|
||||
"__pycache__/*",
|
||||
".DS_Store",
|
||||
".HA_RESTORE",
|
||||
"*.db-shm",
|
||||
"*.log.*",
|
||||
"*.log",
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import Any, Protocol, cast
|
|||
|
||||
from securetar import SecureTarFile, atomic_contents_add
|
||||
|
||||
from homeassistant.backup_restore import RESTORE_BACKUP_FILE
|
||||
from homeassistant.const import __version__ as HAVERSION
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -123,6 +124,10 @@ class BaseBackupManager(abc.ABC):
|
|||
LOGGER.debug("Loaded %s platforms", len(self.platforms))
|
||||
self.loaded_platforms = True
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None:
|
||||
"""Restpre a backup."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_backup(self, **kwargs: Any) -> Backup:
|
||||
"""Generate a backup."""
|
||||
|
@ -291,6 +296,25 @@ class BackupManager(BaseBackupManager):
|
|||
|
||||
return tar_file_path.stat().st_size
|
||||
|
||||
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None:
|
||||
"""Restore a backup.
|
||||
|
||||
This will write the restore information to .HA_RESTORE which
|
||||
will be handled during startup by the restore_backup module.
|
||||
"""
|
||||
if (backup := await self.async_get_backup(slug=slug)) is None:
|
||||
raise HomeAssistantError(f"Backup {slug} not found")
|
||||
|
||||
def _write_restore_file() -> None:
|
||||
"""Write the restore file."""
|
||||
Path(self.hass.config.path(RESTORE_BACKUP_FILE)).write_text(
|
||||
f"{backup.path.as_posix()};",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
await self.hass.async_add_executor_job(_write_restore_file)
|
||||
await self.hass.services.async_call("homeassistant", "restart", {})
|
||||
|
||||
|
||||
def _generate_slug(date: str, name: str) -> str:
|
||||
"""Generate a backup slug."""
|
||||
|
|
|
@ -22,6 +22,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) ->
|
|||
websocket_api.async_register_command(hass, handle_info)
|
||||
websocket_api.async_register_command(hass, handle_create)
|
||||
websocket_api.async_register_command(hass, handle_remove)
|
||||
websocket_api.async_register_command(hass, handle_restore)
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
|
@ -85,6 +86,24 @@ async def handle_remove(
|
|||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "backup/restore",
|
||||
vol.Required("slug"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def handle_restore(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Restore a backup."""
|
||||
await hass.data[DATA_MANAGER].async_restore_backup(msg["slug"])
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command({vol.Required("type"): "backup/generate"})
|
||||
@websocket_api.async_response
|
||||
|
|
|
@ -57,6 +57,7 @@ PyTurboJPEG==1.7.5
|
|||
pyudev==0.24.1
|
||||
PyYAML==6.0.2
|
||||
requests==2.32.3
|
||||
securetar==2024.2.1
|
||||
SQLAlchemy==2.0.31
|
||||
typing-extensions>=4.12.2,<5.0
|
||||
ulid-transform==1.0.2
|
||||
|
|
|
@ -63,6 +63,7 @@ dependencies = [
|
|||
"python-slugify==8.0.4",
|
||||
"PyYAML==6.0.2",
|
||||
"requests==2.32.3",
|
||||
"securetar==2024.2.1",
|
||||
"SQLAlchemy==2.0.31",
|
||||
"typing-extensions>=4.12.2,<5.0",
|
||||
"ulid-transform==1.0.2",
|
||||
|
|
|
@ -35,6 +35,7 @@ psutil-home-assistant==0.0.1
|
|||
python-slugify==8.0.4
|
||||
PyYAML==6.0.2
|
||||
requests==2.32.3
|
||||
securetar==2024.2.1
|
||||
SQLAlchemy==2.0.31
|
||||
typing-extensions>=4.12.2,<5.0
|
||||
ulid-transform==1.0.2
|
||||
|
|
|
@ -269,3 +269,22 @@
|
|||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_restore[with_hassio]
|
||||
dict({
|
||||
'error': dict({
|
||||
'code': 'unknown_command',
|
||||
'message': 'Unknown command.',
|
||||
}),
|
||||
'id': 1,
|
||||
'success': False,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_restore[without_hassio]
|
||||
dict({
|
||||
'id': 1,
|
||||
'result': None,
|
||||
'success': True,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
|
|
|
@ -333,3 +333,31 @@ async def test_loading_platforms_when_running_async_post_backup_actions(
|
|||
assert len(manager.platforms) == 1
|
||||
|
||||
assert "Loaded 1 platforms" in caplog.text
|
||||
|
||||
|
||||
async def test_async_trigger_restore(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test trigger restore."""
|
||||
manager = BackupManager(hass)
|
||||
manager.loaded_backups = True
|
||||
manager.backups = {TEST_BACKUP.slug: TEST_BACKUP}
|
||||
|
||||
with (
|
||||
patch("pathlib.Path.exists", return_value=True),
|
||||
patch("pathlib.Path.write_text") as mocked_write_text,
|
||||
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
|
||||
):
|
||||
await manager.async_restore_backup(TEST_BACKUP.slug)
|
||||
assert mocked_write_text.call_args[0][0] == "abc123.tar;"
|
||||
assert mocked_service_call.called
|
||||
|
||||
|
||||
async def test_async_trigger_restore_missing_backup(hass: HomeAssistant) -> None:
|
||||
"""Test trigger restore."""
|
||||
manager = BackupManager(hass)
|
||||
manager.loaded_backups = True
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="Backup abc123 not found"):
|
||||
await manager.async_restore_backup(TEST_BACKUP.slug)
|
||||
|
|
|
@ -141,6 +141,32 @@ async def test_generate(
|
|||
assert snapshot == await client.receive_json()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"with_hassio",
|
||||
[
|
||||
pytest.param(True, id="with_hassio"),
|
||||
pytest.param(False, id="without_hassio"),
|
||||
],
|
||||
)
|
||||
async def test_restore(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
with_hassio: bool,
|
||||
) -> None:
|
||||
"""Test calling the restore command."""
|
||||
await setup_backup_integration(hass, with_hassio=with_hassio)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_restore_backup",
|
||||
):
|
||||
await client.send_json_auto_id({"type": "backup/restore", "slug": "abc123"})
|
||||
assert await client.receive_json() == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"access_token_fixture_name",
|
||||
["hass_access_token", "hass_supervisor_access_token"],
|
||||
|
|
220
tests/test_backup_restore.py
Normal file
220
tests/test_backup_restore.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
"""Test methods in backup_restore."""
|
||||
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import backup_restore
|
||||
|
||||
from .common import get_test_config_dir
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "content", "expected"),
|
||||
[
|
||||
(FileNotFoundError, "", None),
|
||||
(None, "", backup_restore.RestoreBackupFileContent(backup_file_path=Path(""))),
|
||||
(
|
||||
None,
|
||||
"test;",
|
||||
backup_restore.RestoreBackupFileContent(backup_file_path=Path("test")),
|
||||
),
|
||||
(
|
||||
None,
|
||||
"test;;;;",
|
||||
backup_restore.RestoreBackupFileContent(backup_file_path=Path("test")),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_reading_the_instruction_contents(
|
||||
side_effect: Exception | None,
|
||||
content: str,
|
||||
expected: backup_restore.RestoreBackupFileContent | None,
|
||||
) -> None:
|
||||
"""Test reading the content of the .HA_RESTORE file."""
|
||||
with (
|
||||
mock.patch(
|
||||
"pathlib.Path.read_text",
|
||||
return_value=content,
|
||||
side_effect=side_effect,
|
||||
),
|
||||
):
|
||||
read_content = backup_restore.restore_backup_file_content(
|
||||
Path(get_test_config_dir())
|
||||
)
|
||||
assert read_content == expected
|
||||
|
||||
|
||||
def test_restoring_backup_that_does_not_exist() -> None:
|
||||
"""Test restoring a backup that does not exist."""
|
||||
backup_file_path = Path(get_test_config_dir("backups", "test"))
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=backup_restore.RestoreBackupFileContent(
|
||||
backup_file_path=backup_file_path
|
||||
),
|
||||
),
|
||||
mock.patch("pathlib.Path.read_text", side_effect=FileNotFoundError),
|
||||
pytest.raises(
|
||||
ValueError, match=f"Backup file {backup_file_path} does not exist"
|
||||
),
|
||||
):
|
||||
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
|
||||
|
||||
|
||||
def test_restoring_backup_when_instructions_can_not_be_read() -> None:
|
||||
"""Test restoring a backup when instructions can not be read."""
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
|
||||
|
||||
|
||||
def test_restoring_backup_that_is_not_a_file() -> None:
|
||||
"""Test restoring a backup that is not a file."""
|
||||
backup_file_path = Path(get_test_config_dir("backups", "test"))
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=backup_restore.RestoreBackupFileContent(
|
||||
backup_file_path=backup_file_path
|
||||
),
|
||||
),
|
||||
mock.patch("pathlib.Path.exists", return_value=True),
|
||||
mock.patch("pathlib.Path.is_file", return_value=False),
|
||||
pytest.raises(
|
||||
ValueError, match=f"Backup file {backup_file_path} does not exist"
|
||||
),
|
||||
):
|
||||
assert backup_restore.restore_backup(Path(get_test_config_dir())) is False
|
||||
|
||||
|
||||
def test_aborting_for_older_versions() -> None:
|
||||
"""Test that we abort for older versions."""
|
||||
config_dir = Path(get_test_config_dir())
|
||||
backup_file_path = Path(config_dir, "backups", "test.tar")
|
||||
|
||||
def _patched_path_read_text(path: Path, **kwargs):
|
||||
return '{"homeassistant": {"version": "9999.99.99"}, "compressed": false}'
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=backup_restore.RestoreBackupFileContent(
|
||||
backup_file_path=backup_file_path
|
||||
),
|
||||
),
|
||||
mock.patch("securetar.SecureTarFile"),
|
||||
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
|
||||
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
|
||||
mock.patch("homeassistant.backup_restore.HA_VERSION", "2013.09.17"),
|
||||
pytest.raises(
|
||||
ValueError,
|
||||
match="You need at least Home Assistant version 9999.99.99 to restore this backup",
|
||||
),
|
||||
):
|
||||
assert backup_restore.restore_backup(config_dir) is True
|
||||
|
||||
|
||||
def test_removal_of_current_configuration_when_restoring() -> None:
|
||||
"""Test that we are removing the current configuration directory."""
|
||||
config_dir = Path(get_test_config_dir())
|
||||
backup_file_path = Path(config_dir, "backups", "test.tar")
|
||||
mock_config_dir = [
|
||||
{"path": Path(config_dir, ".HA_RESTORE"), "is_file": True},
|
||||
{"path": Path(config_dir, ".HA_VERSION"), "is_file": True},
|
||||
{"path": Path(config_dir, "backups"), "is_file": False},
|
||||
{"path": Path(config_dir, "www"), "is_file": False},
|
||||
]
|
||||
|
||||
def _patched_path_read_text(path: Path, **kwargs):
|
||||
return '{"homeassistant": {"version": "2013.09.17"}, "compressed": false}'
|
||||
|
||||
def _patched_path_is_file(path: Path, **kwargs):
|
||||
return [x for x in mock_config_dir if x["path"] == path][0]["is_file"]
|
||||
|
||||
def _patched_path_is_dir(path: Path, **kwargs):
|
||||
return not [x for x in mock_config_dir if x["path"] == path][0]["is_file"]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=backup_restore.RestoreBackupFileContent(
|
||||
backup_file_path=backup_file_path
|
||||
),
|
||||
),
|
||||
mock.patch("securetar.SecureTarFile"),
|
||||
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
|
||||
mock.patch("homeassistant.backup_restore.HA_VERSION", "2013.09.17"),
|
||||
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
|
||||
mock.patch("pathlib.Path.is_file", _patched_path_is_file),
|
||||
mock.patch("pathlib.Path.is_dir", _patched_path_is_dir),
|
||||
mock.patch(
|
||||
"pathlib.Path.iterdir",
|
||||
return_value=[x["path"] for x in mock_config_dir],
|
||||
),
|
||||
mock.patch("pathlib.Path.unlink") as unlink_mock,
|
||||
mock.patch("shutil.rmtree") as rmtreemock,
|
||||
):
|
||||
assert backup_restore.restore_backup(config_dir) is True
|
||||
assert unlink_mock.call_count == 2
|
||||
assert (
|
||||
rmtreemock.call_count == 1
|
||||
) # We have 2 directories in the config directory, but backups is kept
|
||||
|
||||
removed_directories = {Path(call.args[0]) for call in rmtreemock.mock_calls}
|
||||
assert removed_directories == {Path(config_dir, "www")}
|
||||
|
||||
|
||||
def test_extracting_the_contents_of_a_backup_file() -> None:
|
||||
"""Test extracting the contents of a backup file."""
|
||||
config_dir = Path(get_test_config_dir())
|
||||
backup_file_path = Path(config_dir, "backups", "test.tar")
|
||||
|
||||
def _patched_path_read_text(path: Path, **kwargs):
|
||||
return '{"homeassistant": {"version": "2013.09.17"}, "compressed": false}'
|
||||
|
||||
getmembers_mock = mock.MagicMock(
|
||||
return_value=[
|
||||
tarfile.TarInfo(name="data"),
|
||||
tarfile.TarInfo(name="data/../test"),
|
||||
tarfile.TarInfo(name="data/.HA_VERSION"),
|
||||
tarfile.TarInfo(name="data/.storage"),
|
||||
tarfile.TarInfo(name="data/www"),
|
||||
]
|
||||
)
|
||||
extractall_mock = mock.MagicMock()
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"homeassistant.backup_restore.restore_backup_file_content",
|
||||
return_value=backup_restore.RestoreBackupFileContent(
|
||||
backup_file_path=backup_file_path
|
||||
),
|
||||
),
|
||||
mock.patch(
|
||||
"tarfile.open",
|
||||
return_value=mock.MagicMock(
|
||||
getmembers=getmembers_mock,
|
||||
extractall=extractall_mock,
|
||||
__iter__=lambda x: iter(getmembers_mock.return_value),
|
||||
),
|
||||
),
|
||||
mock.patch("homeassistant.backup_restore.TemporaryDirectory"),
|
||||
mock.patch("pathlib.Path.read_text", _patched_path_read_text),
|
||||
mock.patch("pathlib.Path.is_file", return_value=False),
|
||||
mock.patch("pathlib.Path.iterdir", return_value=[]),
|
||||
):
|
||||
assert backup_restore.restore_backup(config_dir) is True
|
||||
assert getmembers_mock.call_count == 1
|
||||
assert extractall_mock.call_count == 2
|
||||
|
||||
assert {
|
||||
member.name for member in extractall_mock.mock_calls[-1].kwargs["members"]
|
||||
} == {".HA_VERSION", ".storage", "www"}
|
|
@ -3,7 +3,7 @@
|
|||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from homeassistant import __main__ as main
|
||||
from homeassistant.const import REQUIRED_PYTHON_VER
|
||||
from homeassistant.const import REQUIRED_PYTHON_VER, RESTART_EXIT_CODE
|
||||
|
||||
|
||||
@patch("sys.exit")
|
||||
|
@ -86,3 +86,13 @@ def test_skip_pip_mutually_exclusive(mock_exit) -> None:
|
|||
assert mock_exit.called is False
|
||||
args = parse_args("--skip-pip", "--skip-pip-packages", "foo")
|
||||
assert mock_exit.called is True
|
||||
|
||||
|
||||
def test_restart_after_backup_restore() -> None:
|
||||
"""Test restarting if we restored a backup."""
|
||||
with (
|
||||
patch("sys.argv", ["python"]),
|
||||
patch("homeassistant.__main__.restore_backup", return_value=True),
|
||||
):
|
||||
exit_code = main.main()
|
||||
assert exit_code == RESTART_EXIT_CODE
|
||||
|
|
Loading…
Add table
Reference in a new issue