Show OTA update progress for Shelly gen2 devices (#99534)

* Show OTA update progress

* Use an event listener instead of a dispatcher

* Add tests

* Fix name

* Improve tests coverage

* Fix subscribe/unsubscribe logic

* Use async_on_remove()
This commit is contained in:
Maciej Bieniek 2023-09-06 06:17:45 +00:00 committed by GitHub
parent 4f05e61072
commit d9a1ebafdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 159 additions and 20 deletions

View file

@ -181,3 +181,8 @@ PUSH_UPDATE_ISSUE_ID = "push_update_{unique}"
NOT_CALIBRATED_ISSUE_ID = "not_calibrated_{unique}" NOT_CALIBRATED_ISSUE_ID = "not_calibrated_{unique}"
GAS_VALVE_OPEN_STATES = ("opening", "opened") GAS_VALVE_OPEN_STATES = ("opening", "opened")
OTA_BEGIN = "ota_begin"
OTA_ERROR = "ota_error"
OTA_PROGRESS = "ota_progress"
OTA_SUCCESS = "ota_success"

View file

@ -44,6 +44,10 @@ from .const import (
LOGGER, LOGGER,
MAX_PUSH_UPDATE_FAILURES, MAX_PUSH_UPDATE_FAILURES,
MODELS_SUPPORTING_LIGHT_EFFECTS, MODELS_SUPPORTING_LIGHT_EFFECTS,
OTA_BEGIN,
OTA_ERROR,
OTA_PROGRESS,
OTA_SUCCESS,
PUSH_UPDATE_ISSUE_ID, PUSH_UPDATE_ISSUE_ID,
REST_SENSORS_UPDATE_INTERVAL, REST_SENSORS_UPDATE_INTERVAL,
RPC_INPUTS_EVENTS_TYPES, RPC_INPUTS_EVENTS_TYPES,
@ -384,6 +388,7 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
self._disconnected_callbacks: list[CALLBACK_TYPE] = [] self._disconnected_callbacks: list[CALLBACK_TYPE] = []
self._connection_lock = asyncio.Lock() self._connection_lock = asyncio.Lock()
self._event_listeners: list[Callable[[dict[str, Any]], None]] = [] self._event_listeners: list[Callable[[dict[str, Any]], None]] = []
self._ota_event_listeners: list[Callable[[dict[str, Any]], None]] = []
entry.async_on_unload( entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._handle_ha_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._handle_ha_stop)
@ -408,6 +413,19 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
return True return True
@callback
def async_subscribe_ota_events(
self, ota_event_callback: Callable[[dict[str, Any]], None]
) -> CALLBACK_TYPE:
"""Subscribe to OTA events."""
def _unsubscribe() -> None:
self._ota_event_listeners.remove(ota_event_callback)
self._ota_event_listeners.append(ota_event_callback)
return _unsubscribe
@callback @callback
def async_subscribe_events( def async_subscribe_events(
self, event_callback: Callable[[dict[str, Any]], None] self, event_callback: Callable[[dict[str, Any]], None]
@ -461,6 +479,9 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
ATTR_GENERATION: 2, ATTR_GENERATION: 2,
}, },
) )
elif event_type in (OTA_BEGIN, OTA_ERROR, OTA_PROGRESS, OTA_SUCCESS):
for event_callback in self._ota_event_listeners:
event_callback(event)
async def _async_update_data(self) -> None: async def _async_update_data(self) -> None:
"""Fetch data.""" """Fetch data."""

View file

@ -18,12 +18,12 @@ from homeassistant.components.update import (
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from .const import CONF_SLEEP_PERIOD from .const import CONF_SLEEP_PERIOD, OTA_BEGIN, OTA_ERROR, OTA_PROGRESS, OTA_SUCCESS
from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator from .coordinator import ShellyBlockCoordinator, ShellyRpcCoordinator
from .entity import ( from .entity import (
RestEntityDescription, RestEntityDescription,
@ -229,7 +229,28 @@ class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity):
) -> None: ) -> None:
"""Initialize update entity.""" """Initialize update entity."""
super().__init__(coordinator, key, attribute, description) super().__init__(coordinator, key, attribute, description)
self._in_progress_old_version: str | None = None self._ota_in_progress: bool = False
async def async_added_to_hass(self) -> None:
"""When entity is added to hass."""
await super().async_added_to_hass()
self.async_on_remove(
self.coordinator.async_subscribe_ota_events(self._ota_progress_callback)
)
@callback
def _ota_progress_callback(self, event: dict[str, Any]) -> None:
"""Handle device OTA progress."""
if self._ota_in_progress:
event_type = event["event"]
if event_type == OTA_BEGIN:
self._attr_in_progress = 0
elif event_type == OTA_PROGRESS:
self._attr_in_progress = event["progress_percent"]
elif event_type in (OTA_ERROR, OTA_SUCCESS):
self._attr_in_progress = False
self._ota_in_progress = False
self.async_write_ha_state()
@property @property
def installed_version(self) -> str | None: def installed_version(self) -> str | None:
@ -245,16 +266,10 @@ class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity):
return self.installed_version return self.installed_version
@property
def in_progress(self) -> bool:
"""Update installation in progress."""
return self._in_progress_old_version == self.installed_version
async def async_install( async def async_install(
self, version: str | None, backup: bool, **kwargs: Any self, version: str | None, backup: bool, **kwargs: Any
) -> None: ) -> None:
"""Install the latest firmware version.""" """Install the latest firmware version."""
self._in_progress_old_version = self.installed_version
beta = self.entity_description.beta beta = self.entity_description.beta
update_data = self.coordinator.device.status["sys"]["available_updates"] update_data = self.coordinator.device.status["sys"]["available_updates"]
LOGGER.debug("OTA update service - update_data: %s", update_data) LOGGER.debug("OTA update service - update_data: %s", update_data)
@ -280,6 +295,7 @@ class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity):
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) self.coordinator.entry.async_start_reauth(self.hass)
else: else:
self._ota_in_progress = True
LOGGER.debug("OTA update call successful") LOGGER.debug("OTA update call successful")

View file

@ -29,6 +29,7 @@ from homeassistant.helpers.entity_registry import async_get
from . import ( from . import (
MOCK_MAC, MOCK_MAC,
init_integration, init_integration,
inject_rpc_device_event,
mock_rest_update, mock_rest_update,
register_device, register_device,
register_entity, register_entity,
@ -222,6 +223,7 @@ async def test_block_update_auth_error(
async def test_rpc_update(hass: HomeAssistant, mock_rpc_device, monkeypatch) -> None: async def test_rpc_update(hass: HomeAssistant, mock_rpc_device, monkeypatch) -> None:
"""Test RPC device update entity.""" """Test RPC device update entity."""
entity_id = "update.test_name_firmware_update"
monkeypatch.setitem(mock_rpc_device.shelly, "ver", "1") monkeypatch.setitem(mock_rpc_device.shelly, "ver", "1")
monkeypatch.setitem( monkeypatch.setitem(
mock_rpc_device.status["sys"], mock_rpc_device.status["sys"],
@ -232,7 +234,7 @@ async def test_rpc_update(hass: HomeAssistant, mock_rpc_device, monkeypatch) ->
) )
await init_integration(hass, 2) await init_integration(hass, 2)
state = hass.states.get("update.test_name_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_ON assert state.state == STATE_ON
assert state.attributes[ATTR_INSTALLED_VERSION] == "1" assert state.attributes[ATTR_INSTALLED_VERSION] == "1"
assert state.attributes[ATTR_LATEST_VERSION] == "2" assert state.attributes[ATTR_LATEST_VERSION] == "2"
@ -243,21 +245,68 @@ async def test_rpc_update(hass: HomeAssistant, mock_rpc_device, monkeypatch) ->
await hass.services.async_call( await hass.services.async_call(
UPDATE_DOMAIN, UPDATE_DOMAIN,
SERVICE_INSTALL, SERVICE_INSTALL,
{ATTR_ENTITY_ID: "update.test_name_firmware_update"}, {ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_begin",
"id": 1,
"ts": 1668522399.2,
}
],
"ts": 1668522399.2,
},
)
assert mock_rpc_device.trigger_ota_update.call_count == 1 assert mock_rpc_device.trigger_ota_update.call_count == 1
state = hass.states.get("update.test_name_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_ON assert state.state == STATE_ON
assert state.attributes[ATTR_INSTALLED_VERSION] == "1" assert state.attributes[ATTR_INSTALLED_VERSION] == "1"
assert state.attributes[ATTR_LATEST_VERSION] == "2" assert state.attributes[ATTR_LATEST_VERSION] == "2"
assert state.attributes[ATTR_IN_PROGRESS] is True assert state.attributes[ATTR_IN_PROGRESS] == 0
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_progress",
"id": 1,
"ts": 1668522399.2,
"progress_percent": 50,
}
],
"ts": 1668522399.2,
},
)
assert hass.states.get(entity_id).attributes[ATTR_IN_PROGRESS] == 50
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_success",
"id": 1,
"ts": 1668522399.2,
}
],
"ts": 1668522399.2,
},
)
monkeypatch.setitem(mock_rpc_device.shelly, "ver", "2") monkeypatch.setitem(mock_rpc_device.shelly, "ver", "2")
mock_rpc_device.mock_update() mock_rpc_device.mock_update()
state = hass.states.get("update.test_name_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_OFF assert state.state == STATE_OFF
assert state.attributes[ATTR_INSTALLED_VERSION] == "2" assert state.attributes[ATTR_INSTALLED_VERSION] == "2"
assert state.attributes[ATTR_LATEST_VERSION] == "2" assert state.attributes[ATTR_LATEST_VERSION] == "2"
@ -401,6 +450,7 @@ async def test_rpc_beta_update(
suggested_object_id="test_name_beta_firmware_update", suggested_object_id="test_name_beta_firmware_update",
disabled_by=None, disabled_by=None,
) )
entity_id = "update.test_name_beta_firmware_update"
monkeypatch.setitem(mock_rpc_device.shelly, "ver", "1") monkeypatch.setitem(mock_rpc_device.shelly, "ver", "1")
monkeypatch.setitem( monkeypatch.setitem(
mock_rpc_device.status["sys"], mock_rpc_device.status["sys"],
@ -412,7 +462,7 @@ async def test_rpc_beta_update(
) )
await init_integration(hass, 2) await init_integration(hass, 2)
state = hass.states.get("update.test_name_beta_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_OFF assert state.state == STATE_OFF
assert state.attributes[ATTR_INSTALLED_VERSION] == "1" assert state.attributes[ATTR_INSTALLED_VERSION] == "1"
assert state.attributes[ATTR_LATEST_VERSION] == "1" assert state.attributes[ATTR_LATEST_VERSION] == "1"
@ -428,7 +478,7 @@ async def test_rpc_beta_update(
) )
await mock_rest_update(hass, freezer) await mock_rest_update(hass, freezer)
state = hass.states.get("update.test_name_beta_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_ON assert state.state == STATE_ON
assert state.attributes[ATTR_INSTALLED_VERSION] == "1" assert state.attributes[ATTR_INSTALLED_VERSION] == "1"
assert state.attributes[ATTR_LATEST_VERSION] == "2b" assert state.attributes[ATTR_LATEST_VERSION] == "2b"
@ -437,21 +487,68 @@ async def test_rpc_beta_update(
await hass.services.async_call( await hass.services.async_call(
UPDATE_DOMAIN, UPDATE_DOMAIN,
SERVICE_INSTALL, SERVICE_INSTALL,
{ATTR_ENTITY_ID: "update.test_name_beta_firmware_update"}, {ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_begin",
"id": 1,
"ts": 1668522399.2,
}
],
"ts": 1668522399.2,
},
)
assert mock_rpc_device.trigger_ota_update.call_count == 1 assert mock_rpc_device.trigger_ota_update.call_count == 1
state = hass.states.get("update.test_name_beta_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_ON assert state.state == STATE_ON
assert state.attributes[ATTR_INSTALLED_VERSION] == "1" assert state.attributes[ATTR_INSTALLED_VERSION] == "1"
assert state.attributes[ATTR_LATEST_VERSION] == "2b" assert state.attributes[ATTR_LATEST_VERSION] == "2b"
assert state.attributes[ATTR_IN_PROGRESS] is True assert state.attributes[ATTR_IN_PROGRESS] == 0
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_progress",
"id": 1,
"ts": 1668522399.2,
"progress_percent": 40,
}
],
"ts": 1668522399.2,
},
)
assert hass.states.get(entity_id).attributes[ATTR_IN_PROGRESS] == 40
inject_rpc_device_event(
monkeypatch,
mock_rpc_device,
{
"events": [
{
"event": "ota_success",
"id": 1,
"ts": 1668522399.2,
}
],
"ts": 1668522399.2,
},
)
monkeypatch.setitem(mock_rpc_device.shelly, "ver", "2b") monkeypatch.setitem(mock_rpc_device.shelly, "ver", "2b")
await mock_rest_update(hass, freezer) await mock_rest_update(hass, freezer)
state = hass.states.get("update.test_name_beta_firmware_update") state = hass.states.get(entity_id)
assert state.state == STATE_OFF assert state.state == STATE_OFF
assert state.attributes[ATTR_INSTALLED_VERSION] == "2b" assert state.attributes[ATTR_INSTALLED_VERSION] == "2b"
assert state.attributes[ATTR_LATEST_VERSION] == "2b" assert state.attributes[ATTR_LATEST_VERSION] == "2b"