Store config entry id in entity registry (#14851)

* Store config entry id in entity registry

* Lint
This commit is contained in:
Paulus Schoutsen 2018-06-07 14:23:09 -04:00 committed by Pascal Vizeli
parent bb4d1773d3
commit 67d137cfd5
4 changed files with 31 additions and 8 deletions

View file

@ -260,9 +260,15 @@ class EntityPlatform(object):
suggested_object_id = '{} {}'.format( suggested_object_id = '{} {}'.format(
self.entity_namespace, suggested_object_id) self.entity_namespace, suggested_object_id)
if self.config_entry is not None:
config_entry_id = self.config_entry.entry_id
else:
config_entry_id = None
entry = registry.async_get_or_create( entry = registry.async_get_or_create(
self.domain, self.platform_name, entity.unique_id, self.domain, self.platform_name, entity.unique_id,
suggested_object_id=suggested_object_id) suggested_object_id=suggested_object_id,
config_entry_id=config_entry_id)
if entry.disabled: if entry.disabled:
self.logger.info( self.logger.info(

View file

@ -43,6 +43,7 @@ class RegistryEntry:
unique_id = attr.ib(type=str) unique_id = attr.ib(type=str)
platform = attr.ib(type=str) platform = attr.ib(type=str)
name = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None)
config_entry_id = attr.ib(type=str, default=None)
disabled_by = attr.ib( disabled_by = attr.ib(
type=str, default=None, type=str, default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None))) validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)))
@ -106,7 +107,7 @@ class EntityRegistry:
@callback @callback
def async_get_or_create(self, domain, platform, unique_id, *, def async_get_or_create(self, domain, platform, unique_id, *,
suggested_object_id=None): suggested_object_id=None, config_entry_id=None):
"""Get entity. Create if it doesn't exist.""" """Get entity. Create if it doesn't exist."""
entity_id = self.async_get_entity_id(domain, platform, unique_id) entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id: if entity_id:
@ -114,8 +115,10 @@ class EntityRegistry:
entity_id = self.async_generate_entity_id( entity_id = self.async_generate_entity_id(
domain, suggested_object_id or '{}_{}'.format(platform, unique_id)) domain, suggested_object_id or '{}_{}'.format(platform, unique_id))
entity = RegistryEntry( entity = RegistryEntry(
entity_id=entity_id, entity_id=entity_id,
config_entry_id=config_entry_id,
unique_id=unique_id, unique_id=unique_id,
platform=platform, platform=platform,
) )
@ -179,6 +182,7 @@ class EntityRegistry:
for entity_id, info in data.items(): for entity_id, info in data.items():
entities[entity_id] = RegistryEntry( entities[entity_id] = RegistryEntry(
entity_id=entity_id, entity_id=entity_id,
config_entry_id=info.get('config_entry_id'),
unique_id=info['unique_id'], unique_id=info['unique_id'],
platform=info['platform'], platform=info['platform'],
name=info.get('name'), name=info.get('name'),
@ -205,6 +209,7 @@ class EntityRegistry:
for entry in self.entities.values(): for entry in self.entities.values():
data[entry.entity_id] = { data[entry.entity_id] = {
'config_entry_id': entry.config_entry_id,
'unique_id': entry.unique_id, 'unique_id': entry.unique_id,
'platform': entry.platform, 'platform': entry.platform,
'name': entry.name, 'name': entry.name,

View file

@ -16,7 +16,7 @@ import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry, get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity, MockEntityPlatform, MockConfigEntry, mock_coro) MockEntity, MockEntityPlatform, MockConfigEntry)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain" DOMAIN = "test_domain"
@ -516,11 +516,19 @@ async def test_entity_registry_updates(hass):
async def test_setup_entry(hass): async def test_setup_entry(hass):
"""Test we can setup an entry.""" """Test we can setup an entry."""
async_setup_entry = Mock(return_value=mock_coro(True)) registry = mock_registry(hass)
async def async_setup_entry(hass, config_entry, async_add_devices):
"""Mock setup entry method."""
async_add_devices([
MockEntity(name='test1', unique_id='unique')
])
return True
platform = MockPlatform( platform = MockPlatform(
async_setup_entry=async_setup_entry async_setup_entry=async_setup_entry
) )
config_entry = MockConfigEntry() config_entry = MockConfigEntry(entry_id='super-mock-id')
entity_platform = MockEntityPlatform( entity_platform = MockEntityPlatform(
hass, hass,
platform_name=config_entry.domain, platform_name=config_entry.domain,
@ -528,10 +536,13 @@ async def test_setup_entry(hass):
) )
assert await entity_platform.async_setup_entry(config_entry) assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
full_name = '{}.{}'.format(entity_platform.domain, config_entry.domain) full_name = '{}.{}'.format(entity_platform.domain, config_entry.domain)
assert full_name in hass.config.components assert full_name in hass.config.components
assert len(async_setup_entry.mock_calls) == 1 assert len(hass.states.async_entity_ids()) == 1
assert len(registry.entities) == 1
assert registry.entities['test_domain.test1'].config_entry_id == \
'super-mock-id'
async def test_setup_entry_platform_not_ready(hass, caplog): async def test_setup_entry_platform_not_ready(hass, caplog):

View file

@ -86,7 +86,8 @@ def test_save_timer_reset_on_subsequent_save(hass, registry):
def test_loading_saving_data(hass, registry): def test_loading_saving_data(hass, registry):
"""Test that we load/save data correctly.""" """Test that we load/save data correctly."""
orig_entry1 = registry.async_get_or_create('light', 'hue', '1234') orig_entry1 = registry.async_get_or_create('light', 'hue', '1234')
orig_entry2 = registry.async_get_or_create('light', 'hue', '5678') orig_entry2 = registry.async_get_or_create(
'light', 'hue', '5678', config_entry_id='mock-id')
assert len(registry.entities) == 2 assert len(registry.entities) == 2