* Always do thread safety check when writing state Refactor the 3 most common places where the thread safety check for the event loop to be inline to make the check fast enough that we can keep it long term. While code review catches most of the thread safety issues in core, some of them still make it through, and new ones keep getting added. Its not possible to catch them all with manual code review, so its worth the tiny overhead to check each time. Previously the check was limited to custom components because they were the most common source of thread safety issues. * Always do thread safety check when writing state Refactor the 3 most common places where the thread safety check for the event loop to be inline to make the check fast enough that we can keep it long term. While code review catches most of the thread safety issues in core, some of them still make it through, and new ones keep getting added. Its not possible to catch them all with manual code review, so its worth the tiny overhead to check each time. Previously the check was limited to custom components because they were the most common source of thread safety issues. * async_fire is more common than expected with ccs * fix mock * fix hass mocking
2641 lines
81 KiB
Python
2641 lines
81 KiB
Python
"""Test the entity helper."""
|
|
|
|
import asyncio
|
|
from collections.abc import Iterable
|
|
import dataclasses
|
|
from datetime import timedelta
|
|
from enum import IntFlag
|
|
from functools import cached_property
|
|
import logging
|
|
import threading
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, PropertyMock, patch
|
|
|
|
from freezegun.api import FrozenDateTimeFactory
|
|
import pytest
|
|
from syrupy.assertion import SnapshotAssertion
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.const import (
|
|
ATTR_ATTRIBUTION,
|
|
ATTR_DEVICE_CLASS,
|
|
ATTR_FRIENDLY_NAME,
|
|
STATE_UNAVAILABLE,
|
|
STATE_UNKNOWN,
|
|
)
|
|
from homeassistant.core import (
|
|
Context,
|
|
HassJobType,
|
|
HomeAssistant,
|
|
HomeAssistantError,
|
|
ReleaseChannel,
|
|
callback,
|
|
)
|
|
from homeassistant.helpers import device_registry as dr, entity, entity_registry as er
|
|
from homeassistant.helpers.entity_component import async_update_entity
|
|
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
|
|
|
from tests.common import (
|
|
MockConfigEntry,
|
|
MockEntity,
|
|
MockEntityPlatform,
|
|
MockModule,
|
|
MockPlatform,
|
|
mock_integration,
|
|
mock_registry,
|
|
)
|
|
|
|
|
|
def test_generate_entity_id_requires_hass_or_ids() -> None:
|
|
"""Ensure we require at least hass or current ids."""
|
|
with pytest.raises(ValueError):
|
|
entity.generate_entity_id("test.{}", "hello world")
|
|
|
|
|
|
def test_generate_entity_id_given_keys() -> None:
|
|
"""Test generating an entity id given current ids."""
|
|
assert (
|
|
entity.generate_entity_id(
|
|
"test.{}",
|
|
"overwrite hidden true",
|
|
current_ids=["test.overwrite_hidden_true"],
|
|
)
|
|
== "test.overwrite_hidden_true_2"
|
|
)
|
|
assert (
|
|
entity.generate_entity_id(
|
|
"test.{}", "overwrite hidden true", current_ids=["test.another_entity"]
|
|
)
|
|
== "test.overwrite_hidden_true"
|
|
)
|
|
|
|
|
|
async def test_generate_entity_id_given_hass(hass: HomeAssistant) -> None:
|
|
"""Test generating an entity id given hass object."""
|
|
hass.states.async_set("test.overwrite_hidden_true", "test")
|
|
|
|
fmt = "test.{}"
|
|
assert (
|
|
entity.generate_entity_id(fmt, "overwrite hidden true", hass=hass)
|
|
== "test.overwrite_hidden_true_2"
|
|
)
|
|
|
|
|
|
async def test_async_update_support(hass: HomeAssistant) -> None:
|
|
"""Test async update getting called."""
|
|
sync_update = []
|
|
async_update = []
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""A test entity."""
|
|
|
|
entity_id = "sensor.test"
|
|
|
|
def update(self):
|
|
"""Update entity."""
|
|
sync_update.append([1])
|
|
|
|
ent = AsyncEntity()
|
|
ent.hass = hass
|
|
|
|
await ent.async_update_ha_state(True)
|
|
|
|
assert len(sync_update) == 1
|
|
assert len(async_update) == 0
|
|
|
|
async def async_update_func():
|
|
"""Async update."""
|
|
async_update.append(1)
|
|
|
|
ent.async_update = async_update_func
|
|
|
|
await ent.async_update_ha_state(True)
|
|
|
|
assert len(sync_update) == 1
|
|
assert len(async_update) == 1
|
|
|
|
|
|
async def test_device_class(hass: HomeAssistant) -> None:
|
|
"""Test device class attribute."""
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.overwrite_hidden_true"
|
|
ent.hass = hass
|
|
ent.async_write_ha_state()
|
|
state = hass.states.get(ent.entity_id)
|
|
assert state.attributes.get(ATTR_DEVICE_CLASS) is None
|
|
|
|
ent._attr_device_class = "test_class"
|
|
ent.async_write_ha_state()
|
|
state = hass.states.get(ent.entity_id)
|
|
assert state.attributes.get(ATTR_DEVICE_CLASS) == "test_class"
|
|
|
|
|
|
async def test_warn_slow_update(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Warn we log when entity update takes a long time."""
|
|
update_call = False
|
|
|
|
async def async_update():
|
|
"""Mock async update."""
|
|
nonlocal update_call
|
|
await asyncio.sleep(0.00001)
|
|
update_call = True
|
|
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.async_update = async_update
|
|
|
|
fast_update_time = 0.0000001
|
|
|
|
with patch.object(entity, "SLOW_UPDATE_WARNING", fast_update_time):
|
|
await mock_entity.async_update_ha_state(True)
|
|
assert str(fast_update_time) in caplog.text
|
|
assert mock_entity.entity_id in caplog.text
|
|
assert update_call
|
|
|
|
|
|
async def test_warn_slow_update_with_exception(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Warn we log when entity update takes a long time and trow exception."""
|
|
update_call = False
|
|
|
|
async def async_update():
|
|
"""Mock async update."""
|
|
nonlocal update_call
|
|
update_call = True
|
|
await asyncio.sleep(0.00001)
|
|
raise AssertionError("Fake update error")
|
|
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.async_update = async_update
|
|
|
|
fast_update_time = 0.0000001
|
|
|
|
with patch.object(entity, "SLOW_UPDATE_WARNING", fast_update_time):
|
|
await mock_entity.async_update_ha_state(True)
|
|
assert str(fast_update_time) in caplog.text
|
|
assert mock_entity.entity_id in caplog.text
|
|
assert update_call
|
|
|
|
|
|
async def test_warn_slow_device_update_disabled(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Disable slow update warning with async_device_update."""
|
|
update_call = False
|
|
|
|
async def async_update():
|
|
"""Mock async update."""
|
|
nonlocal update_call
|
|
await asyncio.sleep(0.00001)
|
|
update_call = True
|
|
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.async_update = async_update
|
|
|
|
fast_update_time = 0.0000001
|
|
|
|
with patch.object(entity, "SLOW_UPDATE_WARNING", fast_update_time):
|
|
await mock_entity.async_device_update(warning=False)
|
|
assert str(fast_update_time) not in caplog.text
|
|
assert mock_entity.entity_id not in caplog.text
|
|
assert update_call
|
|
|
|
|
|
async def test_async_schedule_update_ha_state(hass: HomeAssistant) -> None:
|
|
"""Warn we log when entity update takes a long time and trow exception."""
|
|
update_call = False
|
|
|
|
async def async_update():
|
|
"""Mock async update."""
|
|
nonlocal update_call
|
|
update_call = True
|
|
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.async_update = async_update
|
|
|
|
mock_entity.async_schedule_update_ha_state(True)
|
|
await hass.async_block_till_done()
|
|
|
|
assert update_call is True
|
|
|
|
|
|
async def test_async_async_request_call_without_lock(hass: HomeAssistant) -> None:
|
|
"""Test for async_requests_call works without a lock."""
|
|
updates = []
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
|
|
async def testhelper(self, count):
|
|
"""Helper function."""
|
|
updates.append(count)
|
|
|
|
ent_1 = AsyncEntity("light.test_1")
|
|
ent_2 = AsyncEntity("light.test_2")
|
|
try:
|
|
job1 = ent_1.async_request_call(ent_1.testhelper(1))
|
|
job2 = ent_2.async_request_call(ent_2.testhelper(2))
|
|
|
|
await asyncio.gather(job1, job2)
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
finally:
|
|
pass
|
|
|
|
assert len(updates) == 2
|
|
updates.sort()
|
|
assert updates == [1, 2]
|
|
|
|
|
|
async def test_async_async_request_call_with_lock(hass: HomeAssistant) -> None:
|
|
"""Test for async_requests_call works with a semaphore."""
|
|
updates = []
|
|
|
|
test_semaphore = asyncio.Semaphore(1)
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id, lock):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self.parallel_updates = lock
|
|
|
|
async def testhelper(self, count):
|
|
"""Helper function."""
|
|
updates.append(count)
|
|
|
|
ent_1 = AsyncEntity("light.test_1", test_semaphore)
|
|
ent_2 = AsyncEntity("light.test_2", test_semaphore)
|
|
|
|
try:
|
|
assert test_semaphore.locked() is False
|
|
await test_semaphore.acquire()
|
|
assert test_semaphore.locked()
|
|
|
|
job1 = ent_1.async_request_call(ent_1.testhelper(1))
|
|
job2 = ent_2.async_request_call(ent_2.testhelper(2))
|
|
|
|
hass.async_create_task(job1)
|
|
hass.async_create_task(job2)
|
|
|
|
assert len(updates) == 0
|
|
assert updates == []
|
|
assert test_semaphore._value == 0
|
|
|
|
test_semaphore.release()
|
|
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
finally:
|
|
test_semaphore.release()
|
|
|
|
assert len(updates) == 2
|
|
updates.sort()
|
|
assert updates == [1, 2]
|
|
|
|
|
|
async def test_async_parallel_updates_with_zero(hass: HomeAssistant) -> None:
|
|
"""Test parallel updates with 0 (disabled)."""
|
|
updates = []
|
|
test_lock = asyncio.Event()
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id, count):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self._count = count
|
|
|
|
async def async_update(self):
|
|
"""Test update."""
|
|
updates.append(self._count)
|
|
await test_lock.wait()
|
|
|
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
|
|
|
try:
|
|
ent_1.async_schedule_update_ha_state(True)
|
|
ent_2.async_schedule_update_ha_state(True)
|
|
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 2
|
|
assert updates == [1, 2]
|
|
finally:
|
|
test_lock.set()
|
|
|
|
|
|
async def test_async_parallel_updates_with_zero_on_sync_update(
|
|
hass: HomeAssistant,
|
|
) -> None:
|
|
"""Test parallel updates with 0 (disabled)."""
|
|
updates = []
|
|
test_lock = threading.Event()
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id, count):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self._count = count
|
|
|
|
def update(self):
|
|
"""Test update."""
|
|
updates.append(self._count)
|
|
if not test_lock.wait(timeout=1):
|
|
# if timeout populate more data to fail the test
|
|
updates.append(self._count)
|
|
|
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
|
|
|
try:
|
|
ent_1.async_schedule_update_ha_state(True)
|
|
ent_2.async_schedule_update_ha_state(True)
|
|
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 2
|
|
assert updates == [1, 2]
|
|
finally:
|
|
test_lock.set()
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
async def test_async_parallel_updates_with_one(hass: HomeAssistant) -> None:
|
|
"""Test parallel updates with 1 (sequential)."""
|
|
updates = []
|
|
test_lock = asyncio.Lock()
|
|
test_semaphore = asyncio.Semaphore(1)
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id, count):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self._count = count
|
|
self.parallel_updates = test_semaphore
|
|
|
|
async def async_update(self):
|
|
"""Test update."""
|
|
updates.append(self._count)
|
|
await test_lock.acquire()
|
|
|
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
|
ent_3 = AsyncEntity("sensor.test_3", 3)
|
|
|
|
await test_lock.acquire()
|
|
|
|
try:
|
|
ent_1.async_schedule_update_ha_state(True)
|
|
ent_2.async_schedule_update_ha_state(True)
|
|
ent_3.async_schedule_update_ha_state(True)
|
|
|
|
while True:
|
|
if len(updates) >= 1:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 1
|
|
assert updates == [1]
|
|
|
|
updates.clear()
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
|
|
while True:
|
|
if len(updates) >= 1:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 1
|
|
assert updates == [2]
|
|
|
|
updates.clear()
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
|
|
while True:
|
|
if len(updates) >= 1:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 1
|
|
assert updates == [3]
|
|
|
|
updates.clear()
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
|
|
finally:
|
|
# we may have more than one lock need to release in case test failed
|
|
for _ in updates:
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
test_lock.release()
|
|
|
|
|
|
async def test_async_parallel_updates_with_two(hass: HomeAssistant) -> None:
|
|
"""Test parallel updates with 2 (parallel)."""
|
|
updates = []
|
|
test_lock = asyncio.Lock()
|
|
test_semaphore = asyncio.Semaphore(2)
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id, count):
|
|
"""Initialize Async test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self._count = count
|
|
self.parallel_updates = test_semaphore
|
|
|
|
async def async_update(self):
|
|
"""Test update."""
|
|
updates.append(self._count)
|
|
await test_lock.acquire()
|
|
|
|
ent_1 = AsyncEntity("sensor.test_1", 1)
|
|
ent_2 = AsyncEntity("sensor.test_2", 2)
|
|
ent_3 = AsyncEntity("sensor.test_3", 3)
|
|
ent_4 = AsyncEntity("sensor.test_4", 4)
|
|
|
|
await test_lock.acquire()
|
|
|
|
try:
|
|
ent_1.async_schedule_update_ha_state(True)
|
|
ent_2.async_schedule_update_ha_state(True)
|
|
ent_3.async_schedule_update_ha_state(True)
|
|
ent_4.async_schedule_update_ha_state(True)
|
|
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 2
|
|
assert updates == [1, 2]
|
|
|
|
updates.clear()
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
|
|
while True:
|
|
if len(updates) >= 2:
|
|
break
|
|
await asyncio.sleep(0)
|
|
|
|
assert len(updates) == 2
|
|
assert updates == [3, 4]
|
|
|
|
updates.clear()
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
finally:
|
|
# we may have more than one lock need to release in case test failed
|
|
for _ in updates:
|
|
test_lock.release()
|
|
await asyncio.sleep(0)
|
|
test_lock.release()
|
|
|
|
|
|
async def test_async_parallel_updates_with_one_using_executor(
|
|
hass: HomeAssistant,
|
|
) -> None:
|
|
"""Test parallel updates with 1 (sequential) using the executor."""
|
|
test_semaphore = asyncio.Semaphore(1)
|
|
locked = []
|
|
|
|
class SyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def __init__(self, entity_id):
|
|
"""Initialize sync test entity."""
|
|
self.entity_id = entity_id
|
|
self.hass = hass
|
|
self.parallel_updates = test_semaphore
|
|
|
|
def update(self):
|
|
"""Test update."""
|
|
locked.append(self.parallel_updates.locked())
|
|
|
|
entities = [SyncEntity(f"sensor.test_{i}") for i in range(3)]
|
|
|
|
await asyncio.gather(
|
|
*[
|
|
hass.async_create_task(
|
|
ent.async_update_ha_state(True),
|
|
f"Entity schedule update ha state {ent.entity_id}",
|
|
)
|
|
for ent in entities
|
|
]
|
|
)
|
|
|
|
assert locked == [True, True, True]
|
|
|
|
|
|
async def test_async_remove_no_platform(hass: HomeAssistant) -> None:
|
|
"""Test async_remove method when no platform set."""
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "test.test"
|
|
ent.async_write_ha_state()
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
await ent.async_remove()
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
|
|
|
|
async def test_async_remove_runs_callbacks(hass: HomeAssistant) -> None:
|
|
"""Test async_remove runs on_remove callback."""
|
|
result = []
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.test"
|
|
await platform.async_add_entities([ent])
|
|
ent.async_on_remove(lambda: result.append(1))
|
|
await ent.async_remove()
|
|
assert len(result) == 1
|
|
|
|
|
|
async def test_async_remove_ignores_in_flight_polling(hass: HomeAssistant) -> None:
|
|
"""Test in flight polling is ignored after removing."""
|
|
result = []
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.test"
|
|
ent.async_on_remove(lambda: result.append(1))
|
|
await platform.async_add_entities([ent])
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
|
|
# Remove the entity from the entity registry
|
|
await ent.async_remove()
|
|
assert len(result) == 1
|
|
assert hass.states.get("test.test") is None
|
|
|
|
# Simulate an in-flight poll after the entity was removed
|
|
ent.async_write_ha_state()
|
|
assert len(result) == 1
|
|
assert hass.states.get("test.test") is None
|
|
|
|
|
|
async def test_async_remove_twice(hass: HomeAssistant) -> None:
|
|
"""Test removing an entity twice only cleans up once."""
|
|
result = []
|
|
|
|
class MockEntity(entity.Entity):
|
|
def __init__(self) -> None:
|
|
self.remove_calls = []
|
|
|
|
async def async_will_remove_from_hass(self):
|
|
self.remove_calls.append(None)
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = MockEntity()
|
|
ent.hass = hass
|
|
ent.entity_id = "test.test"
|
|
ent.async_on_remove(lambda: result.append(1))
|
|
await platform.async_add_entities([ent])
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
|
|
await ent.async_remove()
|
|
assert len(result) == 1
|
|
assert len(ent.remove_calls) == 1
|
|
|
|
await ent.async_remove()
|
|
assert len(result) == 1
|
|
assert len(ent.remove_calls) == 1
|
|
|
|
|
|
async def test_set_context(hass: HomeAssistant) -> None:
|
|
"""Test setting context."""
|
|
context = Context()
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.async_set_context(context)
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("hello.world").context == context
|
|
|
|
|
|
async def test_set_context_expired(hass: HomeAssistant) -> None:
|
|
"""Test setting context."""
|
|
context = Context()
|
|
|
|
with patch("homeassistant.helpers.entity.CONTEXT_RECENT_TIME_SECONDS", -5):
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.async_set_context(context)
|
|
ent.async_write_ha_state()
|
|
|
|
assert hass.states.get("hello.world").context != context
|
|
assert ent._context is None
|
|
assert ent._context_set is None
|
|
|
|
|
|
async def test_warn_disabled(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test we warn once if we write to a disabled entity."""
|
|
entry = er.RegistryEntry(
|
|
entity_id="hello.world",
|
|
unique_id="test-unique-id",
|
|
platform="test-platform",
|
|
disabled_by=er.RegistryEntryDisabler.USER,
|
|
)
|
|
mock_registry(hass, {"hello.world": entry})
|
|
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.registry_entry = entry
|
|
ent.platform = MagicMock(platform_name="test-platform")
|
|
|
|
caplog.clear()
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("hello.world") is None
|
|
assert "Entity hello.world is incorrectly being triggered" in caplog.text
|
|
|
|
caplog.clear()
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("hello.world") is None
|
|
assert caplog.text == ""
|
|
|
|
|
|
async def test_disabled_in_entity_registry(hass: HomeAssistant) -> None:
|
|
"""Test entity is removed if we disable entity registry entry."""
|
|
entry = er.RegistryEntry(
|
|
entity_id="hello.world",
|
|
unique_id="test-unique-id",
|
|
platform="test-platform",
|
|
disabled_by=None,
|
|
)
|
|
registry = mock_registry(hass, {"hello.world": entry})
|
|
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.registry_entry = entry
|
|
assert ent.enabled is True
|
|
|
|
ent.add_to_platform_start(hass, MagicMock(platform_name="test-platform"), None)
|
|
await ent.add_to_platform_finish()
|
|
assert hass.states.get("hello.world") is not None
|
|
|
|
entry2 = registry.async_update_entity(
|
|
"hello.world", disabled_by=er.RegistryEntryDisabler.USER
|
|
)
|
|
await hass.async_block_till_done()
|
|
assert entry2 != entry
|
|
assert ent.registry_entry == entry2
|
|
assert ent.enabled is False
|
|
assert hass.states.get("hello.world") is None
|
|
|
|
entry3 = registry.async_update_entity("hello.world", disabled_by=None)
|
|
await hass.async_block_till_done()
|
|
assert entry3 != entry2
|
|
# Entry is no longer updated, entity is no longer tracking changes
|
|
assert ent.registry_entry == entry2
|
|
|
|
|
|
async def test_capability_attrs(hass: HomeAssistant) -> None:
|
|
"""Test we still include capabilities even when unavailable."""
|
|
with (
|
|
patch.object(entity.Entity, "available", PropertyMock(return_value=False)),
|
|
patch.object(
|
|
entity.Entity,
|
|
"capability_attributes",
|
|
PropertyMock(return_value={"always": "there"}),
|
|
),
|
|
):
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.async_write_ha_state()
|
|
|
|
state = hass.states.get("hello.world")
|
|
assert state is not None
|
|
assert state.state == STATE_UNAVAILABLE
|
|
assert state.attributes["always"] == "there"
|
|
|
|
|
|
async def test_warn_slow_write_state(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Check that we log a warning if reading properties takes too long."""
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.platform = MagicMock(platform_name="hue")
|
|
|
|
with patch("homeassistant.helpers.entity.timer", side_effect=[0, 10]):
|
|
mock_entity.async_write_ha_state()
|
|
|
|
assert (
|
|
"Updating state for comp_test.test_entity "
|
|
"(<class 'homeassistant.helpers.entity.Entity'>) "
|
|
"took 10.000 seconds. Please create a bug report at "
|
|
"https://github.com/home-assistant/core/issues?"
|
|
"q=is%3Aopen+is%3Aissue+label%3A%22integration%3A+hue%22"
|
|
) in caplog.text
|
|
|
|
|
|
async def test_warn_slow_write_state_custom_component(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Check that we log a warning if reading properties takes too long."""
|
|
|
|
class CustomComponentEntity(entity.Entity):
|
|
"""Custom component entity."""
|
|
|
|
__module__ = "custom_components.bla.sensor"
|
|
|
|
mock_entity = CustomComponentEntity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
mock_entity.platform = MagicMock(platform_name="hue")
|
|
|
|
with patch("homeassistant.helpers.entity.timer", side_effect=[0, 10]):
|
|
mock_entity.async_write_ha_state()
|
|
|
|
assert (
|
|
"Updating state for comp_test.test_entity (<class 'custom_components.bla.sensor"
|
|
".test_warn_slow_write_state_custom_component.<locals>.CustomComponentEntity'>)"
|
|
" took 10.000 seconds. Please report it to the author of the 'hue' custom "
|
|
"integration"
|
|
) in caplog.text
|
|
|
|
|
|
async def test_setup_source(hass: HomeAssistant) -> None:
|
|
"""Check that we register sources correctly."""
|
|
platform = MockEntityPlatform(hass)
|
|
|
|
entity_platform = MockEntity(name="Platform Config Source")
|
|
await platform.async_add_entities([entity_platform])
|
|
|
|
platform.config_entry = MockConfigEntry()
|
|
entity_entry = MockEntity(name="Config Entry Source")
|
|
await platform.async_add_entities([entity_entry])
|
|
|
|
assert entity.entity_sources(hass) == {
|
|
"test_domain.platform_config_source": {
|
|
"custom_component": False,
|
|
"domain": "test_platform",
|
|
},
|
|
"test_domain.config_entry_source": {
|
|
"config_entry": platform.config_entry.entry_id,
|
|
"custom_component": False,
|
|
"domain": "test_platform",
|
|
},
|
|
}
|
|
|
|
await platform.async_reset()
|
|
|
|
assert entity.entity_sources(hass) == {}
|
|
|
|
|
|
async def test_removing_entity_unavailable(hass: HomeAssistant) -> None:
|
|
"""Test removing an entity that is still registered creates an unavailable state."""
|
|
platform = MockEntityPlatform(hass, domain="hello")
|
|
ent = entity.Entity()
|
|
ent.entity_id = "hello.world"
|
|
ent._attr_unique_id = "test-unique-id"
|
|
await platform.async_add_entities([ent])
|
|
|
|
state = hass.states.get("hello.world")
|
|
assert state is not None
|
|
assert state.state == STATE_UNKNOWN
|
|
|
|
await ent.async_remove()
|
|
|
|
state = hass.states.get("hello.world")
|
|
assert state is not None
|
|
assert state.state == STATE_UNAVAILABLE
|
|
|
|
|
|
async def test_get_supported_features_entity_registry(
|
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
|
) -> None:
|
|
"""Test get_supported_features falls back to entity registry."""
|
|
entity_id = entity_registry.async_get_or_create(
|
|
"hello", "world", "5678", supported_features=456
|
|
).entity_id
|
|
assert entity.get_supported_features(hass, entity_id) == 456
|
|
|
|
|
|
async def test_get_supported_features_prioritize_state(
|
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
|
) -> None:
|
|
"""Test get_supported_features gives priority to state."""
|
|
entity_id = entity_registry.async_get_or_create(
|
|
"hello", "world", "5678", supported_features=456
|
|
).entity_id
|
|
assert entity.get_supported_features(hass, entity_id) == 456
|
|
|
|
hass.states.async_set(entity_id, None, {"supported_features": 123})
|
|
|
|
assert entity.get_supported_features(hass, entity_id) == 123
|
|
|
|
|
|
async def test_get_supported_features_raises_on_unknown(hass: HomeAssistant) -> None:
|
|
"""Test get_supported_features raises on unknown entity_id."""
|
|
with pytest.raises(HomeAssistantError):
|
|
entity.get_supported_features(hass, "hello.world")
|
|
|
|
|
|
async def test_float_conversion(hass: HomeAssistant) -> None:
|
|
"""Test conversion of float state to string rounds."""
|
|
assert 2.4 + 1.2 != 3.6
|
|
with patch.object(entity.Entity, "state", PropertyMock(return_value=2.4 + 1.2)):
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.entity_id = "hello.world"
|
|
ent.async_write_ha_state()
|
|
|
|
state = hass.states.get("hello.world")
|
|
assert state is not None
|
|
assert state.state == "3.6"
|
|
|
|
|
|
async def test_attribution_attribute(hass: HomeAssistant) -> None:
|
|
"""Test attribution attribute."""
|
|
mock_entity = entity.Entity()
|
|
mock_entity.hass = hass
|
|
mock_entity.entity_id = "hello.world"
|
|
mock_entity._attr_attribution = "Home Assistant"
|
|
|
|
mock_entity.async_schedule_update_ha_state(True)
|
|
await hass.async_block_till_done()
|
|
|
|
state = hass.states.get(mock_entity.entity_id)
|
|
assert state.attributes.get(ATTR_ATTRIBUTION) == "Home Assistant"
|
|
|
|
|
|
async def test_entity_category_property(hass: HomeAssistant) -> None:
|
|
"""Test entity category property."""
|
|
mock_entity1 = entity.Entity()
|
|
mock_entity1.hass = hass
|
|
mock_entity1.entity_description = entity.EntityDescription(
|
|
key="abc", entity_category="ignore_me"
|
|
)
|
|
mock_entity1.entity_id = "hello.world"
|
|
mock_entity1._attr_entity_category = entity.EntityCategory.CONFIG
|
|
assert mock_entity1.entity_category == "config"
|
|
|
|
mock_entity2 = entity.Entity()
|
|
mock_entity2.hass = hass
|
|
mock_entity2.entity_description = entity.EntityDescription(
|
|
key="abc", entity_category=entity.EntityCategory.CONFIG
|
|
)
|
|
mock_entity2.entity_id = "hello.world"
|
|
assert mock_entity2.entity_category == "config"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("value", "expected"),
|
|
[
|
|
("config", entity.EntityCategory.CONFIG),
|
|
("diagnostic", entity.EntityCategory.DIAGNOSTIC),
|
|
],
|
|
)
|
|
def test_entity_category_schema(value, expected) -> None:
|
|
"""Test entity category schema."""
|
|
schema = vol.Schema(entity.ENTITY_CATEGORIES_SCHEMA)
|
|
result = schema(value)
|
|
assert result == expected
|
|
assert isinstance(result, entity.EntityCategory)
|
|
|
|
|
|
@pytest.mark.parametrize("value", [None, "non_existing"])
|
|
def test_entity_category_schema_error(value) -> None:
|
|
"""Test entity category schema."""
|
|
schema = vol.Schema(entity.ENTITY_CATEGORIES_SCHEMA)
|
|
with pytest.raises(
|
|
vol.Invalid,
|
|
match=r"expected EntityCategory or one of 'config', 'diagnostic'",
|
|
):
|
|
schema(value)
|
|
|
|
|
|
async def test_entity_description_fallback() -> None:
|
|
"""Test entity description has same defaults as entity."""
|
|
ent = entity.Entity()
|
|
ent_with_description = entity.Entity()
|
|
ent_with_description.entity_description = entity.EntityDescription(key="test")
|
|
|
|
for field in dataclasses.fields(entity.EntityDescription._dataclass):
|
|
if field.name == "key":
|
|
continue
|
|
|
|
assert getattr(ent, field.name) == getattr(ent_with_description, field.name)
|
|
|
|
|
|
async def _test_friendly_name(
|
|
hass: HomeAssistant,
|
|
ent: entity.Entity,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name."""
|
|
|
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
|
"""Mock setup entry method."""
|
|
async_add_entities([ent])
|
|
return True
|
|
|
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
|
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
|
config_entry.add_to_hass(hass)
|
|
entity_platform = MockEntityPlatform(
|
|
hass, platform_name=config_entry.domain, platform=platform
|
|
)
|
|
|
|
assert await entity_platform.async_setup_entry(config_entry)
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
state = hass.states.async_all()[0]
|
|
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name
|
|
|
|
await async_update_entity(hass, ent.entity_id)
|
|
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"has_entity_name",
|
|
"entity_name",
|
|
"device_name",
|
|
"expected_friendly_name",
|
|
),
|
|
[
|
|
(False, "Entity Blu", "Device Bla", "Entity Blu"),
|
|
(False, None, "Device Bla", None),
|
|
(True, "Entity Blu", "Device Bla", "Device Bla Entity Blu"),
|
|
(True, None, "Device Bla", "Device Bla"),
|
|
(True, "Entity Blu", UNDEFINED, "Entity Blu"),
|
|
(True, "Entity Blu", None, "Mock Title Entity Blu"),
|
|
],
|
|
)
|
|
async def test_friendly_name_attr(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
entity_name: str | None,
|
|
device_name: str | None | UndefinedType,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity uses _attr_*."""
|
|
|
|
ent = MockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": device_name,
|
|
},
|
|
)
|
|
ent._attr_has_entity_name = has_entity_name
|
|
ent._attr_name = entity_name
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("has_entity_name", "entity_name", "expected_friendly_name"),
|
|
[
|
|
(False, "Entity Blu", "Entity Blu"),
|
|
(False, None, None),
|
|
(False, UNDEFINED, None),
|
|
(True, "Entity Blu", "Device Bla Entity Blu"),
|
|
(True, None, "Device Bla"),
|
|
(True, UNDEFINED, "Device Bla None"),
|
|
],
|
|
)
|
|
async def test_friendly_name_description(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
entity_name: str | None,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity has an entity description."""
|
|
|
|
ent = MockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
)
|
|
ent.entity_description = entity.EntityDescription(
|
|
"test", has_entity_name=has_entity_name, name=entity_name
|
|
)
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("has_entity_name", "entity_name", "expected_friendly_name"),
|
|
[
|
|
(False, "Entity Blu", "Entity Blu"),
|
|
(False, None, None),
|
|
(False, UNDEFINED, None),
|
|
(True, "Entity Blu", "Device Bla Entity Blu"),
|
|
(True, None, "Device Bla"),
|
|
(True, UNDEFINED, "Device Bla English cls"),
|
|
],
|
|
)
|
|
async def test_friendly_name_description_device_class_name(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
entity_name: str | None,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity has an entity description."""
|
|
|
|
translations = {
|
|
"en": {"component.test_domain.entity_component.test_class.name": "English cls"},
|
|
}
|
|
|
|
async def async_get_translations(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
category: str,
|
|
integrations: Iterable[str] | None = None,
|
|
config_flow: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return all backend translations."""
|
|
return translations[language]
|
|
|
|
class DeviceClassNameMockEntity(MockEntity):
|
|
def _default_to_device_class_name(self) -> bool:
|
|
"""Return True if an unnamed entity should be named by its device class."""
|
|
return True
|
|
|
|
ent = DeviceClassNameMockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
)
|
|
ent.entity_description = entity.EntityDescription(
|
|
"test",
|
|
device_class="test_class",
|
|
has_entity_name=has_entity_name,
|
|
name=entity_name,
|
|
)
|
|
with patch(
|
|
"homeassistant.helpers.entity_platform.translation.async_get_translations",
|
|
side_effect=async_get_translations,
|
|
):
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"has_entity_name",
|
|
"translation_key",
|
|
"translations",
|
|
"placeholders",
|
|
"expected_friendly_name",
|
|
),
|
|
[
|
|
(False, None, None, None, "Entity Blu"),
|
|
(True, None, None, None, "Device Bla Entity Blu"),
|
|
(
|
|
True,
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "English ent"
|
|
},
|
|
},
|
|
None,
|
|
"Device Bla English ent",
|
|
),
|
|
(
|
|
True,
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "{placeholder} English ent"
|
|
},
|
|
},
|
|
{"placeholder": "special"},
|
|
"Device Bla special English ent",
|
|
),
|
|
(
|
|
True,
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "English ent {placeholder}"
|
|
},
|
|
},
|
|
{"placeholder": "special"},
|
|
"Device Bla English ent special",
|
|
),
|
|
],
|
|
)
|
|
async def test_entity_name_translation_placeholders(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
translation_key: str | None,
|
|
translations: dict[str, str] | None,
|
|
placeholders: dict[str, str] | None,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity name translation has placeholders."""
|
|
|
|
async def async_get_translations(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
category: str,
|
|
integrations: Iterable[str] | None = None,
|
|
config_flow: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return all backend translations."""
|
|
return translations[language]
|
|
|
|
ent = MockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
)
|
|
ent.entity_description = entity.EntityDescription(
|
|
"test",
|
|
has_entity_name=has_entity_name,
|
|
translation_key=translation_key,
|
|
name="Entity Blu",
|
|
)
|
|
if placeholders is not None:
|
|
ent._attr_translation_placeholders = placeholders
|
|
with patch(
|
|
"homeassistant.helpers.entity_platform.translation.async_get_translations",
|
|
side_effect=async_get_translations,
|
|
):
|
|
await _test_friendly_name(hass, ent, expected_friendly_name)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"translation_key",
|
|
"translations",
|
|
"placeholders",
|
|
"release_channel",
|
|
"expected_error",
|
|
),
|
|
[
|
|
(
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "{placeholder} English ent {2ndplaceholder}"
|
|
},
|
|
},
|
|
{"placeholder": "special"},
|
|
ReleaseChannel.STABLE,
|
|
(
|
|
"has translation placeholders '{'placeholder': 'special'}' which do "
|
|
"not match the name '{placeholder} English ent {2ndplaceholder}'"
|
|
),
|
|
),
|
|
(
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "{placeholder} English ent {2ndplaceholder}"
|
|
},
|
|
},
|
|
{"placeholder": "special"},
|
|
ReleaseChannel.BETA,
|
|
"HomeAssistantError: Missing placeholder '2ndplaceholder'",
|
|
),
|
|
(
|
|
"test_entity",
|
|
{
|
|
"en": {
|
|
"component.test.entity.test_domain.test_entity.name": "{placeholder} English ent"
|
|
},
|
|
},
|
|
None,
|
|
ReleaseChannel.STABLE,
|
|
(
|
|
"has translation placeholders '{}' which do "
|
|
"not match the name '{placeholder} English ent'"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
async def test_entity_name_translation_placeholder_errors(
|
|
hass: HomeAssistant,
|
|
translation_key: str | None,
|
|
translations: dict[str, str] | None,
|
|
placeholders: dict[str, str] | None,
|
|
release_channel: ReleaseChannel,
|
|
expected_error: str,
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test entity name translation has placeholder issues."""
|
|
|
|
async def async_get_translations(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
category: str,
|
|
integrations: Iterable[str] | None = None,
|
|
config_flow: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return all backend translations."""
|
|
return translations[language]
|
|
|
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
|
"""Mock setup entry method."""
|
|
async_add_entities([ent])
|
|
return True
|
|
|
|
ent = MockEntity(
|
|
unique_id="qwer",
|
|
)
|
|
ent.entity_description = entity.EntityDescription(
|
|
"test",
|
|
has_entity_name=True,
|
|
translation_key=translation_key,
|
|
name="Entity Blu",
|
|
)
|
|
if placeholders is not None:
|
|
ent._attr_translation_placeholders = placeholders
|
|
|
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
|
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
|
config_entry.add_to_hass(hass)
|
|
entity_platform = MockEntityPlatform(
|
|
hass, platform_name=config_entry.domain, platform=platform
|
|
)
|
|
|
|
caplog.clear()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.helpers.entity_platform.translation.async_get_translations",
|
|
side_effect=async_get_translations,
|
|
),
|
|
patch(
|
|
"homeassistant.helpers.entity.get_release_channel",
|
|
return_value=release_channel,
|
|
),
|
|
):
|
|
await entity_platform.async_setup_entry(config_entry)
|
|
|
|
assert expected_error in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("has_entity_name", "entity_name", "expected_friendly_name"),
|
|
[
|
|
(False, "Entity Blu", "Entity Blu"),
|
|
(False, None, None),
|
|
(False, UNDEFINED, None),
|
|
(True, "Entity Blu", "Device Bla Entity Blu"),
|
|
(True, None, "Device Bla"),
|
|
(True, UNDEFINED, "Device Bla None"),
|
|
],
|
|
)
|
|
async def test_friendly_name_property(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
entity_name: str | None,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity has overridden the name property."""
|
|
|
|
ent = MockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
has_entity_name=has_entity_name,
|
|
name=entity_name,
|
|
)
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("has_entity_name", "entity_name", "expected_friendly_name"),
|
|
[
|
|
(False, "Entity Blu", "Entity Blu"),
|
|
(False, None, None),
|
|
(False, UNDEFINED, None),
|
|
(True, "Entity Blu", "Device Bla Entity Blu"),
|
|
(True, None, "Device Bla"),
|
|
# Won't use the device class name because the entity overrides the name property
|
|
(True, UNDEFINED, "Device Bla None"),
|
|
],
|
|
)
|
|
async def test_friendly_name_property_device_class_name(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
entity_name: str | None,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity has overridden the name property."""
|
|
|
|
translations = {
|
|
"en": {"component.test_domain.entity_component.test_class.name": "English cls"},
|
|
}
|
|
|
|
async def async_get_translations(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
category: str,
|
|
integrations: Iterable[str] | None = None,
|
|
config_flow: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return all backend translations."""
|
|
return translations[language]
|
|
|
|
class DeviceClassNameMockEntity(MockEntity):
|
|
def _default_to_device_class_name(self) -> bool:
|
|
"""Return True if an unnamed entity should be named by its device class."""
|
|
return True
|
|
|
|
ent = DeviceClassNameMockEntity(
|
|
unique_id="qwer",
|
|
device_class="test_class",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
has_entity_name=has_entity_name,
|
|
name=entity_name,
|
|
)
|
|
with patch(
|
|
"homeassistant.helpers.entity_platform.translation.async_get_translations",
|
|
side_effect=async_get_translations,
|
|
):
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("has_entity_name", "expected_friendly_name"),
|
|
[
|
|
(False, None),
|
|
(True, "Device Bla English cls"),
|
|
],
|
|
)
|
|
async def test_friendly_name_device_class_name(
|
|
hass: HomeAssistant,
|
|
has_entity_name: bool,
|
|
expected_friendly_name: str | None,
|
|
) -> None:
|
|
"""Test friendly name when the entity has not set the name in any way."""
|
|
|
|
translations = {
|
|
"en": {"component.test_domain.entity_component.test_class.name": "English cls"},
|
|
}
|
|
|
|
async def async_get_translations(
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
category: str,
|
|
integrations: Iterable[str] | None = None,
|
|
config_flow: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return all backend translations."""
|
|
return translations[language]
|
|
|
|
class DeviceClassNameMockEntity(MockEntity):
|
|
def _default_to_device_class_name(self) -> bool:
|
|
"""Return True if an unnamed entity should be named by its device class."""
|
|
return True
|
|
|
|
ent = DeviceClassNameMockEntity(
|
|
unique_id="qwer",
|
|
device_class="test_class",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
has_entity_name=has_entity_name,
|
|
)
|
|
with patch(
|
|
"homeassistant.helpers.entity_platform.translation.async_get_translations",
|
|
side_effect=async_get_translations,
|
|
):
|
|
await _test_friendly_name(
|
|
hass,
|
|
ent,
|
|
expected_friendly_name,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"entity_name",
|
|
"expected_friendly_name1",
|
|
"expected_friendly_name2",
|
|
"expected_friendly_name3",
|
|
),
|
|
[
|
|
(
|
|
"Entity Blu",
|
|
"Device Bla Entity Blu",
|
|
"Device Bla2 Entity Blu",
|
|
"New Device Entity Blu",
|
|
),
|
|
(
|
|
None,
|
|
"Device Bla",
|
|
"Device Bla2",
|
|
"New Device",
|
|
),
|
|
],
|
|
)
|
|
async def test_friendly_name_updated(
|
|
hass: HomeAssistant,
|
|
device_registry: dr.DeviceRegistry,
|
|
entity_registry: er.EntityRegistry,
|
|
entity_name: str | None,
|
|
expected_friendly_name1: str,
|
|
expected_friendly_name2: str,
|
|
expected_friendly_name3: str,
|
|
) -> None:
|
|
"""Test friendly name is updated when device or entity registry updates."""
|
|
|
|
async def async_setup_entry(hass, config_entry, async_add_entities):
|
|
"""Mock setup entry method."""
|
|
async_add_entities(
|
|
[
|
|
MockEntity(
|
|
unique_id="qwer",
|
|
device_info={
|
|
"identifiers": {("hue", "1234")},
|
|
"connections": {(dr.CONNECTION_NETWORK_MAC, "abcd")},
|
|
"name": "Device Bla",
|
|
},
|
|
has_entity_name=True,
|
|
name=entity_name,
|
|
),
|
|
]
|
|
)
|
|
return True
|
|
|
|
platform = MockPlatform(async_setup_entry=async_setup_entry)
|
|
config_entry = MockConfigEntry(entry_id="super-mock-id")
|
|
config_entry.add_to_hass(hass)
|
|
entity_platform = MockEntityPlatform(
|
|
hass, platform_name=config_entry.domain, platform=platform
|
|
)
|
|
|
|
assert await entity_platform.async_setup_entry(config_entry)
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
state = hass.states.async_all()[0]
|
|
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name1
|
|
|
|
device = device_registry.async_get_device(identifiers={("hue", "1234")})
|
|
device_registry.async_update_device(device.id, name_by_user="Device Bla2")
|
|
await hass.async_block_till_done()
|
|
|
|
state = hass.states.async_all()[0]
|
|
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name2
|
|
|
|
device = device_registry.async_get_or_create(
|
|
config_entry_id=config_entry.entry_id,
|
|
identifiers={("hue", "5678")},
|
|
name="New Device",
|
|
)
|
|
entity_registry.async_update_entity(state.entity_id, device_id=device.id)
|
|
await hass.async_block_till_done()
|
|
|
|
state = hass.states.async_all()[0]
|
|
assert state.attributes.get(ATTR_FRIENDLY_NAME) == expected_friendly_name3
|
|
|
|
|
|
async def test_translation_key(hass: HomeAssistant) -> None:
|
|
"""Test translation key property."""
|
|
mock_entity1 = entity.Entity()
|
|
mock_entity1.hass = hass
|
|
mock_entity1.entity_description = entity.EntityDescription(
|
|
key="abc", translation_key="from_entity_description"
|
|
)
|
|
mock_entity1.entity_id = "hello.world"
|
|
mock_entity1._attr_translation_key = "from_attr"
|
|
assert mock_entity1.translation_key == "from_attr"
|
|
|
|
mock_entity2 = entity.Entity()
|
|
mock_entity2.hass = hass
|
|
mock_entity2.entity_description = entity.EntityDescription(
|
|
key="abc", translation_key="from_entity_description"
|
|
)
|
|
mock_entity2.entity_id = "hello.world"
|
|
assert mock_entity2.translation_key == "from_entity_description"
|
|
|
|
|
|
async def test_repr(hass) -> None:
|
|
"""Test Entity.__repr__."""
|
|
|
|
class MyEntity(MockEntity):
|
|
"""Mock entity."""
|
|
|
|
@property
|
|
def state(self):
|
|
"""Return the state."""
|
|
raise ValueError("Boom")
|
|
|
|
platform = MockEntityPlatform(hass, domain="hello")
|
|
my_entity = MyEntity(entity_id="test.test", available=False)
|
|
|
|
# Not yet added
|
|
assert str(my_entity) == "<entity unknown.unknown=unknown>"
|
|
|
|
# Added
|
|
await platform.async_add_entities([my_entity])
|
|
assert str(my_entity) == "<entity test.test=unavailable>"
|
|
|
|
# Removed
|
|
await platform.async_remove_entity(my_entity.entity_id)
|
|
assert str(my_entity) == "<entity unknown.unknown=unknown>"
|
|
|
|
|
|
async def test_warn_using_async_update_ha_state(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test we warn once when using async_update_ha_state without force_update."""
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.platform = MockEntityPlatform(hass)
|
|
ent.entity_id = "hello.world"
|
|
error_message = "is using self.async_update_ha_state()"
|
|
|
|
# When forcing, it should not trigger the warning
|
|
caplog.clear()
|
|
await ent.async_update_ha_state(force_refresh=True)
|
|
assert error_message not in caplog.text
|
|
|
|
# When not forcing, it should trigger the warning
|
|
caplog.clear()
|
|
await ent.async_update_ha_state()
|
|
assert error_message in caplog.text
|
|
|
|
# When not forcing, it should not trigger the warning again
|
|
caplog.clear()
|
|
await ent.async_update_ha_state()
|
|
assert error_message not in caplog.text
|
|
|
|
|
|
async def test_warn_no_platform(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test we warn am entity does not have a platform."""
|
|
ent = entity.Entity()
|
|
ent.hass = hass
|
|
ent.platform = MockEntityPlatform(hass)
|
|
ent.entity_id = "hello.world"
|
|
error_message = "does not have a platform"
|
|
|
|
# No warning if the entity has a platform
|
|
caplog.clear()
|
|
ent.async_write_ha_state()
|
|
assert error_message not in caplog.text
|
|
|
|
# Without a platform, it should trigger the warning
|
|
ent.platform = None
|
|
caplog.clear()
|
|
ent.async_write_ha_state()
|
|
assert error_message in caplog.text
|
|
|
|
# Without a platform, it should not trigger the warning again
|
|
caplog.clear()
|
|
ent.async_write_ha_state()
|
|
assert error_message not in caplog.text
|
|
|
|
|
|
async def test_invalid_state(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test the entity helper catches InvalidState and sets state to unknown."""
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.test"
|
|
ent.hass = hass
|
|
|
|
ent._attr_state = "x" * 255
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("test.test").state == "x" * 255
|
|
|
|
caplog.clear()
|
|
ent._attr_state = "x" * 256
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
assert (
|
|
"homeassistant.helpers.entity",
|
|
logging.ERROR,
|
|
f"Failed to set state for test.test, fall back to {STATE_UNKNOWN}",
|
|
) in caplog.record_tuples
|
|
|
|
ent._attr_state = "x" * 255
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get("test.test").state == "x" * 255
|
|
|
|
|
|
async def test_suggest_report_issue_built_in(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test _suggest_report_issue for an entity from a built-in integration."""
|
|
mock_entity = entity.Entity()
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
|
|
suggestion = mock_entity._suggest_report_issue()
|
|
assert suggestion == (
|
|
"create a bug report at https://github.com/home-assistant/core/issues"
|
|
"?q=is%3Aopen+is%3Aissue"
|
|
)
|
|
|
|
mock_integration(hass, MockModule(domain="test"), built_in=True)
|
|
platform = MockEntityPlatform(hass, domain="comp_test", platform_name="test")
|
|
await platform.async_add_entities([mock_entity])
|
|
|
|
suggestion = mock_entity._suggest_report_issue()
|
|
assert suggestion == (
|
|
"create a bug report at https://github.com/home-assistant/core/issues"
|
|
"?q=is%3Aopen+is%3Aissue+label%3A%22integration%3A+test%22"
|
|
)
|
|
|
|
|
|
async def test_suggest_report_issue_custom_component(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test _suggest_report_issue for an entity from a custom component."""
|
|
|
|
class CustomComponentEntity(entity.Entity):
|
|
"""Custom component entity."""
|
|
|
|
__module__ = "custom_components.bla.sensor"
|
|
|
|
mock_entity = CustomComponentEntity()
|
|
mock_entity.entity_id = "comp_test.test_entity"
|
|
|
|
suggestion = mock_entity._suggest_report_issue()
|
|
assert suggestion == "report it to the custom integration author"
|
|
|
|
mock_integration(
|
|
hass,
|
|
MockModule(
|
|
domain="test", partial_manifest={"issue_tracker": "https://some_url"}
|
|
),
|
|
built_in=False,
|
|
)
|
|
platform = MockEntityPlatform(hass, domain="comp_test", platform_name="test")
|
|
await platform.async_add_entities([mock_entity])
|
|
|
|
suggestion = mock_entity._suggest_report_issue()
|
|
assert suggestion == "create a bug report at https://some_url"
|
|
|
|
|
|
async def test_reuse_entity_object_after_abort(
|
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Test reuse entity object."""
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = entity.Entity()
|
|
ent.entity_id = "invalid"
|
|
await platform.async_add_entities([ent])
|
|
assert "Invalid entity ID: invalid" in caplog.text
|
|
await platform.async_add_entities([ent])
|
|
assert (
|
|
"Entity 'invalid' cannot be added a second time to an entity platform"
|
|
in caplog.text
|
|
)
|
|
|
|
|
|
async def test_reuse_entity_object_after_entity_registry_remove(
|
|
hass: HomeAssistant,
|
|
entity_registry: er.EntityRegistry,
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test reuse entity object."""
|
|
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
|
platform = MockEntityPlatform(hass, domain="test", platform_name="test")
|
|
ent = entity.Entity()
|
|
ent._attr_unique_id = "5678"
|
|
await platform.async_add_entities([ent])
|
|
assert ent.registry_entry is entry
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
|
|
entity_registry.async_remove(entry.entity_id)
|
|
await hass.async_block_till_done()
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
|
|
await platform.async_add_entities([ent])
|
|
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
|
|
|
|
async def test_reuse_entity_object_after_entity_registry_disabled(
|
|
hass: HomeAssistant,
|
|
entity_registry: er.EntityRegistry,
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test reuse entity object."""
|
|
entry = entity_registry.async_get_or_create("test", "test", "5678")
|
|
platform = MockEntityPlatform(hass, domain="test", platform_name="test")
|
|
ent = entity.Entity()
|
|
ent._attr_unique_id = "5678"
|
|
await platform.async_add_entities([ent])
|
|
assert ent.registry_entry is entry
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
|
|
entity_registry.async_update_entity(
|
|
entry.entity_id, disabled_by=er.RegistryEntryDisabler.USER
|
|
)
|
|
await hass.async_block_till_done()
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
|
|
await platform.async_add_entities([ent])
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
assert "Entity 'test.test_5678' cannot be added a second time" in caplog.text
|
|
|
|
|
|
async def test_change_entity_id(
|
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
|
) -> None:
|
|
"""Test changing entity id."""
|
|
result = []
|
|
|
|
entry = entity_registry.async_get_or_create(
|
|
"test", "test_platform", "5678", suggested_object_id="test"
|
|
)
|
|
assert entry.entity_id == "test.test"
|
|
|
|
class MockEntity(entity.Entity):
|
|
_attr_unique_id = "5678"
|
|
|
|
def __init__(self) -> None:
|
|
self.added_calls = []
|
|
self.remove_calls = []
|
|
|
|
async def async_added_to_hass(self):
|
|
self.added_calls.append(None)
|
|
self.async_on_remove(lambda: result.append(1))
|
|
|
|
async def async_will_remove_from_hass(self):
|
|
self.remove_calls.append(None)
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = MockEntity()
|
|
await platform.async_add_entities([ent])
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
assert len(ent.added_calls) == 1
|
|
|
|
entry = entity_registry.async_update_entity(
|
|
entry.entity_id, new_entity_id="test.test2"
|
|
)
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(result) == 1
|
|
assert len(ent.added_calls) == 2
|
|
assert len(ent.remove_calls) == 1
|
|
|
|
entity_registry.async_update_entity(entry.entity_id, new_entity_id="test.test3")
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(result) == 2
|
|
assert len(ent.added_calls) == 3
|
|
assert len(ent.remove_calls) == 2
|
|
|
|
|
|
def test_entity_description_as_dataclass(snapshot: SnapshotAssertion):
|
|
"""Test EntityDescription behaves like a dataclass."""
|
|
|
|
obj = entity.EntityDescription("blah", device_class="test")
|
|
with pytest.raises(dataclasses.FrozenInstanceError):
|
|
obj.name = "mutate"
|
|
with pytest.raises(dataclasses.FrozenInstanceError):
|
|
delattr(obj, "name")
|
|
|
|
assert dataclasses.is_dataclass(obj)
|
|
assert obj == snapshot
|
|
assert obj == entity.EntityDescription("blah", device_class="test")
|
|
assert repr(obj) == snapshot
|
|
|
|
|
|
def test_extending_entity_description(snapshot: SnapshotAssertion):
|
|
"""Test extending entity descriptions."""
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FrozenEntityDescription(entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = FrozenEntityDescription("blah", extra="foo", name="name")
|
|
assert obj == snapshot
|
|
assert obj == FrozenEntityDescription("blah", extra="foo", name="name")
|
|
assert repr(obj) == snapshot
|
|
|
|
# Try mutating
|
|
with pytest.raises(dataclasses.FrozenInstanceError):
|
|
obj.name = "mutate"
|
|
with pytest.raises(dataclasses.FrozenInstanceError):
|
|
delattr(obj, "name")
|
|
|
|
@dataclasses.dataclass
|
|
class ThawedEntityDescription(entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ThawedEntityDescription("blah", extra="foo", name="name")
|
|
assert obj == snapshot
|
|
assert obj == ThawedEntityDescription("blah", extra="foo", name="name")
|
|
assert repr(obj) == snapshot
|
|
|
|
# Try mutating
|
|
obj.name = "mutate"
|
|
assert obj.name == "mutate"
|
|
delattr(obj, "key")
|
|
assert not hasattr(obj, "key")
|
|
|
|
# Try multiple levels of FrozenOrThawed
|
|
class ExtendedEntityDescription(entity.EntityDescription, frozen_or_thawed=True):
|
|
extension: str = None
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class MyExtendedEntityDescription(ExtendedEntityDescription):
|
|
extra: str = None
|
|
|
|
obj = MyExtendedEntityDescription("blah", extension="ext", extra="foo", name="name")
|
|
assert obj == snapshot
|
|
assert obj == MyExtendedEntityDescription(
|
|
"blah", extension="ext", extra="foo", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
# Try multiple direct parents
|
|
@dataclasses.dataclass(frozen=True)
|
|
class MyMixin1:
|
|
mixin: str
|
|
|
|
@dataclasses.dataclass
|
|
class MyMixin2:
|
|
mixin: str
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class MyMixin3:
|
|
mixin: str = None
|
|
|
|
@dataclasses.dataclass
|
|
class MyMixin4:
|
|
mixin: str = None
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class ComplexEntityDescription1A(MyMixin1, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription1A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription1A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class ComplexEntityDescription1B(entity.EntityDescription, MyMixin1):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription1B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription1B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ComplexEntityDescription1C(MyMixin1, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription1C(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription1C(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ComplexEntityDescription1D(entity.EntityDescription, MyMixin1):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription1D(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription1D(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(kw_only=True)
|
|
class ComplexEntityDescription2A(MyMixin2, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription2A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription2A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(kw_only=True)
|
|
class ComplexEntityDescription2B(entity.EntityDescription, MyMixin2):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription2B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription2B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass
|
|
class ComplexEntityDescription2C(MyMixin2, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription2C(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription2C(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass
|
|
class ComplexEntityDescription2D(entity.EntityDescription, MyMixin2):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription2D(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription2D(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class ComplexEntityDescription3A(MyMixin3, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription3A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription3A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class ComplexEntityDescription3B(entity.EntityDescription, MyMixin3):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription3B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription3B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ComplexEntityDescription3C(MyMixin3, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ComplexEntityDescription3D(entity.EntityDescription, MyMixin3):
|
|
extra: str = None
|
|
|
|
@dataclasses.dataclass(kw_only=True)
|
|
class ComplexEntityDescription4A(MyMixin4, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription4A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription4A(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
@dataclasses.dataclass(kw_only=True)
|
|
class ComplexEntityDescription4B(entity.EntityDescription, MyMixin4):
|
|
extra: str = None
|
|
|
|
obj = ComplexEntityDescription4B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert obj == snapshot
|
|
assert obj == ComplexEntityDescription4B(
|
|
key="blah", extra="foo", mixin="mixin", name="name"
|
|
)
|
|
assert repr(obj) == snapshot
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
@dataclasses.dataclass
|
|
class ComplexEntityDescription4C(MyMixin4, entity.EntityDescription):
|
|
extra: str = None
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
@dataclasses.dataclass
|
|
class ComplexEntityDescription4D(entity.EntityDescription, MyMixin4):
|
|
extra: str = None
|
|
|
|
# Try inheriting with custom init
|
|
@dataclasses.dataclass
|
|
class CustomInitEntityDescription(entity.EntityDescription):
|
|
def __init__(self, extra, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.extra: str = extra
|
|
|
|
obj = CustomInitEntityDescription(key="blah", extra="foo", name="name")
|
|
assert obj == snapshot
|
|
assert obj == CustomInitEntityDescription(key="blah", extra="foo", name="name")
|
|
assert repr(obj) == snapshot
|
|
|
|
|
|
async def test_update_capabilities(
|
|
hass: HomeAssistant,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test entity capabilities are updated automatically."""
|
|
platform = MockEntityPlatform(hass)
|
|
|
|
ent = MockEntity(unique_id="qwer")
|
|
await platform.async_add_entities([ent])
|
|
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.device_class is None
|
|
assert entry.supported_features == 0
|
|
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["device_class"] = "some_class"
|
|
ent._values["supported_features"] = 127
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities == {"bla": "blu"}
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == 127
|
|
|
|
ent._values["capability_attributes"] = None
|
|
ent._values["device_class"] = None
|
|
ent._values["supported_features"] = None
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.original_device_class is None
|
|
assert entry.supported_features == 0
|
|
|
|
# Device class can be overridden by user, make sure that does not break the
|
|
# automatic updating.
|
|
entity_registry.async_update_entity(ent.entity_id, device_class="set_by_user")
|
|
await hass.async_block_till_done()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.original_device_class is None
|
|
assert entry.supported_features == 0
|
|
|
|
# This will not trigger a state change because the device class is shadowed
|
|
# by the entity registry
|
|
ent._values["device_class"] = "some_class"
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == 0
|
|
|
|
|
|
async def test_update_capabilities_no_unique_id(
|
|
hass: HomeAssistant,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test entity capabilities are updated automatically."""
|
|
platform = MockEntityPlatform(hass)
|
|
|
|
ent = MockEntity()
|
|
await platform.async_add_entities([ent])
|
|
|
|
assert entity_registry.async_get(ent.entity_id) is None
|
|
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["supported_features"] = 127
|
|
ent.async_write_ha_state()
|
|
assert entity_registry.async_get(ent.entity_id) is None
|
|
|
|
|
|
async def test_update_capabilities_too_often(
|
|
hass: HomeAssistant,
|
|
caplog: pytest.LogCaptureFixture,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test entity capabilities are updated automatically."""
|
|
capabilities_too_often_warning = "is updating its capabilities too often"
|
|
platform = MockEntityPlatform(hass)
|
|
|
|
ent = MockEntity(unique_id="qwer")
|
|
await platform.async_add_entities([ent])
|
|
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.device_class is None
|
|
assert entry.supported_features == 0
|
|
|
|
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["device_class"] = "some_class"
|
|
ent._values["supported_features"] = supported_features
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities == {"bla": "blu"}
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == supported_features
|
|
|
|
assert capabilities_too_often_warning not in caplog.text
|
|
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["device_class"] = "some_class"
|
|
ent._values["supported_features"] = supported_features + 1
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities == {"bla": "blu"}
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == supported_features + 1
|
|
|
|
assert capabilities_too_often_warning in caplog.text
|
|
|
|
|
|
async def test_update_capabilities_too_often_cooldown(
|
|
hass: HomeAssistant,
|
|
caplog: pytest.LogCaptureFixture,
|
|
entity_registry: er.EntityRegistry,
|
|
freezer: FrozenDateTimeFactory,
|
|
) -> None:
|
|
"""Test entity capabilities are updated automatically."""
|
|
capabilities_too_often_warning = "is updating its capabilities too often"
|
|
platform = MockEntityPlatform(hass)
|
|
|
|
ent = MockEntity(unique_id="qwer")
|
|
await platform.async_add_entities([ent])
|
|
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities is None
|
|
assert entry.device_class is None
|
|
assert entry.supported_features == 0
|
|
|
|
for supported_features in range(1, entity.CAPABILITIES_UPDATE_LIMIT + 1):
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["device_class"] = "some_class"
|
|
ent._values["supported_features"] = supported_features
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities == {"bla": "blu"}
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == supported_features
|
|
|
|
assert capabilities_too_often_warning not in caplog.text
|
|
|
|
freezer.tick(timedelta(minutes=60) + timedelta(seconds=1))
|
|
|
|
ent._values["capability_attributes"] = {"bla": "blu"}
|
|
ent._values["device_class"] = "some_class"
|
|
ent._values["supported_features"] = supported_features + 1
|
|
ent.async_write_ha_state()
|
|
entry = entity_registry.async_get(ent.entity_id)
|
|
assert entry.capabilities == {"bla": "blu"}
|
|
assert entry.original_device_class == "some_class"
|
|
assert entry.supported_features == supported_features + 1
|
|
|
|
assert capabilities_too_often_warning not in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("property", "default_value", "values"), [("attribution", None, ["abcd", "efgh"])]
|
|
)
|
|
async def test_cached_entity_properties(
|
|
hass: HomeAssistant, property: str, default_value: Any, values: Any
|
|
) -> None:
|
|
"""Test entity properties are cached."""
|
|
ent1 = entity.Entity()
|
|
ent2 = entity.Entity()
|
|
assert getattr(ent1, property) == default_value
|
|
assert getattr(ent2, property) == default_value
|
|
|
|
# Test set
|
|
setattr(ent1, f"_attr_{property}", values[0])
|
|
assert getattr(ent1, property) == values[0]
|
|
assert getattr(ent2, property) == default_value
|
|
|
|
# Test update
|
|
setattr(ent1, f"_attr_{property}", values[1])
|
|
assert getattr(ent1, property) == values[1]
|
|
assert getattr(ent2, property) == default_value
|
|
|
|
# Test delete
|
|
delattr(ent1, f"_attr_{property}")
|
|
assert getattr(ent1, property) == default_value
|
|
assert getattr(ent2, property) == default_value
|
|
|
|
|
|
async def test_cached_entity_property_delete_attr(hass: HomeAssistant) -> None:
|
|
"""Test deleting an _attr corresponding to a cached property."""
|
|
property_name = "has_entity_name"
|
|
|
|
ent = entity.Entity()
|
|
assert not hasattr(ent, f"_attr_{property_name}")
|
|
with pytest.raises(AttributeError):
|
|
delattr(ent, f"_attr_{property_name}")
|
|
assert getattr(ent, property_name) is False
|
|
|
|
with pytest.raises(AttributeError):
|
|
delattr(ent, f"_attr_{property_name}")
|
|
assert not hasattr(ent, f"_attr_{property_name}")
|
|
assert getattr(ent, property_name) is False
|
|
|
|
setattr(ent, f"_attr_{property_name}", True)
|
|
assert getattr(ent, property_name) is True
|
|
|
|
delattr(ent, f"_attr_{property_name}")
|
|
assert not hasattr(ent, f"_attr_{property_name}")
|
|
assert getattr(ent, property_name) is False
|
|
|
|
|
|
async def test_cached_entity_property_class_attribute(hass: HomeAssistant) -> None:
|
|
"""Test entity properties on class level work in derived classes."""
|
|
property_name = "attribution"
|
|
values = ["abcd", "efgh"]
|
|
|
|
class EntityWithClassAttribute1(entity.Entity):
|
|
"""A derived class which overrides an _attr_ from a parent."""
|
|
|
|
_attr_attribution = values[0]
|
|
|
|
class EntityWithClassAttribute2(entity.Entity, cached_properties={property}):
|
|
"""A derived class which overrides an _attr_ from a parent.
|
|
|
|
This class also redundantly marks the overridden _attr_ as cached.
|
|
"""
|
|
|
|
_attr_attribution = values[0]
|
|
|
|
class EntityWithClassAttribute3(entity.Entity, cached_properties={property}):
|
|
"""A derived class which overrides an _attr_ from a parent.
|
|
|
|
This class overrides the attribute property.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._attr_attribution = values[0]
|
|
|
|
@cached_property
|
|
def attribution(self) -> str | None:
|
|
"""Return the attribution."""
|
|
return self._attr_attribution
|
|
|
|
class EntityWithClassAttribute4(entity.Entity, cached_properties={property}):
|
|
"""A derived class which overrides an _attr_ from a parent.
|
|
|
|
This class overrides the attribute property and the _attr_.
|
|
"""
|
|
|
|
_attr_attribution = values[0]
|
|
|
|
@cached_property
|
|
def attribution(self) -> str | None:
|
|
"""Return the attribution."""
|
|
return self._attr_attribution
|
|
|
|
classes = (
|
|
EntityWithClassAttribute1,
|
|
EntityWithClassAttribute2,
|
|
EntityWithClassAttribute3,
|
|
EntityWithClassAttribute4,
|
|
)
|
|
|
|
entities: list[tuple[entity.Entity, entity.Entity]] = [
|
|
(cls(), cls()) for cls in classes
|
|
]
|
|
|
|
for ent in entities:
|
|
assert getattr(ent[0], property_name) == values[0]
|
|
assert getattr(ent[1], property_name) == values[0]
|
|
|
|
# Test update
|
|
for ent in entities:
|
|
setattr(ent[0], f"_attr_{property_name}", values[1])
|
|
for ent in entities:
|
|
assert getattr(ent[0], property_name) == values[1]
|
|
assert getattr(ent[1], property_name) == values[0]
|
|
|
|
|
|
async def test_cached_entity_property_override(hass: HomeAssistant) -> None:
|
|
"""Test overriding cached _attr_ raises."""
|
|
|
|
class EntityWithClassAttribute1(entity.Entity):
|
|
"""A derived class which overrides an _attr_ from a parent."""
|
|
|
|
_attr_attribution: str
|
|
|
|
class EntityWithClassAttribute2(entity.Entity):
|
|
"""A derived class which overrides an _attr_ from a parent."""
|
|
|
|
_attr_attribution = "blabla"
|
|
|
|
class EntityWithClassAttribute3(entity.Entity):
|
|
"""A derived class which overrides an _attr_ from a parent."""
|
|
|
|
_attr_attribution: str = "blabla"
|
|
|
|
class EntityWithClassAttribute4(entity.Entity):
|
|
@property
|
|
def _attr_not_cached(self):
|
|
return "blabla"
|
|
|
|
class EntityWithClassAttribute5(entity.Entity):
|
|
def _attr_not_cached(self):
|
|
return "blabla"
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
class EntityWithClassAttribute6(entity.Entity):
|
|
@property
|
|
def _attr_attribution(self):
|
|
return "🤡"
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
class EntityWithClassAttribute7(entity.Entity):
|
|
def _attr_attribution(self):
|
|
return "🤡"
|
|
|
|
|
|
async def test_entity_report_deprecated_supported_features_values(
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test reporting deprecated supported feature values only happens once."""
|
|
ent = entity.Entity()
|
|
|
|
class MockEntityFeatures(IntFlag):
|
|
VALUE1 = 1
|
|
VALUE2 = 2
|
|
|
|
ent._report_deprecated_supported_features_values(MockEntityFeatures(2))
|
|
assert (
|
|
"is using deprecated supported features values which will be removed"
|
|
in caplog.text
|
|
)
|
|
assert "MockEntityFeatures.VALUE2" in caplog.text
|
|
|
|
caplog.clear()
|
|
ent._report_deprecated_supported_features_values(MockEntityFeatures(2))
|
|
assert (
|
|
"is using deprecated supported features values which will be removed"
|
|
not in caplog.text
|
|
)
|
|
|
|
|
|
async def test_remove_entity_registry(
|
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
|
) -> None:
|
|
"""Test removing an entity from the registry."""
|
|
result = []
|
|
|
|
entry = entity_registry.async_get_or_create(
|
|
"test", "test_platform", "5678", suggested_object_id="test"
|
|
)
|
|
assert entry.entity_id == "test.test"
|
|
|
|
class MockEntity(entity.Entity):
|
|
_attr_unique_id = "5678"
|
|
|
|
def __init__(self) -> None:
|
|
self.added_calls = []
|
|
self.remove_calls = []
|
|
|
|
async def async_added_to_hass(self):
|
|
self.added_calls.append(None)
|
|
self.async_on_remove(lambda: result.append(1))
|
|
|
|
async def async_will_remove_from_hass(self):
|
|
self.remove_calls.append(None)
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = MockEntity()
|
|
await platform.async_add_entities([ent])
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
assert len(ent.added_calls) == 1
|
|
|
|
entry = entity_registry.async_remove(entry.entity_id)
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(result) == 1
|
|
assert len(ent.added_calls) == 1
|
|
assert len(ent.remove_calls) == 1
|
|
|
|
assert hass.states.get("test.test") is None
|
|
|
|
|
|
async def test_reset_right_after_remove_entity_registry(
|
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
|
) -> None:
|
|
"""Test resetting the platform right after removing an entity from the registry.
|
|
|
|
A reset commonly happens during a reload.
|
|
"""
|
|
result = []
|
|
|
|
entry = entity_registry.async_get_or_create(
|
|
"test", "test_platform", "5678", suggested_object_id="test"
|
|
)
|
|
assert entry.entity_id == "test.test"
|
|
|
|
class MockEntity(entity.Entity):
|
|
_attr_unique_id = "5678"
|
|
|
|
def __init__(self) -> None:
|
|
self.added_calls = []
|
|
self.remove_calls = []
|
|
|
|
async def async_added_to_hass(self):
|
|
self.added_calls.append(None)
|
|
self.async_on_remove(lambda: result.append(1))
|
|
|
|
async def async_will_remove_from_hass(self):
|
|
self.remove_calls.append(None)
|
|
|
|
platform = MockEntityPlatform(hass, domain="test")
|
|
ent = MockEntity()
|
|
await platform.async_add_entities([ent])
|
|
assert hass.states.get("test.test").state == STATE_UNKNOWN
|
|
assert len(ent.added_calls) == 1
|
|
|
|
entry = entity_registry.async_remove(entry.entity_id)
|
|
|
|
# Reset the platform immediately after removing the entity from the registry
|
|
await platform.async_reset()
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(result) == 1
|
|
assert len(ent.added_calls) == 1
|
|
assert len(ent.remove_calls) == 1
|
|
|
|
assert hass.states.get("test.test") is None
|
|
|
|
|
|
async def test_get_hassjob_type(hass: HomeAssistant) -> None:
|
|
"""Test get_hassjob_type."""
|
|
|
|
class AsyncEntity(entity.Entity):
|
|
"""Test entity."""
|
|
|
|
def update(self):
|
|
"""Test update Executor."""
|
|
|
|
async def async_update(self):
|
|
"""Test update Coroutinefunction."""
|
|
|
|
@callback
|
|
def update_callback(self):
|
|
"""Test update Callback."""
|
|
|
|
ent_1 = AsyncEntity()
|
|
|
|
assert ent_1.get_hassjob_type("update") is HassJobType.Executor
|
|
assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction
|
|
assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback
|
|
|
|
|
|
async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
|
|
"""Test async_write_ha_state thread safety."""
|
|
hass.config.debug = True
|
|
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.any"
|
|
ent.hass = hass
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get(ent.entity_id)
|
|
|
|
ent2 = entity.Entity()
|
|
ent2.entity_id = "test.any2"
|
|
ent2.hass = hass
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match="Detected code that calls async_write_ha_state from a thread.",
|
|
):
|
|
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
|
assert not hass.states.get(ent2.entity_id)
|
|
|
|
|
|
async def test_async_write_ha_state_thread_safety_always(
|
|
hass: HomeAssistant,
|
|
) -> None:
|
|
"""Test async_write_ha_state thread safe check."""
|
|
|
|
ent = entity.Entity()
|
|
ent.entity_id = "test.any"
|
|
ent.hass = hass
|
|
ent.platform = MockEntityPlatform(hass, domain="test")
|
|
ent.async_write_ha_state()
|
|
assert hass.states.get(ent.entity_id)
|
|
|
|
ent2 = entity.Entity()
|
|
ent2.entity_id = "test.any2"
|
|
ent2.hass = hass
|
|
ent2.platform = MockEntityPlatform(hass, domain="test")
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match="Detected code that calls async_write_ha_state from a thread.",
|
|
):
|
|
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
|
assert not hass.states.get(ent2.entity_id)
|