Refactor integration platforms to import in the executor (#112168)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2024-03-04 19:21:18 -10:00 committed by GitHub
parent a9caa3e582
commit 8b017016b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 242 additions and 60 deletions

View file

@ -37,68 +37,113 @@ class IntegrationPlatform:
@callback @callback
def _get_platform( def _async_integration_platform_component_loaded(
integration: Integration | Exception, component_name: str, platform_name: str
) -> ModuleType | None:
"""Get a platform from an integration."""
if isinstance(integration, Exception):
_LOGGER.exception(
"Error importing integration %s for %s",
component_name,
platform_name,
)
return None
#
# Loading the platform may do quite a bit of blocking I/O
# and CPU work. (https://github.com/python/cpython/issues/92041)
#
# We don't want to block the event loop for too
# long so we check if the platform exists with `platform_exists`
# before trying to load it. `platform_exists` will do two
# `stat()` system calls which is far cheaper than calling
# `integration.get_platform`
#
if not integration.platforms_exists((platform_name,)):
# If the platform cannot possibly exist, don't bother trying to load it
return None
try:
return integration.get_platform(platform_name)
except ImportError as err:
if f"{component_name}.{platform_name}" not in str(err):
_LOGGER.exception(
"Unexpected error importing %s/%s.py",
component_name,
platform_name,
)
return None
@callback
def _async_process_integration_platforms_for_component(
hass: HomeAssistant, hass: HomeAssistant,
integration_platforms: list[IntegrationPlatform], integration_platforms: list[IntegrationPlatform],
event: EventType[EventComponentLoaded], event: EventType[EventComponentLoaded],
) -> None: ) -> None:
"""Process integration platforms for a component.""" """Process integration platforms for a component."""
component_name = event.data[ATTR_COMPONENT] if "." in (component_name := event.data[ATTR_COMPONENT]):
if "." in component_name:
return return
integration = async_get_loaded_integration(hass, component_name) integration = async_get_loaded_integration(hass, component_name)
# First filter out platforms that the integration already processed.
integration_platforms_by_name: dict[str, IntegrationPlatform] = {}
for integration_platform in integration_platforms: for integration_platform in integration_platforms:
if component_name in integration_platform.seen_components or not ( if component_name in integration_platform.seen_components:
platform := _get_platform(
integration, component_name, integration_platform.platform_name
)
):
continue continue
integration_platform.seen_components.add(component_name) integration_platform.seen_components.add(component_name)
hass.async_run_hass_job( integration_platforms_by_name[
integration_platform.process_job, hass, component_name, platform integration_platform.platform_name
] = integration_platform
if not integration_platforms_by_name:
return
# Next, check which platforms exist for this integration.
platforms_that_exist = integration.platforms_exists(integration_platforms_by_name)
if not platforms_that_exist:
return
# If everything is already loaded, we can avoid creating a task.
can_use_cache = True
platforms: dict[str, ModuleType] = {}
for platform_name in platforms_that_exist:
if platform := integration.get_platform_cached(platform_name):
platforms[platform_name] = platform
else:
can_use_cache = False
break
if can_use_cache:
_process_integration_platforms(
hass,
integration,
platforms,
integration_platforms_by_name,
) )
return
# At least one of the platforms is not loaded, we need to load them
# so we have to fall back to creating a task.
hass.async_create_task(
_async_process_integration_platforms_for_component(
hass, integration, platforms_that_exist, integration_platforms_by_name
),
eager_start=True,
)
async def _async_process_integration_platforms_for_component(
hass: HomeAssistant,
integration: Integration,
platforms_that_exist: list[str],
integration_platforms_by_name: dict[str, IntegrationPlatform],
) -> None:
"""Process integration platforms for a component."""
# Now we know which platforms to load, let's load them.
try:
platforms = await integration.async_get_platforms(platforms_that_exist)
except ImportError:
_LOGGER.debug(
"Unexpected error importing integration platforms for %s",
integration.domain,
)
return
if futures := _process_integration_platforms(
hass,
integration,
platforms,
integration_platforms_by_name,
):
await asyncio.gather(*futures)
@callback
def _process_integration_platforms(
hass: HomeAssistant,
integration: Integration,
platforms: dict[str, ModuleType],
integration_platforms_by_name: dict[str, IntegrationPlatform],
) -> list[asyncio.Future[Awaitable[None] | None]]:
"""Process integration platforms for a component.
Only the platforms that are passed in will be processed.
"""
return [
future
for platform_name, platform in platforms.items()
if (integration_platform := integration_platforms_by_name[platform_name])
and (
future := hass.async_run_hass_job(
integration_platform.process_job,
hass,
integration.domain,
platform,
)
)
]
def _format_err(name: str, platform_name: str, *args: Any) -> str: def _format_err(name: str, platform_name: str, *args: Any) -> str:
@ -120,10 +165,11 @@ async def async_process_integration_platforms(
hass.bus.async_listen( hass.bus.async_listen(
EVENT_COMPONENT_LOADED, EVENT_COMPONENT_LOADED,
partial( partial(
_async_process_integration_platforms_for_component, _async_integration_platform_component_loaded,
hass, hass,
integration_platforms, integration_platforms,
), ),
run_immediately=True,
) )
else: else:
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS] integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS]
@ -140,16 +186,42 @@ async def async_process_integration_platforms(
integration_platform = IntegrationPlatform( integration_platform = IntegrationPlatform(
platform_name, process_job, top_level_components platform_name, process_job, top_level_components
) )
# Tell the loader that it should try to pre-load the integration
# for any future components that are loaded so we can reduce the
# amount of import executor usage.
async_register_preload_platform(hass, platform_name)
integration_platforms.append(integration_platform) integration_platforms.append(integration_platform)
if not top_level_components: if not top_level_components:
return return
integrations = await async_get_integrations(hass, top_level_components) integrations = await async_get_integrations(hass, top_level_components)
if futures := [ loaded_integrations: list[Integration] = [
future integration
for comp in top_level_components for integration in integrations.values()
if (platform := _get_platform(integrations[comp], comp, platform_name)) if not isinstance(integration, Exception)
and (future := hass.async_run_hass_job(process_job, hass, comp, platform)) ]
]: # Finally, fetch the platforms for each integration and process them.
# This uses the import executor in a loop. If there are a lot
# of integration with the integration platform to process,
# this could be a bottleneck.
futures: list[asyncio.Future[None]] = []
for integration in loaded_integrations:
if not integration.platforms_exists((platform_name,)):
continue
try:
platform = await integration.async_get_platform(platform_name)
except ImportError:
_LOGGER.debug(
"Unexpected error importing %s for %s",
platform_name,
integration.domain,
)
continue
if future := hass.async_run_hass_job(
process_job, hass, integration.domain, platform
):
futures.append(future)
if futures:
await asyncio.gather(*futures) await asyncio.gather(*futures)

View file

@ -1122,6 +1122,10 @@ class Integration:
raise self._missing_platforms_cache[full_name] raise self._missing_platforms_cache[full_name]
return None return None
def get_platform_cached(self, platform_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache."""
return self._cache.get(f"{self.domain}.{platform_name}") # type: ignore[return-value]
def get_platform(self, platform_name: str) -> ModuleType: def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration.""" """Return a platform for an integration."""
if platform := self._get_platform_cached_or_raise( if platform := self._get_platform_cached_or_raise(

View file

@ -6,6 +6,7 @@ from pydiscovergy.const import API_BASE
from homeassistant.components.discovergy.const import DOMAIN from homeassistant.components.discovergy.const import DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import get_system_health_info from tests.common import get_system_health_info
@ -17,8 +18,11 @@ async def test_discovergy_system_health(
) -> None: ) -> None:
"""Test Discovergy system health.""" """Test Discovergy system health."""
aioclient_mock.get(API_BASE, text="") aioclient_mock.get(API_BASE, text="")
integration = await async_get_integration(hass, DOMAIN)
await integration.async_get_component()
hass.config.components.add(DOMAIN) hass.config.components.add(DOMAIN)
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})
await hass.async_block_till_done()
info = await get_system_health_info(hass, DOMAIN) info = await get_system_health_info(hass, DOMAIN)
@ -34,8 +38,11 @@ async def test_discovergy_system_health_fail(
) -> None: ) -> None:
"""Test Discovergy system health.""" """Test Discovergy system health."""
aioclient_mock.get(API_BASE, exc=ClientError) aioclient_mock.get(API_BASE, exc=ClientError)
integration = await async_get_integration(hass, DOMAIN)
await integration.async_get_component()
hass.config.components.add(DOMAIN) hass.config.components.add(DOMAIN)
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})
await hass.async_block_till_done()
info = await get_system_health_info(hass, DOMAIN) info = await get_system_health_info(hass, DOMAIN)

View file

@ -5,6 +5,7 @@ from aiohttp import ClientError
from homeassistant.components.gios.const import DOMAIN from homeassistant.components.gios.const import DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import get_system_health_info from tests.common import get_system_health_info
@ -16,6 +17,8 @@ async def test_gios_system_health(
) -> None: ) -> None:
"""Test GIOS system health.""" """Test GIOS system health."""
aioclient_mock.get("http://api.gios.gov.pl/", text="") aioclient_mock.get("http://api.gios.gov.pl/", text="")
integration = await async_get_integration(hass, DOMAIN)
await integration.async_get_component()
hass.config.components.add(DOMAIN) hass.config.components.add(DOMAIN)
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})
@ -33,6 +36,8 @@ async def test_gios_system_health_fail(
) -> None: ) -> None:
"""Test GIOS system health.""" """Test GIOS system health."""
aioclient_mock.get("http://api.gios.gov.pl/", exc=ClientError) aioclient_mock.get("http://api.gios.gov.pl/", exc=ClientError)
integration = await async_get_integration(hass, DOMAIN)
await integration.async_get_component()
hass.config.components.add(DOMAIN) hass.config.components.add(DOMAIN)
assert await async_setup_component(hass, "system_health", {}) assert await async_setup_component(hass, "system_health", {})

View file

@ -553,6 +553,7 @@ async def test_group_updated_after_device_tracker_zone_change(
assert await async_setup_component(hass, "group", {}) assert await async_setup_component(hass, "group", {})
assert await async_setup_component(hass, "device_tracker", {}) assert await async_setup_component(hass, "device_tracker", {})
await hass.async_block_till_done()
await group.Group.async_create_group( await group.Group.async_create_group(
hass, hass,

View file

@ -1,10 +1,11 @@
"""Test integration platform helpers.""" """Test integration platform helpers."""
from collections.abc import Callable from collections.abc import Callable
from types import ModuleType from types import ModuleType
from unittest.mock import Mock from unittest.mock import Mock, patch
import pytest import pytest
from homeassistant import loader
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.integration_platform import ( from homeassistant.helpers.integration_platform import (
@ -52,6 +53,84 @@ async def test_process_integration_platforms(hass: HomeAssistant) -> None:
assert len(processed) == 2 assert len(processed) == 2
async def test_process_integration_platforms_import_fails(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test processing integrations when one fails to import."""
loaded_platform = Mock()
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
hass.config.components.add("loaded")
event_platform = Mock()
mock_platform(hass, "event.platform_to_check", event_platform)
processed = []
async def _process_platform(hass, domain, platform):
"""Process platform."""
processed.append((domain, platform))
loaded_integration = await loader.async_get_integration(hass, "loaded")
with patch.object(
loaded_integration, "async_get_platform", side_effect=ImportError
):
await async_process_integration_platforms(
hass, "platform_to_check", _process_platform
)
assert len(processed) == 0
assert "Unexpected error importing platform_to_check for loaded" in caplog.text
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
await hass.async_block_till_done()
assert len(processed) == 1
assert processed[0][0] == "event"
assert processed[0][1] == event_platform
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
await hass.async_block_till_done()
# Firing again should not check again
assert len(processed) == 1
async def test_process_integration_platforms_import_fails_after_registered(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test processing integrations when one fails to import."""
loaded_platform = Mock()
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
hass.config.components.add("loaded")
event_platform = Mock()
mock_platform(hass, "event.platform_to_check", event_platform)
processed = []
async def _process_platform(hass, domain, platform):
"""Process platform."""
processed.append((domain, platform))
await async_process_integration_platforms(
hass, "platform_to_check", _process_platform
)
assert len(processed) == 1
assert processed[0][0] == "loaded"
assert processed[0][1] == loaded_platform
event_integration = await loader.async_get_integration(hass, "event")
with patch.object(
event_integration, "async_get_platforms", side_effect=ImportError
), patch.object(event_integration, "get_platform_cached", return_value=None):
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
await hass.async_block_till_done()
assert len(processed) == 1
assert "Unexpected error importing integration platforms for event" in caplog.text
@callback @callback
def _process_platform_callback( def _process_platform_callback(
hass: HomeAssistant, domain: str, platform: ModuleType hass: HomeAssistant, domain: str, platform: ModuleType
@ -126,8 +205,9 @@ async def test_broken_integration(
hass, "platform_to_check", _process_platform hass, "platform_to_check", _process_platform
) )
# This should never actually happen as the component cannot be
# in hass.config.components without a loaded manifest
assert len(processed) == 0 assert len(processed) == 0
assert "Error importing integration loaded for platform_to_check" in caplog.text
async def test_process_integration_platforms_no_integrations( async def test_process_integration_platforms_no_integrations(

View file

@ -120,6 +120,8 @@ async def test_custom_component_name(
integration = await loader.async_get_integration(hass, "test") integration = await loader.async_get_integration(hass, "test")
platform = integration.get_platform("light") platform = integration.get_platform("light")
assert integration.get_platform_cached("light") is platform
assert platform.__name__ == "custom_components.test.light" assert platform.__name__ == "custom_components.test.light"
assert platform.__package__ == "custom_components.test" assert platform.__package__ == "custom_components.test"
@ -277,6 +279,9 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
with pytest.raises(ImportError): with pytest.raises(ImportError):
assert await integration.async_get_platform("light") == hue_light assert await integration.async_get_platform("light") == hue_light
# The cache should never be filled because the import error is remembered
assert integration.get_platform_cached("light") is None
async def test_async_get_platforms_caches_failures_when_component_loaded( async def test_async_get_platforms_caches_failures_when_component_loaded(
hass: HomeAssistant, hass: HomeAssistant,
@ -312,6 +317,9 @@ async def test_async_get_platforms_caches_failures_when_component_loaded(
with pytest.raises(ImportError): with pytest.raises(ImportError):
assert await integration.async_get_platforms(["light"]) == {"light": hue_light} assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
# The cache should never be filled because the import error is remembered
assert integration.get_platform_cached("light") is None
async def test_get_integration_legacy( async def test_get_integration_legacy(
hass: HomeAssistant, enable_custom_integrations: None hass: HomeAssistant, enable_custom_integrations: None
@ -320,6 +328,7 @@ async def test_get_integration_legacy(
integration = await loader.async_get_integration(hass, "test_embedded") integration = await loader.async_get_integration(hass, "test_embedded")
assert integration.get_component().DOMAIN == "test_embedded" assert integration.get_component().DOMAIN == "test_embedded"
assert integration.get_platform("switch") is not None assert integration.get_platform("switch") is not None
assert integration.get_platform_cached("switch") is not None
async def test_get_integration_custom_component( async def test_get_integration_custom_component(
@ -1549,6 +1558,9 @@ async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(
"switch": switch_module_mock, "switch": switch_module_mock,
"light": light_module_mock, "light": light_module_mock,
} }
assert integration.get_platform_cached("button") is button_module_mock
assert integration.get_platform_cached("switch") is switch_module_mock
assert integration.get_platform_cached("light") is light_module_mock
async def test_async_get_platforms_concurrent_loads( async def test_async_get_platforms_concurrent_loads(
@ -1610,3 +1622,4 @@ async def test_async_get_platforms_concurrent_loads(
assert load_result2 == {"button": button_module_mock} assert load_result2 == {"button": button_module_mock}
assert imports == [button_module_name] assert imports == [button_module_name]
assert integration.get_platform_cached("button") is button_module_mock