Add command_template option to mqtt switch schema (#122103)

This commit is contained in:
Jan Bouwhuis 2024-07-19 12:10:49 +02:00 committed by GitHub
parent c1c5cff993
commit 16434b5306
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 66 additions and 9 deletions

View file

@ -28,9 +28,19 @@ from homeassistant.helpers.typing import ConfigType
from . import subscription
from .config import MQTT_RW_SCHEMA
from .const import CONF_COMMAND_TOPIC, CONF_STATE_TOPIC, PAYLOAD_NONE
from .const import (
CONF_COMMAND_TEMPLATE,
CONF_COMMAND_TOPIC,
CONF_STATE_TOPIC,
PAYLOAD_NONE,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import MqttValueTemplate, ReceiveMessage
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
)
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
DEFAULT_NAME = "MQTT Switch"
@ -41,6 +51,7 @@ CONF_STATE_OFF = "state_off"
PLATFORM_SCHEMA_MODERN = MQTT_RW_SCHEMA.extend(
{
vol.Optional(CONF_COMMAND_TEMPLATE): cv.template,
vol.Optional(CONF_NAME): vol.Any(cv.string, None),
vol.Optional(CONF_PAYLOAD_OFF, default=DEFAULT_PAYLOAD_OFF): cv.string,
vol.Optional(CONF_PAYLOAD_ON, default=DEFAULT_PAYLOAD_ON): cv.string,
@ -79,6 +90,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
_optimistic: bool
_is_on_map: dict[str | bytes, bool | None]
_command_template: Callable[[PublishPayloadType], PublishPayloadType]
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
@staticmethod
@ -100,8 +112,11 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
config[CONF_OPTIMISTIC] or config.get(CONF_STATE_TOPIC) is None
)
self._attr_assumed_state = bool(self._optimistic)
self._command_template = MqttCommandTemplate(
config.get(CONF_COMMAND_TEMPLATE), entity=self
).async_render
self._value_template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), entity=self
config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
@ -132,9 +147,8 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
This method is a coroutine.
"""
await self.async_publish_with_config(
self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_ON]
)
payload = self._command_template(self._config[CONF_PAYLOAD_ON])
await self.async_publish_with_config(self._config[CONF_COMMAND_TOPIC], payload)
if self._optimistic:
# Optimistically assume that switch has changed state.
self._attr_is_on = True
@ -145,9 +159,8 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
This method is a coroutine.
"""
await self.async_publish_with_config(
self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_OFF]
)
payload = self._command_template(self._config[CONF_PAYLOAD_OFF])
await self.async_publish_with_config(self._config[CONF_COMMAND_TOPIC], payload)
if self._optimistic:
# Optimistically assume that switch has changed state.
self._attr_is_on = False

View file

@ -191,6 +191,50 @@ async def test_sending_inital_state_and_optimistic(
assert state.attributes.get(ATTR_ASSUMED_STATE)
@pytest.mark.parametrize(
"hass_config",
[
{
mqtt.DOMAIN: {
switch.DOMAIN: {
"name": "test",
"command_topic": "command-topic",
"command_template": '{"state": "{{ value }}"}',
"payload_on": "beer on",
"payload_off": "beer off",
"qos": "2",
}
}
}
],
)
async def test_sending_mqtt_commands_with_command_template(
hass: HomeAssistant, mqtt_mock_entry: MqttMockHAClientGenerator
) -> None:
"""Test the sending MQTT commands using a command template."""
fake_state = State("switch.test", "on")
mock_restore_cache(hass, (fake_state,))
mqtt_mock = await mqtt_mock_entry()
state = hass.states.get("switch.test")
assert state.state == STATE_ON
assert state.attributes.get(ATTR_ASSUMED_STATE)
await common.async_turn_on(hass, "switch.test")
mqtt_mock.async_publish.assert_called_once_with(
"command-topic", '{"state": "beer on"}', 2, False
)
mqtt_mock.async_publish.reset_mock()
await common.async_turn_off(hass, "switch.test")
mqtt_mock.async_publish.assert_called_once_with(
"command-topic", '{"state": "beer off"}', 2, False
)
@pytest.mark.parametrize(
"hass_config",
[