diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index a4438020886..deedfdab2de 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -1,20 +1,27 @@ """Support to set a numeric value from a slider or text box.""" import logging +import typing import voluptuous as vol from homeassistant.const import ( + ATTR_EDITABLE, ATTR_MODE, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, + CONF_ID, CONF_MODE, CONF_NAME, SERVICE_RELOAD, ) +from homeassistant.core import callback +from homeassistant.helpers import collection, entity_registry import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.helpers.service +from homeassistant.helpers.storage import Store +from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType _LOGGER = logging.getLogger(__name__) @@ -54,6 +61,28 @@ def _cv_input_number(cfg): return cfg +CREATE_FIELDS = { + vol.Required(CONF_NAME): vol.All(str, vol.Length(min=1)), + vol.Required(CONF_MIN): vol.Coerce(float), + vol.Required(CONF_MAX): vol.Coerce(float), + vol.Optional(CONF_INITIAL): vol.Coerce(float), + vol.Optional(CONF_STEP, default=1): vol.All(vol.Coerce(float), vol.Range(min=1e-3)), + vol.Optional(CONF_ICON): cv.icon, + vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string, + vol.Optional(CONF_MODE, default=MODE_SLIDER): vol.In([MODE_BOX, MODE_SLIDER]), +} + +UPDATE_FIELDS = { + vol.Optional(CONF_NAME): cv.string, + vol.Optional(CONF_MIN): vol.Coerce(float), + vol.Optional(CONF_MAX): vol.Coerce(float), + vol.Optional(CONF_INITIAL): vol.Coerce(float), + vol.Optional(CONF_STEP): vol.All(vol.Coerce(float), vol.Range(min=1e-3)), + vol.Optional(CONF_ICON): cv.icon, + vol.Optional(ATTR_UNIT_OF_MEASUREMENT): cv.string, + vol.Optional(CONF_MODE): vol.In([MODE_BOX, MODE_SLIDER]), +} + CONFIG_SCHEMA = vol.Schema( { DOMAIN: cv.schema_with_slug_keys( @@ -80,22 +109,61 @@ CONFIG_SCHEMA = vol.Schema( extra=vol.ALLOW_EXTRA, ) RELOAD_SERVICE_SCHEMA = vol.Schema({}) +STORAGE_KEY = DOMAIN +STORAGE_VERSION = 1 -async def async_setup(hass, config): +async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: """Set up an input slider.""" component = EntityComponent(_LOGGER, DOMAIN, hass) + id_manager = collection.IDManager() - entities = await _async_process_config(config) + yaml_collection = collection.YamlCollection( + logging.getLogger(f"{__name__}.yaml_collection"), id_manager + ) + collection.attach_entity_component_collection( + component, yaml_collection, InputNumber.from_yaml + ) - async def reload_service_handler(service_call): - """Remove all entities and load new ones from config.""" - conf = await component.async_prepare_reload() - if conf is None: + storage_collection = NumberStorageCollection( + Store(hass, STORAGE_VERSION, STORAGE_KEY), + logging.getLogger(f"{__name__}.storage_collection"), + id_manager, + ) + collection.attach_entity_component_collection( + component, storage_collection, InputNumber + ) + + await yaml_collection.async_load( + [{CONF_ID: id_, **(conf or {})} for id_, conf in config[DOMAIN].items()] + ) + await storage_collection.async_load() + + collection.StorageCollectionWebsocket( + storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + ).async_setup(hass) + + async def _collection_changed( + change_type: str, item_id: str, config: typing.Optional[typing.Dict] + ) -> None: + """Handle a collection change: clean up entity registry on removals.""" + if change_type != collection.CHANGE_REMOVED: return - new_entities = await _async_process_config(conf) - if new_entities: - await component.async_add_entities(new_entities) + + ent_reg = await entity_registry.async_get_registry(hass) + ent_reg.async_remove(ent_reg.async_get_entity_id(DOMAIN, DOMAIN, item_id)) + + yaml_collection.async_add_listener(_collection_changed) + storage_collection.async_add_listener(_collection_changed) + + async def reload_service_handler(service_call: ServiceCallType) -> None: + """Reload yaml entities.""" + conf = await component.async_prepare_reload(skip_reset=True) + if conf is None: + conf = {DOMAIN: {}} + await yaml_collection.async_load( + [{CONF_ID: id_, **conf} for id_, conf in conf[DOMAIN].items()] + ) homeassistant.helpers.service.async_register_admin_service( hass, @@ -115,86 +183,102 @@ async def async_setup(hass, config): component.async_register_entity_service(SERVICE_DECREMENT, {}, "async_decrement") - if entities: - await component.async_add_entities(entities) return True -async def _async_process_config(config): - """Process config and create list of entities.""" - entities = [] +class NumberStorageCollection(collection.StorageCollection): + """Input storage based collection.""" - for object_id, cfg in config[DOMAIN].items(): - name = cfg.get(CONF_NAME) - minimum = cfg.get(CONF_MIN) - maximum = cfg.get(CONF_MAX) - initial = cfg.get(CONF_INITIAL) - step = cfg.get(CONF_STEP) - icon = cfg.get(CONF_ICON) - unit = cfg.get(ATTR_UNIT_OF_MEASUREMENT) - mode = cfg.get(CONF_MODE) + CREATE_SCHEMA = vol.Schema(vol.All(CREATE_FIELDS, _cv_input_number)) + UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) - entities.append( - InputNumber( - object_id, name, initial, minimum, maximum, step, icon, unit, mode - ) - ) + async def _process_create_data(self, data: typing.Dict) -> typing.Dict: + """Validate the config is valid.""" + return self.CREATE_SCHEMA(data) - return entities + @callback + def _get_suggested_id(self, info: typing.Dict) -> str: + """Suggest an ID based on the config.""" + return info[CONF_NAME] + + async def _update_data(self, data: dict, update_data: typing.Dict) -> typing.Dict: + """Return a new updated data object.""" + update_data = self.UPDATE_SCHEMA(update_data) + return _cv_input_number({**data, **update_data}) class InputNumber(RestoreEntity): """Representation of a slider.""" - def __init__( - self, object_id, name, initial, minimum, maximum, step, icon, unit, mode - ): + def __init__(self, config: typing.Dict): """Initialize an input number.""" - self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self._name = name - self._current_value = initial - self._initial = initial - self._minimum = minimum - self._maximum = maximum - self._step = step - self._icon = icon - self._unit = unit - self._mode = mode + self._config = config + self.editable = True + self._current_value = config.get(CONF_INITIAL) + + @classmethod + def from_yaml(cls, config: typing.Dict) -> "InputNumber": + """Return entity instance initialized from yaml storage.""" + input_num = cls(config) + input_num.entity_id = ENTITY_ID_FORMAT.format(config[CONF_ID]) + input_num.editable = False + return input_num @property def should_poll(self): """If entity should be polled.""" return False + @property + def _minimum(self) -> float: + """Return minimum allowed value.""" + return self._config[CONF_MIN] + + @property + def _maximum(self) -> float: + """Return maximum allowed value.""" + return self._config[CONF_MAX] + @property def name(self): """Return the name of the input slider.""" - return self._name + return self._config.get(CONF_NAME) @property def icon(self): """Return the icon to be used for this entity.""" - return self._icon + return self._config.get(CONF_ICON) @property def state(self): """Return the state of the component.""" return self._current_value + @property + def _step(self) -> int: + """Return entity's increment/decrement step.""" + return self._config[CONF_STEP] + @property def unit_of_measurement(self): """Return the unit the value is expressed in.""" - return self._unit + return self._config.get(ATTR_UNIT_OF_MEASUREMENT) + + @property + def unique_id(self) -> typing.Optional[str]: + """Return unique id of the entity.""" + return self._config[CONF_ID] @property def state_attributes(self): """Return the state attributes.""" return { - ATTR_INITIAL: self._initial, + ATTR_INITIAL: self._config.get(CONF_INITIAL), + ATTR_EDITABLE: self.editable, ATTR_MIN: self._minimum, ATTR_MAX: self._maximum, ATTR_STEP: self._step, - ATTR_MODE: self._mode, + ATTR_MODE: self._config[CONF_MODE], } async def async_added_to_hass(self): @@ -224,7 +308,7 @@ class InputNumber(RestoreEntity): ) return self._current_value = num_value - await self.async_update_ha_state() + self.async_write_ha_state() async def async_increment(self): """Increment value.""" @@ -238,7 +322,7 @@ class InputNumber(RestoreEntity): ) return self._current_value = new_value - await self.async_update_ha_state() + self.async_write_ha_state() async def async_decrement(self): """Decrement value.""" @@ -252,4 +336,12 @@ class InputNumber(RestoreEntity): ) return self._current_value = new_value - await self.async_update_ha_state() + self.async_write_ha_state() + + async def async_update_config(self, config: typing.Dict) -> None: + """Handle when the config is updated.""" + self._config = config + # just in case min/max values changed + self._current_value = min(self._current_value, self._maximum) + self._current_value = max(self._current_value, self._minimum) + self.async_write_ha_state() diff --git a/tests/components/input_number/test_init.py b/tests/components/input_number/test_init.py index 6d032b639cf..f9763168354 100644 --- a/tests/components/input_number/test_init.py +++ b/tests/components/input_number/test_init.py @@ -12,15 +12,57 @@ from homeassistant.components.input_number import ( SERVICE_RELOAD, SERVICE_SET_VALUE, ) -from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.const import ( + ATTR_EDITABLE, + ATTR_ENTITY_ID, + ATTR_FRIENDLY_NAME, + ATTR_NAME, +) from homeassistant.core import Context, CoreState, State from homeassistant.exceptions import Unauthorized +from homeassistant.helpers import entity_registry from homeassistant.loader import bind_hass from homeassistant.setup import async_setup_component from tests.common import mock_restore_cache +@pytest.fixture +def storage_setup(hass, hass_storage): + """Storage setup.""" + + async def _storage(items=None, config=None): + if items is None: + hass_storage[DOMAIN] = { + "key": DOMAIN, + "version": 1, + "data": { + "items": [ + { + "id": "from_storage", + "initial": 10, + "name": "from storage", + "max": 100, + "min": 0, + "step": 1, + "mode": "slider", + } + ] + }, + } + else: + hass_storage[DOMAIN] = { + "key": DOMAIN, + "version": 1, + "data": {"items": items}, + } + if config is None: + config = {DOMAIN: {}} + return await async_setup_component(hass, DOMAIN, config) + + return _storage + + @bind_hass def set_value(hass, entity_id, value): """Set input_number to value. @@ -258,19 +300,33 @@ async def test_input_number_context(hass, hass_admin_user): async def test_reload(hass, hass_admin_user, hass_read_only_user): """Test reload service.""" count_start = len(hass.states.async_entity_ids()) + ent_reg = await entity_registry.async_get_registry(hass) assert await async_setup_component( - hass, DOMAIN, {DOMAIN: {"test_1": {"initial": 50, "min": 0, "max": 51}}} + hass, + DOMAIN, + { + DOMAIN: { + "test_1": {"initial": 50, "min": 0, "max": 51}, + "test_3": {"initial": 10, "min": 0, "max": 15}, + } + }, ) - assert count_start + 1 == len(hass.states.async_entity_ids()) + assert count_start + 2 == len(hass.states.async_entity_ids()) state_1 = hass.states.get("input_number.test_1") state_2 = hass.states.get("input_number.test_2") + state_3 = hass.states.get("input_number.test_3") assert state_1 is not None assert state_2 is None + assert state_3 is not None assert 50 == float(state_1.state) + assert 10 == float(state_3.state) + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_1") is not None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_2") is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_3") is not None with patch( "homeassistant.config.load_yaml_config_file", @@ -302,8 +358,189 @@ async def test_reload(hass, hass_admin_user, hass_read_only_user): state_1 = hass.states.get("input_number.test_1") state_2 = hass.states.get("input_number.test_2") + state_3 = hass.states.get("input_number.test_3") assert state_1 is not None assert state_2 is not None - assert 40 == float(state_1.state) + assert state_3 is None + assert 50 == float(state_1.state) assert 20 == float(state_2.state) + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_1") is not None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_2") is not None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, "test_3") is None + + +async def test_load_from_storage(hass, storage_setup): + """Test set up from storage.""" + assert await storage_setup() + state = hass.states.get(f"{DOMAIN}.from_storage") + assert float(state.state) == 10 + assert state.attributes.get(ATTR_FRIENDLY_NAME) == "from storage" + assert state.attributes.get(ATTR_EDITABLE) + + +async def test_editable_state_attribute(hass, storage_setup): + """Test editable attribute.""" + assert await storage_setup( + config={ + DOMAIN: { + "from_yaml": { + "min": 1, + "max": 10, + "initial": 5, + "step": 1, + "mode": "slider", + } + } + } + ) + + state = hass.states.get(f"{DOMAIN}.from_storage") + assert float(state.state) == 10 + assert state.attributes.get(ATTR_FRIENDLY_NAME) == "from storage" + assert state.attributes.get(ATTR_EDITABLE) + + state = hass.states.get(f"{DOMAIN}.from_yaml") + assert float(state.state) == 5 + assert not state.attributes.get(ATTR_EDITABLE) + + +async def test_ws_list(hass, hass_ws_client, storage_setup): + """Test listing via WS.""" + assert await storage_setup( + config={ + DOMAIN: { + "from_yaml": { + "min": 1, + "max": 10, + "initial": 5, + "step": 1, + "mode": "slider", + } + } + } + ) + + client = await hass_ws_client(hass) + + await client.send_json({"id": 6, "type": f"{DOMAIN}/list"}) + resp = await client.receive_json() + assert resp["success"] + + storage_ent = "from_storage" + yaml_ent = "from_yaml" + result = {item["id"]: item for item in resp["result"]} + + assert len(result) == 1 + assert storage_ent in result + assert yaml_ent not in result + assert result[storage_ent][ATTR_NAME] == "from storage" + + +async def test_ws_delete(hass, hass_ws_client, storage_setup): + """Test WS delete cleans up entity registry.""" + assert await storage_setup() + + input_id = "from_storage" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is not None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None + + client = await hass_ws_client(hass) + + await client.send_json( + {"id": 6, "type": f"{DOMAIN}/delete", f"{DOMAIN}_id": f"{input_id}"} + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert state is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None + + +async def test_update_min_max(hass, hass_ws_client, storage_setup): + """Test updating min/max updates the state.""" + + items = [ + { + "id": "from_storage", + "name": "from storage", + "max": 100, + "min": 0, + "step": 1, + "mode": "slider", + } + ] + assert await storage_setup(items) + + input_id = "from_storage" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is not None + assert state.state + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is not None + + client = await hass_ws_client(hass) + + await client.send_json( + {"id": 6, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", "min": 9} + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert float(state.state) == 9 + + await client.send_json( + { + "id": 7, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + "max": 5, + "min": 0, + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert float(state.state) == 5 + + +async def test_ws_create(hass, hass_ws_client, storage_setup): + """Test create WS.""" + assert await storage_setup(items=[]) + + input_id = "new_input" + input_entity_id = f"{DOMAIN}.{input_id}" + ent_reg = await entity_registry.async_get_registry(hass) + + state = hass.states.get(input_entity_id) + assert state is None + assert ent_reg.async_get_entity_id(DOMAIN, DOMAIN, input_id) is None + + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 6, + "type": f"{DOMAIN}/create", + "name": "New Input", + "max": 20, + "min": 0, + "initial": 10, + "step": 1, + "mode": "slider", + } + ) + resp = await client.receive_json() + assert resp["success"] + + state = hass.states.get(input_entity_id) + assert float(state.state) == 10