Add ComponentProtocol to improve type checking (#90586)
This commit is contained in:
parent
03137feba5
commit
611d4135fd
5 changed files with 77 additions and 17 deletions
|
@ -61,7 +61,7 @@ from .helpers import (
|
||||||
)
|
)
|
||||||
from .helpers.entity_values import EntityValues
|
from .helpers.entity_values import EntityValues
|
||||||
from .helpers.typing import ConfigType
|
from .helpers.typing import ConfigType
|
||||||
from .loader import Integration, IntegrationNotFound
|
from .loader import ComponentProtocol, Integration, IntegrationNotFound
|
||||||
from .requirements import RequirementsNotFound, async_get_integration_with_requirements
|
from .requirements import RequirementsNotFound, async_get_integration_with_requirements
|
||||||
from .util.package import is_docker_env
|
from .util.package import is_docker_env
|
||||||
from .util.unit_system import get_unit_system, validate_unit_system
|
from .util.unit_system import get_unit_system, validate_unit_system
|
||||||
|
@ -681,7 +681,7 @@ def _log_pkg_error(package: str, component: str, config: dict, message: str) ->
|
||||||
_LOGGER.error(message)
|
_LOGGER.error(message)
|
||||||
|
|
||||||
|
|
||||||
def _identify_config_schema(module: ModuleType) -> str | None:
|
def _identify_config_schema(module: ComponentProtocol) -> str | None:
|
||||||
"""Extract the schema and identify list or dict based."""
|
"""Extract the schema and identify list or dict based."""
|
||||||
if not isinstance(module.CONFIG_SCHEMA, vol.Schema):
|
if not isinstance(module.CONFIG_SCHEMA, vol.Schema):
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -383,7 +383,7 @@ class ConfigEntry:
|
||||||
result = await component.async_setup_entry(hass, self)
|
result = await component.async_setup_entry(hass, self)
|
||||||
|
|
||||||
if not isinstance(result, bool):
|
if not isinstance(result, bool):
|
||||||
_LOGGER.error(
|
_LOGGER.error( # type: ignore[unreachable]
|
||||||
"%s.async_setup_entry did not return boolean", integration.domain
|
"%s.async_setup_entry did not return boolean", integration.domain
|
||||||
)
|
)
|
||||||
result = False
|
result = False
|
||||||
|
@ -546,8 +546,7 @@ class ConfigEntry:
|
||||||
|
|
||||||
await self._async_process_on_unload()
|
await self._async_process_on_unload()
|
||||||
|
|
||||||
# https://github.com/python/mypy/issues/11839
|
return result
|
||||||
return result # type: ignore[no-any-return]
|
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
"Error unloading entry %s for %s", self.title, integration.domain
|
"Error unloading entry %s for %s", self.title, integration.domain
|
||||||
|
@ -628,15 +627,14 @@ class ConfigEntry:
|
||||||
try:
|
try:
|
||||||
result = await component.async_migrate_entry(hass, self)
|
result = await component.async_migrate_entry(hass, self)
|
||||||
if not isinstance(result, bool):
|
if not isinstance(result, bool):
|
||||||
_LOGGER.error(
|
_LOGGER.error( # type: ignore[unreachable]
|
||||||
"%s.async_migrate_entry did not return boolean", self.domain
|
"%s.async_migrate_entry did not return boolean", self.domain
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if result:
|
if result:
|
||||||
# pylint: disable-next=protected-access
|
# pylint: disable-next=protected-access
|
||||||
hass.config_entries._async_schedule_save()
|
hass.config_entries._async_schedule_save()
|
||||||
# https://github.com/python/mypy/issues/11839
|
return result
|
||||||
return result # type: ignore[no-any-return]
|
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
"Error migrating entry %s for %s", self.title, self.domain
|
"Error migrating entry %s for %s", self.title, self.domain
|
||||||
|
|
|
@ -15,13 +15,14 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
from awesomeversion import (
|
from awesomeversion import (
|
||||||
AwesomeVersion,
|
AwesomeVersion,
|
||||||
AwesomeVersionException,
|
AwesomeVersionException,
|
||||||
AwesomeVersionStrategy,
|
AwesomeVersionStrategy,
|
||||||
)
|
)
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from . import generated
|
from . import generated
|
||||||
from .generated.application_credentials import APPLICATION_CREDENTIALS
|
from .generated.application_credentials import APPLICATION_CREDENTIALS
|
||||||
|
@ -35,7 +36,10 @@ from .util.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||||
|
|
||||||
# Typing imports that create a circular dependency
|
# Typing imports that create a circular dependency
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .config_entries import ConfigEntry
|
||||||
from .core import HomeAssistant
|
from .core import HomeAssistant
|
||||||
|
from .helpers import device_registry as dr
|
||||||
|
from .helpers.typing import ConfigType
|
||||||
|
|
||||||
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
@ -260,6 +264,52 @@ async def async_get_config_flows(
|
||||||
return flows
|
return flows
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentProtocol(Protocol):
|
||||||
|
"""Define the format of an integration."""
|
||||||
|
|
||||||
|
CONFIG_SCHEMA: vol.Schema
|
||||||
|
DOMAIN: str
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
self, hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Set up a config entry."""
|
||||||
|
|
||||||
|
async def async_unload_entry(
|
||||||
|
self, hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Unload a config entry."""
|
||||||
|
|
||||||
|
async def async_migrate_entry(
|
||||||
|
self, hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> bool:
|
||||||
|
"""Migrate an old config entry."""
|
||||||
|
|
||||||
|
async def async_remove_entry(
|
||||||
|
self, hass: HomeAssistant, config_entry: ConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Remove a config entry."""
|
||||||
|
|
||||||
|
async def async_remove_config_entry_device(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
device_entry: dr.DeviceEntry,
|
||||||
|
) -> bool:
|
||||||
|
"""Remove a config entry device."""
|
||||||
|
|
||||||
|
async def async_reset_platform(
|
||||||
|
self, hass: HomeAssistant, integration_name: str
|
||||||
|
) -> None:
|
||||||
|
"""Release resources."""
|
||||||
|
|
||||||
|
async def async_setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up integration."""
|
||||||
|
|
||||||
|
def setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up integration."""
|
||||||
|
|
||||||
|
|
||||||
async def async_get_integration_descriptions(
|
async def async_get_integration_descriptions(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
@ -750,14 +800,18 @@ class Integration:
|
||||||
|
|
||||||
return self._all_dependencies_resolved
|
return self._all_dependencies_resolved
|
||||||
|
|
||||||
def get_component(self) -> ModuleType:
|
def get_component(self) -> ComponentProtocol:
|
||||||
"""Return the component."""
|
"""Return the component."""
|
||||||
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {})
|
cache: dict[str, ComponentProtocol] = self.hass.data.setdefault(
|
||||||
|
DATA_COMPONENTS, {}
|
||||||
|
)
|
||||||
if self.domain in cache:
|
if self.domain in cache:
|
||||||
return cache[self.domain]
|
return cache[self.domain]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache[self.domain] = importlib.import_module(self.pkg_path)
|
cache[self.domain] = cast(
|
||||||
|
ComponentProtocol, importlib.import_module(self.pkg_path)
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise
|
raise
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
@ -922,7 +976,7 @@ class CircularDependency(LoaderError):
|
||||||
|
|
||||||
def _load_file(
|
def _load_file(
|
||||||
hass: HomeAssistant, comp_or_platform: str, base_paths: list[str]
|
hass: HomeAssistant, comp_or_platform: str, base_paths: list[str]
|
||||||
) -> ModuleType | None:
|
) -> ComponentProtocol | None:
|
||||||
"""Try to load specified file.
|
"""Try to load specified file.
|
||||||
|
|
||||||
Looks in config dir first, then built-in components.
|
Looks in config dir first, then built-in components.
|
||||||
|
@ -957,7 +1011,7 @@ def _load_file(
|
||||||
|
|
||||||
cache[comp_or_platform] = module
|
cache[comp_or_platform] = module
|
||||||
|
|
||||||
return module
|
return cast(ComponentProtocol, module)
|
||||||
|
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
# This error happens if for example custom_components/switch
|
# This error happens if for example custom_components/switch
|
||||||
|
@ -981,7 +1035,7 @@ def _load_file(
|
||||||
class ModuleWrapper:
|
class ModuleWrapper:
|
||||||
"""Class to wrap a Python module and auto fill in hass argument."""
|
"""Class to wrap a Python module and auto fill in hass argument."""
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, module: ModuleType) -> None:
|
def __init__(self, hass: HomeAssistant, module: ComponentProtocol) -> None:
|
||||||
"""Initialize the module wrapper."""
|
"""Initialize the module wrapper."""
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._module = module
|
self._module = module
|
||||||
|
@ -1010,7 +1064,7 @@ class Components:
|
||||||
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)
|
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)
|
||||||
|
|
||||||
if isinstance(integration, Integration):
|
if isinstance(integration, Integration):
|
||||||
component: ModuleType | None = integration.get_component()
|
component: ComponentProtocol | None = integration.get_component()
|
||||||
else:
|
else:
|
||||||
# Fallback to importing old-school
|
# Fallback to importing old-school
|
||||||
component = _load_file(self._hass, comp_name, _lookup_path(self._hass))
|
component = _load_file(self._hass, comp_name, _lookup_path(self._hass))
|
||||||
|
|
|
@ -236,7 +236,7 @@ async def _async_setup_component(
|
||||||
SLOW_SETUP_WARNING,
|
SLOW_SETUP_WARNING,
|
||||||
)
|
)
|
||||||
|
|
||||||
task = None
|
task: Awaitable[bool] | None = None
|
||||||
result: Any | bool = True
|
result: Any | bool = True
|
||||||
try:
|
try:
|
||||||
if hasattr(component, "async_setup"):
|
if hasattr(component, "async_setup"):
|
||||||
|
|
|
@ -202,6 +202,14 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
|
||||||
},
|
},
|
||||||
return_type="bool",
|
return_type="bool",
|
||||||
),
|
),
|
||||||
|
TypeHintMatch(
|
||||||
|
function_name="async_reset_platform",
|
||||||
|
arg_types={
|
||||||
|
0: "HomeAssistant",
|
||||||
|
1: "str",
|
||||||
|
},
|
||||||
|
return_type=None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
"__any_platform__": [
|
"__any_platform__": [
|
||||||
TypeHintMatch(
|
TypeHintMatch(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue