Add device_class to MQTT switch (#58931)

This commit is contained in:
Chris Browet 2021-11-02 17:40:05 +01:00 committed by GitHub
parent 339117aceb
commit 2df1ba2346
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View file

@ -1,11 +1,14 @@
"""Support for MQTT switches.""" """Support for MQTT switches."""
from __future__ import annotations
import functools import functools
import voluptuous as vol import voluptuous as vol
from homeassistant.components import switch from homeassistant.components import switch
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import DEVICE_CLASSES_SCHEMA, SwitchEntity
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_NAME, CONF_NAME,
CONF_OPTIMISTIC, CONF_OPTIMISTIC,
CONF_PAYLOAD_OFF, CONF_PAYLOAD_OFF,
@ -48,6 +51,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend(
vol.Optional(CONF_STATE_OFF): cv.string, vol.Optional(CONF_STATE_OFF): cv.string,
vol.Optional(CONF_STATE_ON): cv.string, vol.Optional(CONF_STATE_ON): cv.string,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
} }
).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema)
@ -158,6 +162,11 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
"""Return true if we do optimistic updates.""" """Return true if we do optimistic updates."""
return self._optimistic return self._optimistic
@property
def device_class(self) -> str | None:
"""Return the device class of the sensor."""
return self._config.get(CONF_DEVICE_CLASS)
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Turn the device on. """Turn the device on.

View file

@ -6,7 +6,12 @@ import pytest
from homeassistant.components import switch from homeassistant.components import switch
from homeassistant.components.mqtt.switch import MQTT_SWITCH_ATTRIBUTES_BLOCKED from homeassistant.components.mqtt.switch import MQTT_SWITCH_ATTRIBUTES_BLOCKED
from homeassistant.const import ATTR_ASSUMED_STATE, STATE_OFF, STATE_ON from homeassistant.const import (
ATTR_ASSUMED_STATE,
ATTR_DEVICE_CLASS,
STATE_OFF,
STATE_ON,
)
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -56,6 +61,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock):
"command_topic": "command-topic", "command_topic": "command-topic",
"payload_on": 1, "payload_on": 1,
"payload_off": 0, "payload_off": 0,
"device_class": "switch",
} }
}, },
) )
@ -63,6 +69,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock):
state = hass.states.get("switch.test") state = hass.states.get("switch.test")
assert state.state == STATE_OFF assert state.state == STATE_OFF
assert state.attributes.get(ATTR_DEVICE_CLASS) == "switch"
assert not state.attributes.get(ATTR_ASSUMED_STATE) assert not state.attributes.get(ATTR_ASSUMED_STATE)
async_fire_mqtt_message(hass, "state-topic", "1") async_fire_mqtt_message(hass, "state-topic", "1")
@ -387,6 +394,7 @@ async def test_discovery_update_unchanged_switch(hass, mqtt_mock, caplog):
"""Test update of discovered switch.""" """Test update of discovered switch."""
data1 = ( data1 = (
'{ "name": "Beer",' '{ "name": "Beer",'
' "device_class": "switch",'
' "state_topic": "test_topic",' ' "state_topic": "test_topic",'
' "command_topic": "test_topic" }' ' "command_topic": "test_topic" }'
) )