Refactor integration platforms to import in the executor (#112168)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
a9caa3e582
commit
8b017016b0
7 changed files with 242 additions and 60 deletions
|
@ -37,68 +37,113 @@ class IntegrationPlatform:
|
|||
|
||||
|
||||
@callback
|
||||
def _get_platform(
|
||||
integration: Integration | Exception, component_name: str, platform_name: str
|
||||
) -> ModuleType | None:
|
||||
"""Get a platform from an integration."""
|
||||
if isinstance(integration, Exception):
|
||||
_LOGGER.exception(
|
||||
"Error importing integration %s for %s",
|
||||
component_name,
|
||||
platform_name,
|
||||
)
|
||||
return None
|
||||
|
||||
#
|
||||
# Loading the platform may do quite a bit of blocking I/O
|
||||
# and CPU work. (https://github.com/python/cpython/issues/92041)
|
||||
#
|
||||
# We don't want to block the event loop for too
|
||||
# long so we check if the platform exists with `platform_exists`
|
||||
# before trying to load it. `platform_exists` will do two
|
||||
# `stat()` system calls which is far cheaper than calling
|
||||
# `integration.get_platform`
|
||||
#
|
||||
if not integration.platforms_exists((platform_name,)):
|
||||
# If the platform cannot possibly exist, don't bother trying to load it
|
||||
return None
|
||||
|
||||
try:
|
||||
return integration.get_platform(platform_name)
|
||||
except ImportError as err:
|
||||
if f"{component_name}.{platform_name}" not in str(err):
|
||||
_LOGGER.exception(
|
||||
"Unexpected error importing %s/%s.py",
|
||||
component_name,
|
||||
platform_name,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@callback
|
||||
def _async_process_integration_platforms_for_component(
|
||||
def _async_integration_platform_component_loaded(
|
||||
hass: HomeAssistant,
|
||||
integration_platforms: list[IntegrationPlatform],
|
||||
event: EventType[EventComponentLoaded],
|
||||
) -> None:
|
||||
"""Process integration platforms for a component."""
|
||||
component_name = event.data[ATTR_COMPONENT]
|
||||
if "." in component_name:
|
||||
if "." in (component_name := event.data[ATTR_COMPONENT]):
|
||||
return
|
||||
|
||||
integration = async_get_loaded_integration(hass, component_name)
|
||||
# First filter out platforms that the integration already processed.
|
||||
integration_platforms_by_name: dict[str, IntegrationPlatform] = {}
|
||||
for integration_platform in integration_platforms:
|
||||
if component_name in integration_platform.seen_components or not (
|
||||
platform := _get_platform(
|
||||
integration, component_name, integration_platform.platform_name
|
||||
)
|
||||
):
|
||||
if component_name in integration_platform.seen_components:
|
||||
continue
|
||||
integration_platform.seen_components.add(component_name)
|
||||
hass.async_run_hass_job(
|
||||
integration_platform.process_job, hass, component_name, platform
|
||||
integration_platforms_by_name[
|
||||
integration_platform.platform_name
|
||||
] = integration_platform
|
||||
|
||||
if not integration_platforms_by_name:
|
||||
return
|
||||
|
||||
# Next, check which platforms exist for this integration.
|
||||
platforms_that_exist = integration.platforms_exists(integration_platforms_by_name)
|
||||
if not platforms_that_exist:
|
||||
return
|
||||
|
||||
# If everything is already loaded, we can avoid creating a task.
|
||||
can_use_cache = True
|
||||
platforms: dict[str, ModuleType] = {}
|
||||
for platform_name in platforms_that_exist:
|
||||
if platform := integration.get_platform_cached(platform_name):
|
||||
platforms[platform_name] = platform
|
||||
else:
|
||||
can_use_cache = False
|
||||
break
|
||||
|
||||
if can_use_cache:
|
||||
_process_integration_platforms(
|
||||
hass,
|
||||
integration,
|
||||
platforms,
|
||||
integration_platforms_by_name,
|
||||
)
|
||||
return
|
||||
|
||||
# At least one of the platforms is not loaded, we need to load them
|
||||
# so we have to fall back to creating a task.
|
||||
hass.async_create_task(
|
||||
_async_process_integration_platforms_for_component(
|
||||
hass, integration, platforms_that_exist, integration_platforms_by_name
|
||||
),
|
||||
eager_start=True,
|
||||
)
|
||||
|
||||
|
||||
async def _async_process_integration_platforms_for_component(
|
||||
hass: HomeAssistant,
|
||||
integration: Integration,
|
||||
platforms_that_exist: list[str],
|
||||
integration_platforms_by_name: dict[str, IntegrationPlatform],
|
||||
) -> None:
|
||||
"""Process integration platforms for a component."""
|
||||
# Now we know which platforms to load, let's load them.
|
||||
try:
|
||||
platforms = await integration.async_get_platforms(platforms_that_exist)
|
||||
except ImportError:
|
||||
_LOGGER.debug(
|
||||
"Unexpected error importing integration platforms for %s",
|
||||
integration.domain,
|
||||
)
|
||||
return
|
||||
|
||||
if futures := _process_integration_platforms(
|
||||
hass,
|
||||
integration,
|
||||
platforms,
|
||||
integration_platforms_by_name,
|
||||
):
|
||||
await asyncio.gather(*futures)
|
||||
|
||||
|
||||
@callback
|
||||
def _process_integration_platforms(
|
||||
hass: HomeAssistant,
|
||||
integration: Integration,
|
||||
platforms: dict[str, ModuleType],
|
||||
integration_platforms_by_name: dict[str, IntegrationPlatform],
|
||||
) -> list[asyncio.Future[Awaitable[None] | None]]:
|
||||
"""Process integration platforms for a component.
|
||||
|
||||
Only the platforms that are passed in will be processed.
|
||||
"""
|
||||
return [
|
||||
future
|
||||
for platform_name, platform in platforms.items()
|
||||
if (integration_platform := integration_platforms_by_name[platform_name])
|
||||
and (
|
||||
future := hass.async_run_hass_job(
|
||||
integration_platform.process_job,
|
||||
hass,
|
||||
integration.domain,
|
||||
platform,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _format_err(name: str, platform_name: str, *args: Any) -> str:
|
||||
|
@ -120,10 +165,11 @@ async def async_process_integration_platforms(
|
|||
hass.bus.async_listen(
|
||||
EVENT_COMPONENT_LOADED,
|
||||
partial(
|
||||
_async_process_integration_platforms_for_component,
|
||||
_async_integration_platform_component_loaded,
|
||||
hass,
|
||||
integration_platforms,
|
||||
),
|
||||
run_immediately=True,
|
||||
)
|
||||
else:
|
||||
integration_platforms = hass.data[DATA_INTEGRATION_PLATFORMS]
|
||||
|
@ -140,16 +186,42 @@ async def async_process_integration_platforms(
|
|||
integration_platform = IntegrationPlatform(
|
||||
platform_name, process_job, top_level_components
|
||||
)
|
||||
# Tell the loader that it should try to pre-load the integration
|
||||
# for any future components that are loaded so we can reduce the
|
||||
# amount of import executor usage.
|
||||
async_register_preload_platform(hass, platform_name)
|
||||
integration_platforms.append(integration_platform)
|
||||
|
||||
if not top_level_components:
|
||||
return
|
||||
|
||||
integrations = await async_get_integrations(hass, top_level_components)
|
||||
if futures := [
|
||||
future
|
||||
for comp in top_level_components
|
||||
if (platform := _get_platform(integrations[comp], comp, platform_name))
|
||||
and (future := hass.async_run_hass_job(process_job, hass, comp, platform))
|
||||
]:
|
||||
loaded_integrations: list[Integration] = [
|
||||
integration
|
||||
for integration in integrations.values()
|
||||
if not isinstance(integration, Exception)
|
||||
]
|
||||
# Finally, fetch the platforms for each integration and process them.
|
||||
# This uses the import executor in a loop. If there are a lot
|
||||
# of integration with the integration platform to process,
|
||||
# this could be a bottleneck.
|
||||
futures: list[asyncio.Future[None]] = []
|
||||
for integration in loaded_integrations:
|
||||
if not integration.platforms_exists((platform_name,)):
|
||||
continue
|
||||
try:
|
||||
platform = await integration.async_get_platform(platform_name)
|
||||
except ImportError:
|
||||
_LOGGER.debug(
|
||||
"Unexpected error importing %s for %s",
|
||||
platform_name,
|
||||
integration.domain,
|
||||
)
|
||||
continue
|
||||
|
||||
if future := hass.async_run_hass_job(
|
||||
process_job, hass, integration.domain, platform
|
||||
):
|
||||
futures.append(future)
|
||||
|
||||
if futures:
|
||||
await asyncio.gather(*futures)
|
||||
|
|
|
@ -1122,6 +1122,10 @@ class Integration:
|
|||
raise self._missing_platforms_cache[full_name]
|
||||
return None
|
||||
|
||||
def get_platform_cached(self, platform_name: str) -> ModuleType | None:
|
||||
"""Return a platform for an integration from cache."""
|
||||
return self._cache.get(f"{self.domain}.{platform_name}") # type: ignore[return-value]
|
||||
|
||||
def get_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Return a platform for an integration."""
|
||||
if platform := self._get_platform_cached_or_raise(
|
||||
|
|
|
@ -6,6 +6,7 @@ from pydiscovergy.const import API_BASE
|
|||
|
||||
from homeassistant.components.discovergy.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import get_system_health_info
|
||||
|
@ -17,8 +18,11 @@ async def test_discovergy_system_health(
|
|||
) -> None:
|
||||
"""Test Discovergy system health."""
|
||||
aioclient_mock.get(API_BASE, text="")
|
||||
integration = await async_get_integration(hass, DOMAIN)
|
||||
await integration.async_get_component()
|
||||
hass.config.components.add(DOMAIN)
|
||||
assert await async_setup_component(hass, "system_health", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
info = await get_system_health_info(hass, DOMAIN)
|
||||
|
||||
|
@ -34,8 +38,11 @@ async def test_discovergy_system_health_fail(
|
|||
) -> None:
|
||||
"""Test Discovergy system health."""
|
||||
aioclient_mock.get(API_BASE, exc=ClientError)
|
||||
integration = await async_get_integration(hass, DOMAIN)
|
||||
await integration.async_get_component()
|
||||
hass.config.components.add(DOMAIN)
|
||||
assert await async_setup_component(hass, "system_health", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
info = await get_system_health_info(hass, DOMAIN)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from aiohttp import ClientError
|
|||
|
||||
from homeassistant.components.gios.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import get_system_health_info
|
||||
|
@ -16,6 +17,8 @@ async def test_gios_system_health(
|
|||
) -> None:
|
||||
"""Test GIOS system health."""
|
||||
aioclient_mock.get("http://api.gios.gov.pl/", text="")
|
||||
integration = await async_get_integration(hass, DOMAIN)
|
||||
await integration.async_get_component()
|
||||
hass.config.components.add(DOMAIN)
|
||||
assert await async_setup_component(hass, "system_health", {})
|
||||
|
||||
|
@ -33,6 +36,8 @@ async def test_gios_system_health_fail(
|
|||
) -> None:
|
||||
"""Test GIOS system health."""
|
||||
aioclient_mock.get("http://api.gios.gov.pl/", exc=ClientError)
|
||||
integration = await async_get_integration(hass, DOMAIN)
|
||||
await integration.async_get_component()
|
||||
hass.config.components.add(DOMAIN)
|
||||
assert await async_setup_component(hass, "system_health", {})
|
||||
|
||||
|
|
|
@ -553,6 +553,7 @@ async def test_group_updated_after_device_tracker_zone_change(
|
|||
|
||||
assert await async_setup_component(hass, "group", {})
|
||||
assert await async_setup_component(hass, "device_tracker", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await group.Group.async_create_group(
|
||||
hass,
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Test integration platform helpers."""
|
||||
from collections.abc import Callable
|
||||
from types import ModuleType
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import loader
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.integration_platform import (
|
||||
|
@ -52,6 +53,84 @@ async def test_process_integration_platforms(hass: HomeAssistant) -> None:
|
|||
assert len(processed) == 2
|
||||
|
||||
|
||||
async def test_process_integration_platforms_import_fails(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test processing integrations when one fails to import."""
|
||||
loaded_platform = Mock()
|
||||
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
|
||||
hass.config.components.add("loaded")
|
||||
|
||||
event_platform = Mock()
|
||||
mock_platform(hass, "event.platform_to_check", event_platform)
|
||||
|
||||
processed = []
|
||||
|
||||
async def _process_platform(hass, domain, platform):
|
||||
"""Process platform."""
|
||||
processed.append((domain, platform))
|
||||
|
||||
loaded_integration = await loader.async_get_integration(hass, "loaded")
|
||||
with patch.object(
|
||||
loaded_integration, "async_get_platform", side_effect=ImportError
|
||||
):
|
||||
await async_process_integration_platforms(
|
||||
hass, "platform_to_check", _process_platform
|
||||
)
|
||||
|
||||
assert len(processed) == 0
|
||||
assert "Unexpected error importing platform_to_check for loaded" in caplog.text
|
||||
|
||||
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(processed) == 1
|
||||
assert processed[0][0] == "event"
|
||||
assert processed[0][1] == event_platform
|
||||
|
||||
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Firing again should not check again
|
||||
assert len(processed) == 1
|
||||
|
||||
|
||||
async def test_process_integration_platforms_import_fails_after_registered(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test processing integrations when one fails to import."""
|
||||
loaded_platform = Mock()
|
||||
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
|
||||
hass.config.components.add("loaded")
|
||||
|
||||
event_platform = Mock()
|
||||
mock_platform(hass, "event.platform_to_check", event_platform)
|
||||
|
||||
processed = []
|
||||
|
||||
async def _process_platform(hass, domain, platform):
|
||||
"""Process platform."""
|
||||
processed.append((domain, platform))
|
||||
|
||||
await async_process_integration_platforms(
|
||||
hass, "platform_to_check", _process_platform
|
||||
)
|
||||
|
||||
assert len(processed) == 1
|
||||
assert processed[0][0] == "loaded"
|
||||
assert processed[0][1] == loaded_platform
|
||||
|
||||
event_integration = await loader.async_get_integration(hass, "event")
|
||||
with patch.object(
|
||||
event_integration, "async_get_platforms", side_effect=ImportError
|
||||
), patch.object(event_integration, "get_platform_cached", return_value=None):
|
||||
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {ATTR_COMPONENT: "event"})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(processed) == 1
|
||||
assert "Unexpected error importing integration platforms for event" in caplog.text
|
||||
|
||||
|
||||
@callback
|
||||
def _process_platform_callback(
|
||||
hass: HomeAssistant, domain: str, platform: ModuleType
|
||||
|
@ -126,8 +205,9 @@ async def test_broken_integration(
|
|||
hass, "platform_to_check", _process_platform
|
||||
)
|
||||
|
||||
# This should never actually happen as the component cannot be
|
||||
# in hass.config.components without a loaded manifest
|
||||
assert len(processed) == 0
|
||||
assert "Error importing integration loaded for platform_to_check" in caplog.text
|
||||
|
||||
|
||||
async def test_process_integration_platforms_no_integrations(
|
||||
|
|
|
@ -120,6 +120,8 @@ async def test_custom_component_name(
|
|||
|
||||
integration = await loader.async_get_integration(hass, "test")
|
||||
platform = integration.get_platform("light")
|
||||
assert integration.get_platform_cached("light") is platform
|
||||
|
||||
assert platform.__name__ == "custom_components.test.light"
|
||||
assert platform.__package__ == "custom_components.test"
|
||||
|
||||
|
@ -277,6 +279,9 @@ async def test_async_get_platform_caches_failures_when_component_loaded(
|
|||
with pytest.raises(ImportError):
|
||||
assert await integration.async_get_platform("light") == hue_light
|
||||
|
||||
# The cache should never be filled because the import error is remembered
|
||||
assert integration.get_platform_cached("light") is None
|
||||
|
||||
|
||||
async def test_async_get_platforms_caches_failures_when_component_loaded(
|
||||
hass: HomeAssistant,
|
||||
|
@ -312,6 +317,9 @@ async def test_async_get_platforms_caches_failures_when_component_loaded(
|
|||
with pytest.raises(ImportError):
|
||||
assert await integration.async_get_platforms(["light"]) == {"light": hue_light}
|
||||
|
||||
# The cache should never be filled because the import error is remembered
|
||||
assert integration.get_platform_cached("light") is None
|
||||
|
||||
|
||||
async def test_get_integration_legacy(
|
||||
hass: HomeAssistant, enable_custom_integrations: None
|
||||
|
@ -320,6 +328,7 @@ async def test_get_integration_legacy(
|
|||
integration = await loader.async_get_integration(hass, "test_embedded")
|
||||
assert integration.get_component().DOMAIN == "test_embedded"
|
||||
assert integration.get_platform("switch") is not None
|
||||
assert integration.get_platform_cached("switch") is not None
|
||||
|
||||
|
||||
async def test_get_integration_custom_component(
|
||||
|
@ -1549,6 +1558,9 @@ async def test_async_get_platforms_loads_loop_if_already_in_sys_modules(
|
|||
"switch": switch_module_mock,
|
||||
"light": light_module_mock,
|
||||
}
|
||||
assert integration.get_platform_cached("button") is button_module_mock
|
||||
assert integration.get_platform_cached("switch") is switch_module_mock
|
||||
assert integration.get_platform_cached("light") is light_module_mock
|
||||
|
||||
|
||||
async def test_async_get_platforms_concurrent_loads(
|
||||
|
@ -1610,3 +1622,4 @@ async def test_async_get_platforms_concurrent_loads(
|
|||
assert load_result2 == {"button": button_module_mock}
|
||||
|
||||
assert imports == [button_module_name]
|
||||
assert integration.get_platform_cached("button") is button_module_mock
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue