Fix resetting of attributes in EntityRegistry.async_get_or_create (#77516)

* Fix resetting of attributes in EntityRegistry.async_get_or_create

* Fix typing

* Fix resetting config entry

* Improve test

* Update tests
This commit is contained in:
Erik Montnemery 2022-08-30 21:07:50 +02:00 committed by GitHub
parent 67db380253
commit 4655ed995e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 151 additions and 114 deletions

View file

@ -12,7 +12,7 @@ from __future__ import annotations
from collections import UserDict
from collections.abc import Callable, Iterable, Mapping
import logging
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
import attr
import voluptuous as vol
@ -53,6 +53,8 @@ if TYPE_CHECKING:
from .entity import EntityCategory
T = TypeVar("T")
PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry"
EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated"
@ -324,41 +326,43 @@ class EntityRegistry:
disabled_by: RegistryEntryDisabler | None = None,
hidden_by: RegistryEntryHider | None = None,
# Data that we want entry to have
area_id: str | None = None,
capabilities: Mapping[str, Any] | None = None,
config_entry: ConfigEntry | None = None,
device_id: str | None = None,
entity_category: EntityCategory | None = None,
has_entity_name: bool | None = None,
original_device_class: str | None = None,
original_icon: str | None = None,
original_name: str | None = None,
supported_features: int | None = None,
unit_of_measurement: str | None = None,
area_id: str | None | UndefinedType = UNDEFINED,
capabilities: Mapping[str, Any] | None | UndefinedType = UNDEFINED,
config_entry: ConfigEntry | None | UndefinedType = UNDEFINED,
device_id: str | None | UndefinedType = UNDEFINED,
entity_category: EntityCategory | UndefinedType | None = UNDEFINED,
has_entity_name: bool | UndefinedType = UNDEFINED,
original_device_class: str | None | UndefinedType = UNDEFINED,
original_icon: str | None | UndefinedType = UNDEFINED,
original_name: str | None | UndefinedType = UNDEFINED,
supported_features: int | None | UndefinedType = UNDEFINED,
unit_of_measurement: str | None | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Get entity. Create if it doesn't exist."""
config_entry_id = None
if config_entry:
config_entry_id: str | None | UndefinedType = UNDEFINED
if not config_entry:
config_entry_id = None
elif config_entry is not UNDEFINED:
config_entry_id = config_entry.entry_id
supported_features = supported_features or 0
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
return self.async_update_entity(
entity_id,
area_id=area_id or UNDEFINED,
capabilities=capabilities or UNDEFINED,
config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or UNDEFINED,
entity_category=entity_category or UNDEFINED,
has_entity_name=has_entity_name
if has_entity_name is not None
else UNDEFINED,
original_device_class=original_device_class or UNDEFINED,
original_icon=original_icon or UNDEFINED,
original_name=original_name or UNDEFINED,
supported_features=supported_features or UNDEFINED,
unit_of_measurement=unit_of_measurement or UNDEFINED,
area_id=area_id,
capabilities=capabilities,
config_entry_id=config_entry_id,
device_id=device_id,
entity_category=entity_category,
has_entity_name=has_entity_name,
original_device_class=original_device_class,
original_icon=original_icon,
original_name=original_name,
supported_features=supported_features,
unit_of_measurement=unit_of_measurement,
# When we changed our slugify algorithm, we invalidated some
# stored entity IDs with either a __ or ending in _.
# Fix introduced in 0.86 (Jan 23, 2019). Next line can be
@ -380,32 +384,41 @@ class EntityRegistry:
if (
disabled_by is None
and config_entry
and config_entry is not UNDEFINED
and config_entry.pref_disable_new_entities
):
disabled_by = RegistryEntryDisabler.INTEGRATION
from .entity import EntityCategory # pylint: disable=import-outside-toplevel
if entity_category and not isinstance(entity_category, EntityCategory):
if (
entity_category
and entity_category is not UNDEFINED
and not isinstance(entity_category, EntityCategory)
):
raise ValueError("entity_category must be a valid EntityCategory instance")
def none_if_undefined(value: T | UndefinedType) -> T | None:
"""Return None if value is UNDEFINED, otherwise return value."""
return None if value is UNDEFINED else value
entry = RegistryEntry(
area_id=area_id,
capabilities=capabilities,
config_entry_id=config_entry_id,
device_id=device_id,
area_id=none_if_undefined(area_id),
capabilities=none_if_undefined(capabilities),
config_entry_id=none_if_undefined(config_entry_id),
device_id=none_if_undefined(device_id),
disabled_by=disabled_by,
entity_category=entity_category,
entity_category=none_if_undefined(entity_category),
entity_id=entity_id,
hidden_by=hidden_by,
has_entity_name=has_entity_name or False,
original_device_class=original_device_class,
original_icon=original_icon,
original_name=original_name,
has_entity_name=none_if_undefined(has_entity_name) or False,
original_device_class=none_if_undefined(original_device_class),
original_icon=none_if_undefined(original_icon),
original_name=none_if_undefined(original_name),
platform=platform,
supported_features=supported_features or 0,
supported_features=none_if_undefined(supported_features) or 0,
unique_id=unique_id,
unit_of_measurement=unit_of_measurement,
unit_of_measurement=none_if_undefined(unit_of_measurement),
)
self.entities[entity_id] = entry
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)

View file

@ -226,6 +226,8 @@ async def test_get_action_capabilities(
"""Test we get the expected capabilities from a sensor trigger."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -239,8 +241,6 @@ async def test_get_action_capabilities(
platform.ENTITIES["no_arm_code"].unique_id,
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_capabilities = {
"arm_away": {"extra_fields": []},
@ -270,6 +270,8 @@ async def test_get_action_capabilities_arm_code(
"""Test we get the expected capabilities from a sensor trigger."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -283,8 +285,6 @@ async def test_get_action_capabilities_arm_code(
platform.ENTITIES["arm_code"].unique_id,
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_capabilities = {
"arm_away": {

View file

@ -49,6 +49,8 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr
"""Test we get the expected conditions from a binary_sensor."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -64,9 +66,6 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_conditions = [
{
"condition": "device",

View file

@ -49,6 +49,8 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat
"""Test we get the expected triggers from a binary_sensor."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -64,9 +66,6 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_triggers = [
{
"platform": "device",

View file

@ -9,6 +9,7 @@ from homeassistant.helpers.entity_registry import (
RegistryEntry,
RegistryEntryDisabler,
RegistryEntryHider,
async_get as async_get_entity_registry,
)
from tests.common import (
@ -374,25 +375,15 @@ async def test_update_entity(hass, client):
async def test_update_entity_require_restart(hass, client):
"""Test updating entity."""
entity_id = "test_domain.test_platform_1234"
config_entry = MockConfigEntry(domain="test_platform")
config_entry.add_to_hass(hass)
mock_registry(
hass,
{
"test_domain.world": RegistryEntry(
config_entry_id=config_entry.entry_id,
entity_id="test_domain.world",
unique_id="1234",
# Using component.async_add_entities is equal to platform "domain"
platform="test_platform",
)
},
)
platform = MockEntityPlatform(hass)
platform.config_entry = config_entry
entity = MockEntity(unique_id="1234")
await platform.async_add_entities([entity])
state = hass.states.get("test_domain.world")
state = hass.states.get(entity_id)
assert state is not None
# UPDATE DISABLED_BY TO NONE
@ -400,7 +391,7 @@ async def test_update_entity_require_restart(hass, client):
{
"id": 8,
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"entity_id": entity_id,
"disabled_by": None,
}
)
@ -416,7 +407,7 @@ async def test_update_entity_require_restart(hass, client):
"device_id": None,
"disabled_by": None,
"entity_category": None,
"entity_id": "test_domain.world",
"entity_id": entity_id,
"icon": None,
"hidden_by": None,
"has_entity_name": False,
@ -434,6 +425,7 @@ async def test_update_entity_require_restart(hass, client):
async def test_enable_entity_disabled_device(hass, client, device_registry):
"""Test enabling entity of disabled device."""
entity_id = "test_domain.test_platform_1234"
config_entry = MockConfigEntry(domain="test_platform")
config_entry.add_to_hass(hass)
@ -445,33 +437,30 @@ async def test_enable_entity_disabled_device(hass, client, device_registry):
model="model",
disabled_by=DeviceEntryDisabler.USER,
)
device_info = {
"connections": {("ethernet", "12:34:56:78:90:AB:CD:EF")},
}
mock_registry(
hass,
{
"test_domain.world": RegistryEntry(
config_entry_id=config_entry.entry_id,
entity_id="test_domain.world",
unique_id="1234",
# Using component.async_add_entities is equal to platform "domain"
platform="test_platform",
device_id=device.id,
)
},
)
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id="1234")
platform.config_entry = config_entry
entity = MockEntity(unique_id="1234", device_info=device_info)
await platform.async_add_entities([entity])
state = hass.states.get("test_domain.world")
assert state is not None
state = hass.states.get(entity_id)
assert state is None
entity_reg = async_get_entity_registry(hass)
entity_entry = entity_reg.async_get(entity_id)
assert entity_entry.config_entry_id == config_entry.entry_id
assert entity_entry.device_id == device.id
assert entity_entry.disabled_by == RegistryEntryDisabler.DEVICE
# UPDATE DISABLED_BY TO NONE
await client.send_json(
{
"id": 8,
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"entity_id": entity_id,
"disabled_by": None,
}
)

View file

@ -181,6 +181,8 @@ async def test_get_action_capabilities(
),
)
ent = platform.ENTITIES[0]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -192,9 +194,6 @@ async def test_get_action_capabilities(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, device_entry.id
)
@ -215,6 +214,8 @@ async def test_get_action_capabilities_set_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[1]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -226,9 +227,6 @@ async def test_get_action_capabilities_set_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_capabilities = {
"extra_fields": [
{
@ -264,6 +262,8 @@ async def test_get_action_capabilities_set_tilt_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[3]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -275,9 +275,6 @@ async def test_get_action_capabilities_set_tilt_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_capabilities = {
"extra_fields": [
{

View file

@ -171,6 +171,8 @@ async def test_get_condition_capabilities(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[0]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -182,8 +184,6 @@ async def test_get_condition_capabilities(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
conditions = await async_get_device_automations(
hass, DeviceAutomationType.CONDITION, device_entry.id
)
@ -202,6 +202,8 @@ async def test_get_condition_capabilities_set_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[1]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -213,8 +215,6 @@ async def test_get_condition_capabilities_set_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
expected_capabilities = {
"extra_fields": [
{
@ -256,6 +256,8 @@ async def test_get_condition_capabilities_set_tilt_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[3]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -267,8 +269,6 @@ async def test_get_condition_capabilities_set_tilt_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
expected_capabilities = {
"extra_fields": [
{

View file

@ -191,6 +191,8 @@ async def test_get_trigger_capabilities(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[0]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -202,8 +204,6 @@ async def test_get_trigger_capabilities(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device_entry.id
)
@ -226,6 +226,8 @@ async def test_get_trigger_capabilities_set_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[1]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -237,8 +239,6 @@ async def test_get_trigger_capabilities_set_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
expected_capabilities = {
"extra_fields": [
{
@ -288,6 +288,8 @@ async def test_get_trigger_capabilities_set_tilt_pos(
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
ent = platform.ENTITIES[3]
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -299,8 +301,6 @@ async def test_get_trigger_capabilities_set_tilt_pos(
DOMAIN, "test", ent.unique_id, device_id=device_entry.id
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
expected_capabilities = {
"extra_fields": [
{

View file

@ -51,6 +51,8 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr
"""Test we get the expected conditions from a sensor."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -66,9 +68,6 @@ async def test_get_conditions(hass, device_reg, entity_reg, enable_custom_integr
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_conditions = [
{
"condition": "device",

View file

@ -55,6 +55,8 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat
"""Test we get the expected triggers from a sensor."""
platform = getattr(hass.components, f"test.{DOMAIN}")
platform.init()
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
@ -70,9 +72,6 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat
device_id=device_entry.id,
)
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_PLATFORM: "test"}})
await hass.async_block_till_done()
expected_triggers = [
{
"platform": "device",

View file

@ -79,8 +79,8 @@ def test_get_or_create_updates_data(registry):
device_id="mock-dev-id",
disabled_by=er.RegistryEntryDisabler.HASS,
entity_category=EntityCategory.CONFIG,
hidden_by=er.RegistryEntryHider.INTEGRATION,
has_entity_name=True,
hidden_by=er.RegistryEntryHider.INTEGRATION,
original_device_class="mock-device-class",
original_icon="initial-original_icon",
original_name="initial-original_name",
@ -99,10 +99,10 @@ def test_get_or_create_updates_data(registry):
device_id="mock-dev-id",
disabled_by=er.RegistryEntryDisabler.HASS,
entity_category=EntityCategory.CONFIG,
has_entity_name=True,
hidden_by=er.RegistryEntryHider.INTEGRATION,
icon=None,
id=orig_entry.id,
has_entity_name=True,
name=None,
original_device_class="mock-device-class",
original_icon="initial-original_icon",
@ -122,9 +122,9 @@ def test_get_or_create_updates_data(registry):
config_entry=new_config_entry,
device_id="new-mock-dev-id",
disabled_by=er.RegistryEntryDisabler.USER,
entity_category=None,
hidden_by=er.RegistryEntryHider.USER,
entity_category=EntityCategory.DIAGNOSTIC,
has_entity_name=False,
hidden_by=er.RegistryEntryHider.USER,
original_device_class="new-mock-device-class",
original_icon="updated-original_icon",
original_name="updated-original_name",
@ -142,11 +142,11 @@ def test_get_or_create_updates_data(registry):
device_class=None,
device_id="new-mock-dev-id",
disabled_by=er.RegistryEntryDisabler.HASS, # Should not be updated
entity_category=EntityCategory.CONFIG,
entity_category=EntityCategory.DIAGNOSTIC,
has_entity_name=False,
hidden_by=er.RegistryEntryHider.INTEGRATION, # Should not be updated
icon=None,
id=orig_entry.id,
has_entity_name=False,
name=None,
original_device_class="new-mock-device-class",
original_icon="updated-original_icon",
@ -155,6 +155,48 @@ def test_get_or_create_updates_data(registry):
unit_of_measurement="updated-unit_of_measurement",
)
new_entry = registry.async_get_or_create(
"light",
"hue",
"5678",
area_id=None,
capabilities=None,
config_entry=None,
device_id=None,
disabled_by=None,
entity_category=None,
has_entity_name=None,
hidden_by=None,
original_device_class=None,
original_icon=None,
original_name=None,
supported_features=None,
unit_of_measurement=None,
)
assert new_entry == er.RegistryEntry(
"light.hue_5678",
"5678",
"hue",
area_id=None,
capabilities=None,
config_entry_id=None,
device_class=None,
device_id=None,
disabled_by=er.RegistryEntryDisabler.HASS, # Should not be updated
entity_category=None,
has_entity_name=None,
hidden_by=er.RegistryEntryHider.INTEGRATION, # Should not be updated
icon=None,
id=orig_entry.id,
name=None,
original_device_class=None,
original_icon=None,
original_name=None,
supported_features=0, # supported_features is stored as an int
unit_of_measurement=None,
)
def test_get_or_create_suggested_object_id_conflict_register(registry):
"""Test that we don't generate an entity id that is already registered."""