Store config entry id in entity registry (#14851)

* Store config entry id in entity registry

* Lint
This commit is contained in:
Paulus Schoutsen 2018-06-07 14:23:09 -04:00 committed by Pascal Vizeli
parent bb4d1773d3
commit 67d137cfd5
4 changed files with 31 additions and 8 deletions

View file

@ -260,9 +260,15 @@ class EntityPlatform(object):
suggested_object_id = '{} {}'.format(
self.entity_namespace, suggested_object_id)
if self.config_entry is not None:
config_entry_id = self.config_entry.entry_id
else:
config_entry_id = None
entry = registry.async_get_or_create(
self.domain, self.platform_name, entity.unique_id,
suggested_object_id=suggested_object_id)
suggested_object_id=suggested_object_id,
config_entry_id=config_entry_id)
if entry.disabled:
self.logger.info(

View file

@ -43,6 +43,7 @@ class RegistryEntry:
unique_id = attr.ib(type=str)
platform = attr.ib(type=str)
name = attr.ib(type=str, default=None)
config_entry_id = attr.ib(type=str, default=None)
disabled_by = attr.ib(
type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
@ -106,7 +107,7 @@ class EntityRegistry:
@callback
def async_get_or_create(self, domain, platform, unique_id, *,
suggested_object_id=None):
suggested_object_id=None, config_entry_id=None):
"""Get entity. Create if it doesn't exist."""
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
@ -114,8 +115,10 @@ class EntityRegistry:
entity_id = self.async_generate_entity_id(
domain, suggested_object_id or '{}_{}'.format(platform, unique_id))
entity = RegistryEntry(
entity_id=entity_id,
config_entry_id=config_entry_id,
unique_id=unique_id,
platform=platform,
)
@ -179,6 +182,7 @@ class EntityRegistry:
for entity_id, info in data.items():
entities[entity_id] = RegistryEntry(
entity_id=entity_id,
config_entry_id=info.get('config_entry_id'),
unique_id=info['unique_id'],
platform=info['platform'],
name=info.get('name'),
@ -205,6 +209,7 @@ class EntityRegistry:
for entry in self.entities.values():
data[entry.entity_id] = {
'config_entry_id': entry.config_entry_id,
'unique_id': entry.unique_id,
'platform': entry.platform,
'name': entry.name,

View file

@ -16,7 +16,7 @@ import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity, MockEntityPlatform, MockConfigEntry, mock_coro)
MockEntity, MockEntityPlatform, MockConfigEntry)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
@ -516,11 +516,19 @@ async def test_entity_registry_updates(hass):
async def test_setup_entry(hass):
"""Test we can setup an entry."""
async_setup_entry = Mock(return_value=mock_coro(True))
registry = mock_registry(hass)
async def async_setup_entry(hass, config_entry, async_add_devices):
"""Mock setup entry method."""
async_add_devices([
MockEntity(name='test1', unique_id='unique')
])
return True
platform = MockPlatform(
async_setup_entry=async_setup_entry
)
config_entry = MockConfigEntry()
config_entry = MockConfigEntry(entry_id='super-mock-id')
entity_platform = MockEntityPlatform(
hass,
platform_name=config_entry.domain,
@ -528,10 +536,13 @@ async def test_setup_entry(hass):
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
full_name = '{}.{}'.format(entity_platform.domain, config_entry.domain)
assert full_name in hass.config.components
assert len(async_setup_entry.mock_calls) == 1
assert len(hass.states.async_entity_ids()) == 1
assert len(registry.entities) == 1
assert registry.entities['test_domain.test1'].config_entry_id == \
'super-mock-id'
async def test_setup_entry_platform_not_ready(hass, caplog):

View file

@ -86,7 +86,8 @@ def test_save_timer_reset_on_subsequent_save(hass, registry):
def test_loading_saving_data(hass, registry):
"""Test that we load/save data correctly."""
orig_entry1 = registry.async_get_or_create('light', 'hue', '1234')
orig_entry2 = registry.async_get_or_create('light', 'hue', '5678')
orig_entry2 = registry.async_get_or_create(
'light', 'hue', '5678', config_entry_id='mock-id')
assert len(registry.entities) == 2