hass-core/homeassistant/components/update/__init__.py
Joakim Sørensen 1a78e18eeb
Add update integration (#66552)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
2022-03-03 08:27:09 -08:00

273 lines
8.1 KiB
Python

"""Support for Update."""
from __future__ import annotations
import asyncio
import dataclasses
import logging
from typing import Any, Protocol
import async_timeout
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import integration_platform, storage
from homeassistant.helpers.typing import ConfigType
_LOGGER = logging.getLogger(__name__)
DOMAIN = "update"
INFO_CALLBACK_TIMEOUT = 5
STORAGE_VERSION = 1
class IntegrationUpdateFailed(HomeAssistantError):
"""Error to indicate an update has failed."""
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Update integration."""
hass.data[DOMAIN] = UpdateManager(hass=hass)
websocket_api.async_register_command(hass, handle_info)
websocket_api.async_register_command(hass, handle_update)
websocket_api.async_register_command(hass, handle_skip)
return True
@websocket_api.websocket_command({vol.Required("type"): "update/info"})
@websocket_api.async_response
async def handle_info(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get pending updates from all platforms."""
manager: UpdateManager = hass.data[DOMAIN]
updates = await manager.gather_updates()
connection.send_result(msg["id"], updates)
@websocket_api.websocket_command(
{
vol.Required("type"): "update/skip",
vol.Required("domain"): str,
vol.Required("identifier"): str,
vol.Required("version"): str,
}
)
@websocket_api.async_response
async def handle_skip(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Skip an update."""
manager: UpdateManager = hass.data[DOMAIN]
if not await manager.domain_is_valid(msg["domain"]):
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Domain not supported"
)
return
manager.skip_update(msg["domain"], msg["identifier"], msg["version"])
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "update/update",
vol.Required("domain"): str,
vol.Required("identifier"): str,
vol.Required("version"): str,
vol.Optional("backup"): bool,
}
)
@websocket_api.async_response
async def handle_update(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle an update."""
manager: UpdateManager = hass.data[DOMAIN]
if not await manager.domain_is_valid(msg["domain"]):
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_FOUND,
f"{msg['domain']} is not a supported domain",
)
return
try:
await manager.perform_update(
domain=msg["domain"],
identifier=msg["identifier"],
version=msg["version"],
backup=msg.get("backup"),
)
except IntegrationUpdateFailed as err:
connection.send_error(
msg["id"],
"update_failed",
str(err),
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Update of %s to version %s failed",
msg["identifier"],
msg["version"],
)
connection.send_error(
msg["id"],
"update_failed",
"Unknown Error",
)
else:
connection.send_result(msg["id"])
class UpdatePlatformProtocol(Protocol):
"""Define the format that update platforms can have."""
async def async_list_updates(self, hass: HomeAssistant) -> list[UpdateDescription]:
"""List all updates available in the integration."""
async def async_perform_update(
self,
hass: HomeAssistant,
identifier: str,
version: str,
**kwargs: Any,
) -> None:
"""Perform an update."""
@dataclasses.dataclass()
class UpdateDescription:
"""Describe an update update."""
identifier: str
name: str
current_version: str
available_version: str
changelog_content: str | None = None
changelog_url: str | None = None
icon_url: str | None = None
supports_backup: bool = False
class UpdateManager:
"""Update manager for the update integration."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the update manager."""
self._hass = hass
self._store = storage.Store(
hass=hass,
version=STORAGE_VERSION,
key=DOMAIN,
)
self._skip: set[str] = set()
self._platforms: dict[str, UpdatePlatformProtocol] = {}
self._loaded = False
async def add_platform(
self,
hass: HomeAssistant,
integration_domain: str,
platform: UpdatePlatformProtocol,
) -> None:
"""Add a platform to the update manager."""
self._platforms[integration_domain] = platform
async def _load(self) -> None:
"""Load platforms and data from storage."""
await integration_platform.async_process_integration_platforms(
self._hass, DOMAIN, self.add_platform
)
from_storage = await self._store.async_load()
if isinstance(from_storage, dict):
self._skip = set(from_storage["skipped"])
self._loaded = True
async def gather_updates(self) -> list[dict[str, Any]]:
"""Gather updates."""
if not self._loaded:
await self._load()
updates: dict[str, list[UpdateDescription] | None] = {}
for domain, update_descriptions in zip(
self._platforms,
await asyncio.gather(
*(
self._get_integration_info(integration_domain, registration)
for integration_domain, registration in self._platforms.items()
)
),
):
updates[domain] = update_descriptions
return [
{
"domain": integration_domain,
**dataclasses.asdict(description),
}
for integration_domain, update_descriptions in updates.items()
if update_descriptions is not None
for description in update_descriptions
if f"{integration_domain}_{description.identifier}_{description.available_version}"
not in self._skip
]
async def domain_is_valid(self, domain: str) -> bool:
"""Return if the domain is valid."""
if not self._loaded:
await self._load()
return domain in self._platforms
@callback
def _data_to_save(self) -> dict[str, Any]:
"""Schedule storing the data."""
return {"skipped": list(self._skip)}
async def perform_update(
self,
domain: str,
identifier: str,
version: str,
**kwargs: Any,
) -> None:
"""Perform an update."""
await self._platforms[domain].async_perform_update(
hass=self._hass,
identifier=identifier,
version=version,
**kwargs,
)
@callback
def skip_update(self, domain: str, identifier: str, version: str) -> None:
"""Skip an update."""
self._skip.add(f"{domain}_{identifier}_{version}")
self._store.async_delay_save(self._data_to_save, 60)
async def _get_integration_info(
self,
integration_domain: str,
platform: UpdatePlatformProtocol,
) -> list[UpdateDescription] | None:
"""Get integration update details."""
try:
async with async_timeout.timeout(INFO_CALLBACK_TIMEOUT):
return await platform.async_list_updates(hass=self._hass)
except asyncio.TimeoutError:
_LOGGER.warning("Timeout while getting updates from %s", integration_domain)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error fetching info from %s", integration_domain)
return None