diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index d3252525b76..9cc13ac94bd 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -1,11 +1,14 @@ """Support for MQTT switches.""" +from __future__ import annotations + import functools import voluptuous as vol from homeassistant.components import switch -from homeassistant.components.switch import SwitchEntity +from homeassistant.components.switch import DEVICE_CLASSES_SCHEMA, SwitchEntity from homeassistant.const import ( + CONF_DEVICE_CLASS, CONF_NAME, CONF_OPTIMISTIC, CONF_PAYLOAD_OFF, @@ -48,6 +51,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend( vol.Optional(CONF_STATE_OFF): cv.string, vol.Optional(CONF_STATE_ON): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, + vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, } ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) @@ -158,6 +162,11 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): """Return true if we do optimistic updates.""" return self._optimistic + @property + def device_class(self) -> str | None: + """Return the device class of the sensor.""" + return self._config.get(CONF_DEVICE_CLASS) + async def async_turn_on(self, **kwargs): """Turn the device on. diff --git a/tests/components/mqtt/test_switch.py b/tests/components/mqtt/test_switch.py index 263ec0a2825..a3ef29d0d08 100644 --- a/tests/components/mqtt/test_switch.py +++ b/tests/components/mqtt/test_switch.py @@ -6,7 +6,12 @@ import pytest from homeassistant.components import switch from homeassistant.components.mqtt.switch import MQTT_SWITCH_ATTRIBUTES_BLOCKED -from homeassistant.const import ATTR_ASSUMED_STATE, STATE_OFF, STATE_ON +from homeassistant.const import ( + ATTR_ASSUMED_STATE, + ATTR_DEVICE_CLASS, + STATE_OFF, + STATE_ON, +) import homeassistant.core as ha from homeassistant.setup import async_setup_component @@ -56,6 +61,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock): "command_topic": "command-topic", "payload_on": 1, "payload_off": 0, + "device_class": "switch", } }, ) @@ -63,6 +69,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock): state = hass.states.get("switch.test") assert state.state == STATE_OFF + assert state.attributes.get(ATTR_DEVICE_CLASS) == "switch" assert not state.attributes.get(ATTR_ASSUMED_STATE) async_fire_mqtt_message(hass, "state-topic", "1") @@ -387,6 +394,7 @@ async def test_discovery_update_unchanged_switch(hass, mqtt_mock, caplog): """Test update of discovered switch.""" data1 = ( '{ "name": "Beer",' + ' "device_class": "switch",' ' "state_topic": "test_topic",' ' "command_topic": "test_topic" }' )