Check if attributes are present in new_state before accessing them (#71967)

* Check if attributes are present in new_state before accessing them.

* Early return if new state is None|Unknown|Unavailable

* Removed whitespace at line endings. +black run

* Update test for coverage
This commit is contained in:
RoboMagus 2022-05-25 08:44:08 +02:00 committed by GitHub
parent 5dfeb1e02a
commit c1ddde3764
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 6 deletions

View file

@ -191,11 +191,16 @@ class IntegrationSensor(RestoreEntity, SensorEntity):
old_state = event.data.get("old_state")
new_state = event.data.get("new_state")
if new_state is None or new_state.state in (
STATE_UNKNOWN,
STATE_UNAVAILABLE,
):
return
# We may want to update our state before an early return,
# based on the source sensor's unit_of_measurement
# or device_class.
update_state = False
unit = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
if unit is not None:
new_unit_of_measurement = self._unit_template.format(unit)
@ -214,11 +219,9 @@ class IntegrationSensor(RestoreEntity, SensorEntity):
if update_state:
self.async_write_ha_state()
if (
old_state is None
or new_state is None
or old_state.state in (STATE_UNKNOWN, STATE_UNAVAILABLE)
or new_state.state in (STATE_UNKNOWN, STATE_UNAVAILABLE)
if old_state is None or old_state.state in (
STATE_UNKNOWN,
STATE_UNAVAILABLE,
):
return

View file

@ -9,6 +9,7 @@ from homeassistant.const import (
ENERGY_WATT_HOUR,
POWER_KILO_WATT,
POWER_WATT,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
TIME_SECONDS,
)
@ -350,6 +351,15 @@ async def test_units(hass):
# they became valid
assert state.attributes.get("unit_of_measurement") == ENERGY_WATT_HOUR
# When source state goes to None / Unknown, expect an early exit without
# changes to the state or unit_of_measurement
hass.states.async_set(entity_id, STATE_UNAVAILABLE, None)
await hass.async_block_till_done()
new_state = hass.states.get("sensor.integration")
assert state == new_state
assert state.attributes.get("unit_of_measurement") == ENERGY_WATT_HOUR
async def test_device_class(hass):
"""Test integration sensor units using a power source."""