Fix refactoring error with updating polling entities in sequence (#93693)

* Fix refactoring error with updating in sequence

see #93649

* coverage

* make sure entities are being updated in parallel

* make sure entities are being updated in sequence
This commit is contained in:
J. Nick Koston 2023-05-28 09:20:48 -05:00 committed by GitHub
parent 49c3a8886f
commit 083cf7a38b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 6 deletions

View file

@ -136,7 +136,7 @@ class EntityPlatform:
self._process_updates: asyncio.Lock | None = None self._process_updates: asyncio.Lock | None = None
self.parallel_updates: asyncio.Semaphore | None = None self.parallel_updates: asyncio.Semaphore | None = None
self._update_in_parallel: bool = True self._update_in_sequence: bool = False
# Platform is None for the EntityComponent "catch-all" EntityPlatform # Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities # which powers entity_component.add_entities
@ -187,7 +187,7 @@ class EntityPlatform:
if parallel_updates is not None: if parallel_updates is not None:
self.parallel_updates = asyncio.Semaphore(parallel_updates) self.parallel_updates = asyncio.Semaphore(parallel_updates)
self._update_in_parallel = parallel_updates != 1 self._update_in_sequence = parallel_updates == 1
return self.parallel_updates return self.parallel_updates
@ -846,11 +846,13 @@ class EntityPlatform:
return return
async with self._process_updates: async with self._process_updates:
if self._update_in_parallel or len(self.entities) <= 1: if self._update_in_sequence or len(self.entities) <= 1:
# If we know are going to update sequentially, we want to update # If we know we will update sequentially, we want to avoid scheduling
# to avoid scheduling the coroutines as tasks that will we know # the coroutines as tasks that will wait on the semaphore lock.
# are going to wait on the semaphore lock.
for entity in list(self.entities.values()): for entity in list(self.entities.values()):
# If the entity is removed from hass during the previous
# entity being updated, we need to skip updating the
# entity.
if entity.should_poll and entity.hass: if entity.should_poll and entity.hass:
await entity.async_update_ha_state(True) await entity.async_update_ha_state(True)
return return

View file

@ -2,6 +2,7 @@
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
import pytest import pytest
@ -307,6 +308,7 @@ async def test_parallel_updates_async_platform(hass: HomeAssistant) -> None:
entity = AsyncEntity() entity = AsyncEntity()
await handle.async_add_entities([entity]) await handle.async_add_entities([entity])
assert entity.parallel_updates is None assert entity.parallel_updates is None
assert handle._update_in_sequence is False
async def test_parallel_updates_async_platform_with_constant( async def test_parallel_updates_async_platform_with_constant(
@ -336,6 +338,7 @@ async def test_parallel_updates_async_platform_with_constant(
await handle.async_add_entities([entity]) await handle.async_add_entities([entity])
assert entity.parallel_updates is not None assert entity.parallel_updates is not None
assert entity.parallel_updates._value == 2 assert entity.parallel_updates._value == 2
assert handle._update_in_sequence is False
async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None: async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None:
@ -412,6 +415,104 @@ async def test_parallel_updates_sync_platform_with_constant(
assert entity.parallel_updates._value == 2 assert entity.parallel_updates._value == 2
async def test_parallel_updates_async_platform_updates_in_parallel(
hass: HomeAssistant,
) -> None:
"""Test an async platform is updated in parallel."""
platform = MockPlatform()
mock_entity_platform(hass, "test_domain.async_platform", platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
await component.async_setup({DOMAIN: {"platform": "async_platform"}})
await hass.async_block_till_done()
handle = list(component._platforms.values())[-1]
updating = []
peak_update_count = 0
class AsyncEntity(MockEntity):
"""Mock entity that has async_update."""
async def async_update(self):
pass
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
nonlocal peak_update_count
updating.append(self.entity_id)
await asyncio.sleep(0)
peak_update_count = max(len(updating), peak_update_count)
await asyncio.sleep(0)
updating.remove(self.entity_id)
entity1 = AsyncEntity()
entity2 = AsyncEntity()
entity3 = AsyncEntity()
await handle.async_add_entities([entity1, entity2, entity3])
assert entity1.parallel_updates is None
assert entity2.parallel_updates is None
assert entity3.parallel_updates is None
assert handle._update_in_sequence is False
await handle._update_entity_states(dt_util.utcnow())
assert peak_update_count > 1
async def test_parallel_updates_sync_platform_updates_in_sequence(
hass: HomeAssistant,
) -> None:
"""Test a sync platform is updated in sequence."""
platform = MockPlatform()
mock_entity_platform(hass, "test_domain.platform", platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
await component.async_setup({DOMAIN: {"platform": "platform"}})
await hass.async_block_till_done()
handle = list(component._platforms.values())[-1]
updating = []
peak_update_count = 0
class SyncEntity(MockEntity):
"""Mock entity that has update."""
def update(self):
pass
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
nonlocal peak_update_count
updating.append(self.entity_id)
await asyncio.sleep(0)
peak_update_count = max(len(updating), peak_update_count)
await asyncio.sleep(0)
updating.remove(self.entity_id)
entity1 = SyncEntity()
entity2 = SyncEntity()
entity3 = SyncEntity()
await handle.async_add_entities([entity1, entity2, entity3])
assert entity1.parallel_updates is not None
assert entity1.parallel_updates._value == 1
assert entity2.parallel_updates is not None
assert entity2.parallel_updates._value == 1
assert entity3.parallel_updates is not None
assert entity3.parallel_updates._value == 1
assert handle._update_in_sequence is True
await handle._update_entity_states(dt_util.utcnow())
assert peak_update_count == 1
async def test_raise_error_on_update(hass: HomeAssistant) -> None: async def test_raise_error_on_update(hass: HomeAssistant) -> None:
"""Test the add entity if they raise an error on update.""" """Test the add entity if they raise an error on update."""
updates = [] updates = []