diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index 4eef2d372ae..23c94ada4c0 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -160,6 +160,7 @@ ABBREVIATIONS = { "spd_val_tpl": "speed_value_template", "spds": "speeds", "src_type": "source_type", + "stat_cla": "state_class", "stat_clsd": "state_closed", "stat_closing": "state_closing", "stat_off": "state_off", diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index ca399161b25..145af55daa8 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -7,7 +7,11 @@ import functools import voluptuous as vol from homeassistant.components import sensor -from homeassistant.components.sensor import DEVICE_CLASSES_SCHEMA, SensorEntity +from homeassistant.components.sensor import ( + DEVICE_CLASSES_SCHEMA, + STATE_CLASSES_SCHEMA, + SensorEntity, +) from homeassistant.const import ( CONF_DEVICE_CLASS, CONF_FORCE_UPDATE, @@ -33,6 +37,7 @@ from .mixins import ( ) CONF_EXPIRE_AFTER = "expire_after" +CONF_STATE_CLASS = "state_class" DEFAULT_NAME = "MQTT Sensor" DEFAULT_FORCE_UPDATE = False @@ -42,6 +47,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RO_PLATFORM_SCHEMA.extend( vol.Optional(CONF_EXPIRE_AFTER): cv.positive_int, vol.Optional(CONF_FORCE_UPDATE, default=DEFAULT_FORCE_UPDATE): cv.boolean, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, + vol.Optional(CONF_STATE_CLASS): STATE_CLASSES_SCHEMA, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, } ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) @@ -173,6 +179,11 @@ class MqttSensor(MqttEntity, SensorEntity): """Return the device class of the sensor.""" return self._config.get(CONF_DEVICE_CLASS) + @property + def state_class(self) -> str | None: + """Return the state class of the sensor.""" + return self._config.get(CONF_STATE_CLASS) + @property def available(self) -> bool: """Return true if the device is available and value has not expired.""" diff --git a/tests/components/mqtt/test_sensor.py b/tests/components/mqtt/test_sensor.py index c6ebbe98dc4..fe97bdfbfde 100644 --- a/tests/components/mqtt/test_sensor.py +++ b/tests/components/mqtt/test_sensor.py @@ -381,6 +381,51 @@ async def test_valid_device_class(hass, mqtt_mock): assert "device_class" not in state.attributes +async def test_invalid_state_class(hass, mqtt_mock): + """Test state_class option with invalid value.""" + assert await async_setup_component( + hass, + sensor.DOMAIN, + { + sensor.DOMAIN: { + "platform": "mqtt", + "name": "test", + "state_topic": "test-topic", + "state_class": "foobarnotreal", + } + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test") + assert state is None + + +async def test_valid_state_class(hass, mqtt_mock): + """Test state_class option with valid values.""" + assert await async_setup_component( + hass, + "sensor", + { + "sensor": [ + { + "platform": "mqtt", + "name": "Test 1", + "state_topic": "test-topic", + "state_class": "measurement", + }, + {"platform": "mqtt", "name": "Test 2", "state_topic": "test-topic"}, + ] + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("sensor.test_1") + assert state.attributes["state_class"] == "measurement" + state = hass.states.get("sensor.test_2") + assert "state_class" not in state.attributes + + async def test_setting_attribute_via_mqtt_json_message(hass, mqtt_mock): """Test the setting of attribute via MQTT with JSON payload.""" await help_test_setting_attribute_via_mqtt_json_message(