diff --git a/homeassistant/components/template/fan.py b/homeassistant/components/template/fan.py index d68f37123bf..78d4829b632 100644 --- a/homeassistant/components/template/fan.py +++ b/homeassistant/components/template/fan.py @@ -23,8 +23,6 @@ from homeassistant.const import ( CONF_FRIENDLY_NAME, CONF_UNIQUE_ID, CONF_VALUE_TEMPLATE, - EVENT_HOMEASSISTANT_START, - MATCH_ALL, STATE_OFF, STATE_ON, STATE_UNAVAILABLE, @@ -36,8 +34,8 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.script import Script -from . import extract_entities, initialise_templates from .const import CONF_AVAILABILITY_TEMPLATE +from .template_entity import TemplateEntityWithAvailability _LOGGER = logging.getLogger(__name__) @@ -104,17 +102,6 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= speed_list = device_config[CONF_SPEED_LIST] unique_id = device_config.get(CONF_UNIQUE_ID) - templates = { - CONF_VALUE_TEMPLATE: state_template, - CONF_SPEED_TEMPLATE: speed_template, - CONF_OSCILLATING_TEMPLATE: oscillating_template, - CONF_DIRECTION_TEMPLATE: direction_template, - CONF_AVAILABILITY_TEMPLATE: availability_template, - } - - initialise_templates(hass, templates) - entity_ids = extract_entities(device, "fan", None, templates) - fans.append( TemplateFan( hass, @@ -131,7 +118,6 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= set_oscillating_action, set_direction_action, speed_list, - entity_ids, unique_id, ) ) @@ -139,7 +125,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= async_add_entities(fans) -class TemplateFan(FanEntity): +class TemplateFan(TemplateEntityWithAvailability, FanEntity): """A template fan component.""" def __init__( @@ -158,10 +144,10 @@ class TemplateFan(FanEntity): set_oscillating_action, set_direction_action, speed_list, - entity_ids, unique_id, ): """Initialize the fan.""" + super().__init__(availability_template) self.hass = hass self.entity_id = async_generate_entity_id( ENTITY_ID_FORMAT, device_id, hass=hass @@ -172,8 +158,6 @@ class TemplateFan(FanEntity): self._speed_template = speed_template self._oscillating_template = oscillating_template self._direction_template = direction_template - self._availability_template = availability_template - self._available = True self._supported_features = 0 domain = __name__.split(".")[-2] @@ -211,7 +195,6 @@ class TemplateFan(FanEntity): if self._direction_template: self._supported_features |= SUPPORT_DIRECTION - self._entities = entity_ids self._unique_id = unique_id # List of valid speeds @@ -257,16 +240,6 @@ class TemplateFan(FanEntity): """Return the oscillation state.""" return self._direction - @property - def should_poll(self): - """Return the polling state.""" - return False - - @property - def available(self): - """Return availability of Device.""" - return self._available - # pylint: disable=arguments-differ async def async_turn_on(self, speed: str = None) -> None: """Turn on the fan.""" @@ -331,125 +304,94 @@ class TemplateFan(FanEntity): ", ".join(_VALID_DIRECTIONS), ) - async def async_added_to_hass(self): - """Register callbacks.""" - - @callback - def template_fan_state_listener(event): - """Handle target device state changes.""" - self.async_schedule_update_ha_state(True) - - @callback - def template_fan_startup(event): - """Update template on startup.""" - if self._entities != MATCH_ALL: - # Track state change only for valid templates - self.hass.helpers.event.async_track_state_change_event( - self._entities, template_fan_state_listener - ) - - self.async_schedule_update_ha_state(True) - - self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, template_fan_startup) - - async def async_update(self): - """Update the state from the template.""" - # Update state - try: - state = self._template.async_render() - except TemplateError as ex: - _LOGGER.error(ex) - state = None + @callback + def _update_state(self, result): + super()._update_state(result) + if isinstance(result, TemplateError): self._state = None + return # Validate state - if state in _VALID_STATES: - self._state = state - elif state in [STATE_UNAVAILABLE, STATE_UNKNOWN]: + if result in _VALID_STATES: + self._state = result + elif result in [STATE_UNAVAILABLE, STATE_UNKNOWN]: self._state = None else: _LOGGER.error( "Received invalid fan is_on state: %s. Expected: %s", - state, + result, ", ".join(_VALID_STATES), ) self._state = None - # Update speed if 'speed_template' is configured + async def async_added_to_hass(self): + """Register callbacks.""" + self.add_template_attribute("_state", self._template, None, self._update_state) if self._speed_template is not None: - try: - speed = self._speed_template.async_render() - except TemplateError as ex: - _LOGGER.error(ex) - speed = None - self._state = None - - # Validate speed - if speed in self._speed_list: - self._speed = speed - elif speed in [STATE_UNAVAILABLE, STATE_UNKNOWN]: - self._speed = None - else: - _LOGGER.error( - "Received invalid speed: %s. Expected: %s", speed, self._speed_list - ) - self._speed = None - - # Update oscillating if 'oscillating_template' is configured + self.add_template_attribute( + "_speed", + self._speed_template, + None, + self._update_speed, + none_on_template_error=True, + ) if self._oscillating_template is not None: - try: - oscillating = self._oscillating_template.async_render() - except TemplateError as ex: - _LOGGER.error(ex) - oscillating = None - self._state = None - - # Validate osc - if oscillating == "True" or oscillating is True: - self._oscillating = True - elif oscillating == "False" or oscillating is False: - self._oscillating = False - elif oscillating in [STATE_UNAVAILABLE, STATE_UNKNOWN]: - self._oscillating = None - else: - _LOGGER.error( - "Received invalid oscillating: %s. Expected: True/False", - oscillating, - ) - self._oscillating = None - - # Update direction if 'direction_template' is configured + self.add_template_attribute( + "_oscillating", + self._oscillating_template, + None, + self._update_oscillating, + none_on_template_error=True, + ) if self._direction_template is not None: - try: - direction = self._direction_template.async_render() - except TemplateError as ex: - _LOGGER.error(ex) - direction = None - self._state = None + self.add_template_attribute( + "_direction", + self._direction_template, + None, + self._update_direction, + none_on_template_error=True, + ) + await super().async_added_to_hass() - # Validate speed - if direction in _VALID_DIRECTIONS: - self._direction = direction - elif direction in [STATE_UNAVAILABLE, STATE_UNKNOWN]: - self._direction = None - else: - _LOGGER.error( - "Received invalid direction: %s. Expected: %s", - direction, - ", ".join(_VALID_DIRECTIONS), - ) - self._direction = None + @callback + def _update_speed(self, speed): + # Validate speed + if speed in self._speed_list: + self._speed = speed + elif speed in [STATE_UNAVAILABLE, STATE_UNKNOWN]: + self._speed = None + else: + _LOGGER.error( + "Received invalid speed: %s. Expected: %s", speed, self._speed_list + ) + self._speed = None - # Update Availability if 'availability_template' is defined - if self._availability_template is not None: - try: - self._available = ( - self._availability_template.async_render().lower() == "true" - ) - except (TemplateError, ValueError) as ex: - _LOGGER.error( - "Could not render %s template %s: %s", - CONF_AVAILABILITY_TEMPLATE, - self._name, - ex, - ) + @callback + def _update_oscillating(self, oscillating): + # Validate osc + if oscillating == "True" or oscillating is True: + self._oscillating = True + elif oscillating == "False" or oscillating is False: + self._oscillating = False + elif oscillating in [STATE_UNAVAILABLE, STATE_UNKNOWN]: + self._oscillating = None + else: + _LOGGER.error( + "Received invalid oscillating: %s. Expected: True/False", oscillating, + ) + self._oscillating = None + + @callback + def _update_direction(self, direction): + # Validate direction + if direction in _VALID_DIRECTIONS: + self._direction = direction + elif direction in [STATE_UNAVAILABLE, STATE_UNKNOWN]: + self._direction = None + else: + _LOGGER.error( + "Received invalid direction: %s. Expected: %s", + direction, + ", ".join(_VALID_DIRECTIONS), + ) + self._direction = None diff --git a/homeassistant/components/template/template_entity.py b/homeassistant/components/template/template_entity.py index 748a130d064..618b240a9d4 100644 --- a/homeassistant/components/template/template_entity.py +++ b/homeassistant/components/template/template_entity.py @@ -26,6 +26,7 @@ class _TemplateAttribute: template: Template, validator: Callable[[Any], Any] = match_all, on_update: Optional[Callable[[Any], None]] = None, + none_on_template_error: Optional[bool] = False, ): """Template attribute.""" self._entity = entity @@ -35,6 +36,7 @@ class _TemplateAttribute: self.on_update = on_update self.async_update = None self.add_complete = False + self.none_on_template_error = none_on_template_error @callback def async_setup(self): @@ -75,7 +77,10 @@ class _TemplateAttribute: self._attribute, self._entity.entity_id, ) - self.on_update(result) + if self.none_on_template_error: + self._default_update(result) + else: + self.on_update(result) self._write_update_if_added() return @@ -139,6 +144,7 @@ class TemplateEntity(Entity): template: Template, validator: Callable[[Any], Any] = match_all, on_update: Optional[Callable[[Any], None]] = None, + none_on_template_error: bool = False, ) -> None: """ Call in the constructor to add a template linked to a attribute. @@ -158,7 +164,9 @@ class TemplateEntity(Entity): if the template or validator resulted in an error. """ - attribute = _TemplateAttribute(self, attribute, template, validator, on_update) + attribute = _TemplateAttribute( + self, attribute, template, validator, on_update, none_on_template_error + ) attribute.async_setup() self._template_attrs.append(attribute) diff --git a/tests/components/template/test_fan.py b/tests/components/template/test_fan.py index a56b55fa123..10af5fd0b74 100644 --- a/tests/components/template/test_fan.py +++ b/tests/components/template/test_fan.py @@ -414,8 +414,9 @@ async def test_invalid_availability_template_keeps_component_available(hass, cap await hass.async_block_till_done() assert hass.states.get("fan.test_fan").state != STATE_UNAVAILABLE - assert ("Could not render availability_template template") in caplog.text - assert ("UndefinedError: 'x' is undefined") in caplog.text + + assert "TemplateError" in caplog.text + assert "x" in caplog.text # End of template tests #