Add setup function to the component loader (#98148)

* Add setup function to the component loader

* Update test

* Setup the loader in safe mode and in check_config script
This commit is contained in:
Erik Montnemery 2023-08-15 10:59:42 +02:00 committed by GitHub
parent b1e5b3be34
commit 3b9d6f2dde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 48 deletions

View file

@ -134,6 +134,7 @@ async def async_setup_hass(
_LOGGER.info("Config directory: %s", runtime_config.config_dir) _LOGGER.info("Config directory: %s", runtime_config.config_dir)
loader.async_setup(hass)
config_dict = None config_dict = None
basic_setup_success = False basic_setup_success = False
@ -185,6 +186,8 @@ async def async_setup_hass(
hass.config.internal_url = old_config.internal_url hass.config.internal_url = old_config.internal_url
hass.config.external_url = old_config.external_url hass.config.external_url = old_config.external_url
hass.config.config_dir = old_config.config_dir hass.config.config_dir = old_config.config_dir
# Setup loader cache after the config dir has been set
loader.async_setup(hass)
if safe_mode: if safe_mode:
_LOGGER.info("Starting in safe mode") _LOGGER.info("Starting in safe mode")

View file

@ -166,6 +166,13 @@ class Manifest(TypedDict, total=False):
loggers: list[str] loggers: list[str]
def async_setup(hass: HomeAssistant) -> None:
"""Set up the necessary data structures."""
_async_mount_config_dir(hass)
hass.data[DATA_COMPONENTS] = {}
hass.data[DATA_INTEGRATIONS] = {}
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest: def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
"""Generate a manifest from a legacy module.""" """Generate a manifest from a legacy module."""
return { return {
@ -802,9 +809,7 @@ class Integration:
def get_component(self) -> ComponentProtocol: def get_component(self) -> ComponentProtocol:
"""Return the component.""" """Return the component."""
cache: dict[str, ComponentProtocol] = self.hass.data.setdefault( cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
DATA_COMPONENTS, {}
)
if self.domain in cache: if self.domain in cache:
return cache[self.domain] return cache[self.domain]
@ -824,7 +829,7 @@ class Integration:
def get_platform(self, platform_name: str) -> ModuleType: def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration.""" """Return a platform for an integration."""
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {}) cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
full_name = f"{self.domain}.{platform_name}" full_name = f"{self.domain}.{platform_name}"
if full_name in cache: if full_name in cache:
return cache[full_name] return cache[full_name]
@ -883,11 +888,7 @@ async def async_get_integrations(
hass: HomeAssistant, domains: Iterable[str] hass: HomeAssistant, domains: Iterable[str]
) -> dict[str, Integration | Exception]: ) -> dict[str, Integration | Exception]:
"""Get integrations.""" """Get integrations."""
if (cache := hass.data.get(DATA_INTEGRATIONS)) is None: cache = hass.data[DATA_INTEGRATIONS]
if not _async_mount_config_dir(hass):
return {domain: IntegrationNotFound(domain) for domain in domains}
cache = hass.data[DATA_INTEGRATIONS] = {}
results: dict[str, Integration | Exception] = {} results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Future[None]] = {} needed: dict[str, asyncio.Future[None]] = {}
in_progress: dict[str, asyncio.Future[None]] = {} in_progress: dict[str, asyncio.Future[None]] = {}
@ -993,10 +994,7 @@ def _load_file(
comp_or_platform comp_or_platform
] ]
if (cache := hass.data.get(DATA_COMPONENTS)) is None: cache = hass.data[DATA_COMPONENTS]
if not _async_mount_config_dir(hass):
return None
cache = hass.data[DATA_COMPONENTS] = {}
for path in (f"{base}.{comp_or_platform}" for base in base_paths): for path in (f"{base}.{comp_or_platform}" for base in base_paths):
try: try:
@ -1066,7 +1064,7 @@ class Components:
def __getattr__(self, comp_name: str) -> ModuleWrapper: def __getattr__(self, comp_name: str) -> ModuleWrapper:
"""Fetch a component.""" """Fetch a component."""
# Test integration cache # Test integration cache
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name) integration = self._hass.data[DATA_INTEGRATIONS].get(comp_name)
if isinstance(integration, Integration): if isinstance(integration, Integration):
component: ComponentProtocol | None = integration.get_component() component: ComponentProtocol | None = integration.get_component()

View file

@ -11,7 +11,7 @@ import os
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
from homeassistant import core from homeassistant import core, loader
from homeassistant.config import get_default_config_dir from homeassistant.config import get_default_config_dir
from homeassistant.config_entries import ConfigEntries from homeassistant.config_entries import ConfigEntries
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -232,6 +232,7 @@ def check(config_dir, secrets=False):
async def async_check_config(config_dir): async def async_check_config(config_dir):
"""Check the HA config.""" """Check the HA config."""
hass = core.HomeAssistant() hass = core.HomeAssistant()
loader.async_setup(hass)
hass.config.config_dir = config_dir hass.config.config_dir = config_dir
hass.config_entries = ConfigEntries(hass, {}) hass.config_entries = ConfigEntries(hass, {})
await ar.async_load(hass) await ar.async_load(hass)

View file

@ -256,6 +256,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
# Load the registries # Load the registries
entity.async_setup(hass) entity.async_setup(hass)
loader.async_setup(hass)
if load_registries: if load_registries:
with patch( with patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None "homeassistant.helpers.storage.Store.async_load", return_value=None
@ -1339,16 +1340,10 @@ def mock_integration(
integration._import_platform = mock_import_platform integration._import_platform = mock_import_platform
_LOGGER.info("Adding mock integration: %s", module.DOMAIN) _LOGGER.info("Adding mock integration: %s", module.DOMAIN)
integration_cache = hass.data.get(loader.DATA_INTEGRATIONS) integration_cache = hass.data[loader.DATA_INTEGRATIONS]
if integration_cache is None:
integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {}
loader._async_mount_config_dir(hass)
integration_cache[module.DOMAIN] = integration integration_cache[module.DOMAIN] = integration
module_cache = hass.data.get(loader.DATA_COMPONENTS) module_cache = hass.data[loader.DATA_COMPONENTS]
if module_cache is None:
module_cache = hass.data[loader.DATA_COMPONENTS] = {}
loader._async_mount_config_dir(hass)
module_cache[module.DOMAIN] = module module_cache[module.DOMAIN] = module
return integration return integration
@ -1374,15 +1369,8 @@ def mock_platform(
platform_path is in form hue.config_flow. platform_path is in form hue.config_flow.
""" """
domain = platform_path.split(".")[0] domain = platform_path.split(".")[0]
integration_cache = hass.data.get(loader.DATA_INTEGRATIONS) integration_cache = hass.data[loader.DATA_INTEGRATIONS]
if integration_cache is None: module_cache = hass.data[loader.DATA_COMPONENTS]
integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {}
loader._async_mount_config_dir(hass)
module_cache = hass.data.get(loader.DATA_COMPONENTS)
if module_cache is None:
module_cache = hass.data[loader.DATA_COMPONENTS] = {}
loader._async_mount_config_dir(hass)
if domain not in integration_cache: if domain not in integration_cache:
mock_integration(hass, MockModule(domain)) mock_integration(hass, MockModule(domain))

View file

@ -304,7 +304,7 @@ async def test_websocket_get_action_capabilities(
return {"extra_fields": vol.Schema({vol.Optional("code"): str})} return {"extra_fields": vol.Schema({vol.Optional("code"): str})}
return {} return {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"] module = module_cache["fake_integration.device_action"]
module.async_get_action_capabilities = _async_get_action_capabilities module.async_get_action_capabilities = _async_get_action_capabilities
@ -406,7 +406,7 @@ async def test_websocket_get_action_capabilities_bad_action(
await async_setup_component(hass, "device_automation", {}) await async_setup_component(hass, "device_automation", {})
expected_capabilities = {} expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"] module = module_cache["fake_integration.device_action"]
module.async_get_action_capabilities = Mock( module.async_get_action_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig side_effect=InvalidDeviceAutomationConfig
@ -459,7 +459,7 @@ async def test_websocket_get_condition_capabilities(
"""List condition capabilities.""" """List condition capabilities."""
return await toggle_entity.async_get_condition_capabilities(hass, config) return await toggle_entity.async_get_condition_capabilities(hass, config)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"] module = module_cache["fake_integration.device_condition"]
module.async_get_condition_capabilities = _async_get_condition_capabilities module.async_get_condition_capabilities = _async_get_condition_capabilities
@ -569,7 +569,7 @@ async def test_websocket_get_condition_capabilities_bad_condition(
await async_setup_component(hass, "device_automation", {}) await async_setup_component(hass, "device_automation", {})
expected_capabilities = {} expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"] module = module_cache["fake_integration.device_condition"]
module.async_get_condition_capabilities = Mock( module.async_get_condition_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig side_effect=InvalidDeviceAutomationConfig
@ -747,7 +747,7 @@ async def test_websocket_get_trigger_capabilities(
"""List trigger capabilities.""" """List trigger capabilities."""
return await toggle_entity.async_get_trigger_capabilities(hass, config) return await toggle_entity.async_get_trigger_capabilities(hass, config)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_get_trigger_capabilities = _async_get_trigger_capabilities module.async_get_trigger_capabilities = _async_get_trigger_capabilities
@ -857,7 +857,7 @@ async def test_websocket_get_trigger_capabilities_bad_trigger(
await async_setup_component(hass, "device_automation", {}) await async_setup_component(hass, "device_automation", {})
expected_capabilities = {} expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_get_trigger_capabilities = Mock( module.async_get_trigger_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig side_effect=InvalidDeviceAutomationConfig
@ -912,7 +912,7 @@ async def test_automation_with_device_action(
) -> None: ) -> None:
"""Test automation with a device action.""" """Test automation with a device action."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"] module = module_cache["fake_integration.device_action"]
module.async_call_action_from_config = AsyncMock() module.async_call_action_from_config = AsyncMock()
@ -949,7 +949,7 @@ async def test_automation_with_dynamically_validated_action(
) -> None: ) -> None:
"""Test device automation with an action which is dynamically validated.""" """Test device automation with an action which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"] module = module_cache["fake_integration.device_action"]
module.async_validate_action_config = AsyncMock() module.async_validate_action_config = AsyncMock()
@ -1003,7 +1003,7 @@ async def test_automation_with_device_condition(
) -> None: ) -> None:
"""Test automation with a device condition.""" """Test automation with a device condition."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"] module = module_cache["fake_integration.device_condition"]
module.async_condition_from_config = Mock() module.async_condition_from_config = Mock()
@ -1037,7 +1037,7 @@ async def test_automation_with_dynamically_validated_condition(
) -> None: ) -> None:
"""Test device automation with a condition which is dynamically validated.""" """Test device automation with a condition which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"] module = module_cache["fake_integration.device_condition"]
module.async_validate_condition_config = AsyncMock() module.async_validate_condition_config = AsyncMock()
@ -1102,7 +1102,7 @@ async def test_automation_with_device_trigger(
) -> None: ) -> None:
"""Test automation with a device trigger.""" """Test automation with a device trigger."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_attach_trigger = AsyncMock() module.async_attach_trigger = AsyncMock()
@ -1136,7 +1136,7 @@ async def test_automation_with_dynamically_validated_trigger(
) -> None: ) -> None:
"""Test device automation with a trigger which is dynamically validated.""" """Test device automation with a trigger which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_attach_trigger = AsyncMock() module.async_attach_trigger = AsyncMock()
module.async_validate_trigger_config = AsyncMock(wraps=lambda hass, config: config) module.async_validate_trigger_config = AsyncMock(wraps=lambda hass, config: config)
@ -1457,7 +1457,7 @@ async def test_automation_with_unknown_device(
) -> None: ) -> None:
"""Test device automation with a trigger with an unknown device.""" """Test device automation with a trigger with an unknown device."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock() module.async_validate_trigger_config = AsyncMock()
@ -1492,7 +1492,7 @@ async def test_automation_with_device_wrong_domain(
) -> None: ) -> None:
"""Test device automation where the device doesn't have the right config entry.""" """Test device automation where the device doesn't have the right config entry."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock() module.async_validate_trigger_config = AsyncMock()
@ -1534,7 +1534,7 @@ async def test_automation_with_device_component_not_loaded(
) -> None: ) -> None:
"""Test device automation where the device's config entry is not loaded.""" """Test device automation where the device's config entry is not loaded."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"] module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock() module.async_validate_trigger_config = AsyncMock()
module.async_attach_trigger = AsyncMock() module.async_attach_trigger = AsyncMock()

View file

@ -1810,7 +1810,7 @@ async def test_execute_script_with_dynamically_validated_action(
ws_client = await hass_ws_client(hass) ws_client = await hass_ws_client(hass)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"] module = module_cache["fake_integration.device_action"]
module.async_call_action_from_config = AsyncMock() module.async_call_action_from_config = AsyncMock()
module.async_validate_action_config = AsyncMock( module.async_validate_action_config = AsyncMock(