Support restoring NumberEntity native_value (#73475)

This commit is contained in:
Erik Montnemery 2022-06-14 19:56:27 +02:00 committed by GitHub
parent 61e4b56e19
commit 23fa19b75a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 6 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass import dataclasses
from datetime import timedelta from datetime import timedelta
import inspect import inspect
import logging import logging
@ -22,6 +22,7 @@ from homeassistant.helpers.config_validation import ( # noqa: F401
) )
from homeassistant.helpers.entity import Entity, EntityDescription from homeassistant.helpers.entity import Entity, EntityDescription
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import temperature as temperature_util from homeassistant.util import temperature as temperature_util
@ -112,7 +113,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await component.async_unload_entry(entry) return await component.async_unload_entry(entry)
@dataclass @dataclasses.dataclass
class NumberEntityDescription(EntityDescription): class NumberEntityDescription(EntityDescription):
"""A class that describes number entities.""" """A class that describes number entities."""
@ -324,7 +325,7 @@ class NumberEntity(Entity):
@property @property
def native_value(self) -> float | None: def native_value(self) -> float | None:
"""Return the value reported by the sensor.""" """Return the value reported by the number."""
return self._attr_native_value return self._attr_native_value
@property @property
@ -419,3 +420,53 @@ class NumberEntity(Entity):
type(self), type(self),
report_issue, report_issue,
) )
@dataclasses.dataclass
class NumberExtraStoredData(ExtraStoredData):
"""Object to hold extra stored data."""
native_max_value: float | None
native_min_value: float | None
native_step: float | None
native_unit_of_measurement: str | None
native_value: float | None
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the number data."""
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, restored: dict[str, Any]) -> NumberExtraStoredData | None:
"""Initialize a stored number state from a dict."""
try:
return cls(
restored["native_max_value"],
restored["native_min_value"],
restored["native_step"],
restored["native_unit_of_measurement"],
restored["native_value"],
)
except KeyError:
return None
class RestoreNumber(NumberEntity, RestoreEntity):
"""Mixin class for restoring previous number state."""
@property
def extra_restore_state_data(self) -> NumberExtraStoredData:
"""Return number specific state data to be restored."""
return NumberExtraStoredData(
self.native_max_value,
self.native_min_value,
self.native_step,
self.native_unit_of_measurement,
self.native_value,
)
async def async_get_last_number_data(self) -> NumberExtraStoredData | None:
"""Restore native_*."""
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
return None
return NumberExtraStoredData.from_dict(restored_last_extra_data.as_dict())

View file

@ -21,10 +21,13 @@ from homeassistant.const import (
TEMP_CELSIUS, TEMP_CELSIUS,
TEMP_FAHRENHEIT, TEMP_FAHRENHEIT,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
from tests.common import mock_restore_cache_with_extra_data
class MockDefaultNumberEntity(NumberEntity): class MockDefaultNumberEntity(NumberEntity):
"""Mock NumberEntity device to use in tests. """Mock NumberEntity device to use in tests.
@ -570,3 +573,115 @@ async def test_temperature_conversion(
state = hass.states.get(entity0.entity_id) state = hass.states.get(entity0.entity_id)
assert float(state.state) == pytest.approx(float(state_max_value), rel=0.1) assert float(state.state) == pytest.approx(float(state_max_value), rel=0.1)
RESTORE_DATA = {
"native_max_value": 200.0,
"native_min_value": -10.0,
"native_step": 2.0,
"native_unit_of_measurement": "°F",
"native_value": 123.0,
}
async def test_restore_number_save_state(
hass,
hass_storage,
enable_custom_integrations,
):
"""Test RestoreNumber."""
platform = getattr(hass.components, "test.number")
platform.init(empty=True)
platform.ENTITIES.append(
platform.MockRestoreNumber(
name="Test",
native_max_value=200.0,
native_min_value=-10.0,
native_step=2.0,
native_unit_of_measurement=TEMP_FAHRENHEIT,
native_value=123.0,
device_class=NumberDeviceClass.TEMPERATURE,
)
)
entity0 = platform.ENTITIES[0]
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
await hass.async_block_till_done()
# Trigger saving state
await hass.async_stop()
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
assert state["entity_id"] == entity0.entity_id
extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"]
assert extra_data == RESTORE_DATA
assert type(extra_data["native_value"]) == float
@pytest.mark.parametrize(
"native_max_value, native_min_value, native_step, native_value, native_value_type, extra_data, device_class, uom",
[
(
200.0,
-10.0,
2.0,
123.0,
float,
RESTORE_DATA,
NumberDeviceClass.TEMPERATURE,
"°F",
),
(100.0, 0.0, None, None, type(None), None, None, None),
(100.0, 0.0, None, None, type(None), {}, None, None),
(100.0, 0.0, None, None, type(None), {"beer": 123}, None, None),
(
100.0,
0.0,
None,
None,
type(None),
{"native_unit_of_measurement": "°F", "native_value": {}},
None,
None,
),
],
)
async def test_restore_number_restore_state(
hass,
enable_custom_integrations,
hass_storage,
native_max_value,
native_min_value,
native_step,
native_value,
native_value_type,
extra_data,
device_class,
uom,
):
"""Test RestoreNumber."""
mock_restore_cache_with_extra_data(hass, ((State("number.test", ""), extra_data),))
platform = getattr(hass.components, "test.number")
platform.init(empty=True)
platform.ENTITIES.append(
platform.MockRestoreNumber(
device_class=device_class,
name="Test",
native_value=None,
)
)
entity0 = platform.ENTITIES[0]
assert await async_setup_component(hass, "number", {"number": {"platform": "test"}})
await hass.async_block_till_done()
assert hass.states.get(entity0.entity_id)
assert entity0.native_max_value == native_max_value
assert entity0.native_min_value == native_min_value
assert entity0.native_step == native_step
assert entity0.native_value == native_value
assert type(entity0.native_value) == native_value_type
assert entity0.native_unit_of_measurement == uom

View file

@ -3,7 +3,7 @@ Provide a mock number platform.
Call init before using it in your tests to ensure clean test data. Call init before using it in your tests to ensure clean test data.
""" """
from homeassistant.components.number import NumberEntity from homeassistant.components.number import NumberEntity, RestoreNumber
from tests.common import MockEntity from tests.common import MockEntity
@ -37,7 +37,7 @@ class MockNumberEntity(MockEntity, NumberEntity):
@property @property
def native_value(self): def native_value(self):
"""Return the native value of this sensor.""" """Return the native value of this number."""
return self._handle("native_value") return self._handle("native_value")
def set_native_value(self, value: float) -> None: def set_native_value(self, value: float) -> None:
@ -45,6 +45,23 @@ class MockNumberEntity(MockEntity, NumberEntity):
self._values["native_value"] = value self._values["native_value"] = value
class MockRestoreNumber(MockNumberEntity, RestoreNumber):
"""Mock RestoreNumber class."""
async def async_added_to_hass(self) -> None:
"""Restore native_*."""
await super().async_added_to_hass()
if (last_number_data := await self.async_get_last_number_data()) is None:
return
self._values["native_max_value"] = last_number_data.native_max_value
self._values["native_min_value"] = last_number_data.native_min_value
self._values["native_step"] = last_number_data.native_step
self._values[
"native_unit_of_measurement"
] = last_number_data.native_unit_of_measurement
self._values["native_value"] = last_number_data.native_value
class LegacyMockNumberEntity(MockEntity, NumberEntity): class LegacyMockNumberEntity(MockEntity, NumberEntity):
"""Mock Number class using deprecated features.""" """Mock Number class using deprecated features."""