Avoid race in entity_platform.async_add_entities() (#18445)

This avoids a race between multiple concurrent calls to
entity_platform.async_add_entities() which may cause
entities to be created with non-unique entity_id
This commit is contained in:
emontnemery 2018-11-19 10:13:50 +01:00 committed by Paulus Schoutsen
parent 7e702d3caa
commit f241becf7f
2 changed files with 11 additions and 13 deletions

View file

@ -206,7 +206,6 @@ class EntityPlatform:
return return
hass = self.hass hass = self.hass
component_entities = set(hass.states.async_entity_ids(self.domain))
device_registry = await \ device_registry = await \
hass.helpers.device_registry.async_get_registry() hass.helpers.device_registry.async_get_registry()
@ -214,8 +213,7 @@ class EntityPlatform:
hass.helpers.entity_registry.async_get_registry() hass.helpers.entity_registry.async_get_registry()
tasks = [ tasks = [
self._async_add_entity(entity, update_before_add, self._async_add_entity(entity, update_before_add,
component_entities, entity_registry, entity_registry, device_registry)
device_registry)
for entity in new_entities] for entity in new_entities]
# No entities for processing # No entities for processing
@ -235,8 +233,7 @@ class EntityPlatform:
) )
async def _async_add_entity(self, entity, update_before_add, async def _async_add_entity(self, entity, update_before_add,
component_entities, entity_registry, entity_registry, device_registry):
device_registry):
"""Add an entity to the platform.""" """Add an entity to the platform."""
if entity is None: if entity is None:
raise ValueError('Entity cannot be None') raise ValueError('Entity cannot be None')
@ -329,25 +326,24 @@ class EntityPlatform:
if self.entity_namespace is not None: if self.entity_namespace is not None:
suggested_object_id = '{} {}'.format(self.entity_namespace, suggested_object_id = '{} {}'.format(self.entity_namespace,
suggested_object_id) suggested_object_id)
entity.entity_id = entity_registry.async_generate_entity_id( entity.entity_id = entity_registry.async_generate_entity_id(
self.domain, suggested_object_id) self.domain, suggested_object_id, self.entities.keys())
# Make sure it is valid in case an entity set the value themselves # Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id): if not valid_entity_id(entity.entity_id):
raise HomeAssistantError( raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id)) 'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities: elif (entity.entity_id in self.entities or
entity.entity_id in self.hass.states.async_entity_ids(
self.domain)):
msg = 'Entity id already exists: {}'.format(entity.entity_id) msg = 'Entity id already exists: {}'.format(entity.entity_id)
if entity.unique_id is not None: if entity.unique_id is not None:
msg += '. Platform {} does not generate unique IDs'.format( msg += '. Platform {} does not generate unique IDs'.format(
self.platform_name) self.platform_name)
raise HomeAssistantError( raise HomeAssistantError(msg)
msg)
entity_id = entity.entity_id entity_id = entity.entity_id
self.entities[entity_id] = entity self.entities[entity_id] = entity
component_entities.add(entity_id)
entity.async_on_remove(lambda: self.entities.pop(entity_id)) entity.async_on_remove(lambda: self.entities.pop(entity_id))
if hasattr(entity, 'async_added_to_hass'): if hasattr(entity, 'async_added_to_hass'):

View file

@ -95,7 +95,8 @@ class EntityRegistry:
return None return None
@callback @callback
def async_generate_entity_id(self, domain, suggested_object_id): def async_generate_entity_id(self, domain, suggested_object_id,
known_object_ids=None):
"""Generate an entity ID that does not conflict. """Generate an entity ID that does not conflict.
Conflicts checked against registered and currently existing entities. Conflicts checked against registered and currently existing entities.
@ -103,7 +104,8 @@ class EntityRegistry:
return ensure_unique_string( return ensure_unique_string(
'{}.{}'.format(domain, slugify(suggested_object_id)), '{}.{}'.format(domain, slugify(suggested_object_id)),
chain(self.entities.keys(), chain(self.entities.keys(),
self.hass.states.async_entity_ids(domain)) self.hass.states.async_entity_ids(domain),
known_object_ids if known_object_ids else [])
) )
@callback @callback