Add ComponentProtocol to improve type checking (#90586)

This commit is contained in:
epenet 2023-03-31 20:19:58 +02:00 committed by GitHub
parent 03137feba5
commit 611d4135fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 17 deletions

View file

@ -61,7 +61,7 @@ from .helpers import (
) )
from .helpers.entity_values import EntityValues from .helpers.entity_values import EntityValues
from .helpers.typing import ConfigType from .helpers.typing import ConfigType
from .loader import Integration, IntegrationNotFound from .loader import ComponentProtocol, Integration, IntegrationNotFound
from .requirements import RequirementsNotFound, async_get_integration_with_requirements from .requirements import RequirementsNotFound, async_get_integration_with_requirements
from .util.package import is_docker_env from .util.package import is_docker_env
from .util.unit_system import get_unit_system, validate_unit_system from .util.unit_system import get_unit_system, validate_unit_system
@ -681,7 +681,7 @@ def _log_pkg_error(package: str, component: str, config: dict, message: str) ->
_LOGGER.error(message) _LOGGER.error(message)
def _identify_config_schema(module: ModuleType) -> str | None: def _identify_config_schema(module: ComponentProtocol) -> str | None:
"""Extract the schema and identify list or dict based.""" """Extract the schema and identify list or dict based."""
if not isinstance(module.CONFIG_SCHEMA, vol.Schema): if not isinstance(module.CONFIG_SCHEMA, vol.Schema):
return None return None

View file

@ -383,7 +383,7 @@ class ConfigEntry:
result = await component.async_setup_entry(hass, self) result = await component.async_setup_entry(hass, self)
if not isinstance(result, bool): if not isinstance(result, bool):
_LOGGER.error( _LOGGER.error( # type: ignore[unreachable]
"%s.async_setup_entry did not return boolean", integration.domain "%s.async_setup_entry did not return boolean", integration.domain
) )
result = False result = False
@ -546,8 +546,7 @@ class ConfigEntry:
await self._async_process_on_unload() await self._async_process_on_unload()
# https://github.com/python/mypy/issues/11839 return result
return result # type: ignore[no-any-return]
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
_LOGGER.exception( _LOGGER.exception(
"Error unloading entry %s for %s", self.title, integration.domain "Error unloading entry %s for %s", self.title, integration.domain
@ -628,15 +627,14 @@ class ConfigEntry:
try: try:
result = await component.async_migrate_entry(hass, self) result = await component.async_migrate_entry(hass, self)
if not isinstance(result, bool): if not isinstance(result, bool):
_LOGGER.error( _LOGGER.error( # type: ignore[unreachable]
"%s.async_migrate_entry did not return boolean", self.domain "%s.async_migrate_entry did not return boolean", self.domain
) )
return False return False
if result: if result:
# pylint: disable-next=protected-access # pylint: disable-next=protected-access
hass.config_entries._async_schedule_save() hass.config_entries._async_schedule_save()
# https://github.com/python/mypy/issues/11839 return result
return result # type: ignore[no-any-return]
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception( _LOGGER.exception(
"Error migrating entry %s for %s", self.title, self.domain "Error migrating entry %s for %s", self.title, self.domain

View file

@ -15,13 +15,14 @@ import logging
import pathlib import pathlib
import sys import sys
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast
from awesomeversion import ( from awesomeversion import (
AwesomeVersion, AwesomeVersion,
AwesomeVersionException, AwesomeVersionException,
AwesomeVersionStrategy, AwesomeVersionStrategy,
) )
import voluptuous as vol
from . import generated from . import generated
from .generated.application_credentials import APPLICATION_CREDENTIALS from .generated.application_credentials import APPLICATION_CREDENTIALS
@ -35,7 +36,10 @@ from .util.json import JSON_DECODE_EXCEPTIONS, json_loads
# Typing imports that create a circular dependency # Typing imports that create a circular dependency
if TYPE_CHECKING: if TYPE_CHECKING:
from .config_entries import ConfigEntry
from .core import HomeAssistant from .core import HomeAssistant
from .helpers import device_registry as dr
from .helpers.typing import ConfigType
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
@ -260,6 +264,52 @@ async def async_get_config_flows(
return flows return flows
class ComponentProtocol(Protocol):
"""Define the format of an integration."""
CONFIG_SCHEMA: vol.Schema
DOMAIN: str
async def async_setup_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up a config entry."""
async def async_unload_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload a config entry."""
async def async_migrate_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Migrate an old config entry."""
async def async_remove_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> None:
"""Remove a config entry."""
async def async_remove_config_entry_device(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
device_entry: dr.DeviceEntry,
) -> bool:
"""Remove a config entry device."""
async def async_reset_platform(
self, hass: HomeAssistant, integration_name: str
) -> None:
"""Release resources."""
async def async_setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""
def setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""
async def async_get_integration_descriptions( async def async_get_integration_descriptions(
hass: HomeAssistant, hass: HomeAssistant,
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -750,14 +800,18 @@ class Integration:
return self._all_dependencies_resolved return self._all_dependencies_resolved
def get_component(self) -> ModuleType: def get_component(self) -> ComponentProtocol:
"""Return the component.""" """Return the component."""
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {}) cache: dict[str, ComponentProtocol] = self.hass.data.setdefault(
DATA_COMPONENTS, {}
)
if self.domain in cache: if self.domain in cache:
return cache[self.domain] return cache[self.domain]
try: try:
cache[self.domain] = importlib.import_module(self.pkg_path) cache[self.domain] = cast(
ComponentProtocol, importlib.import_module(self.pkg_path)
)
except ImportError: except ImportError:
raise raise
except Exception as err: except Exception as err:
@ -922,7 +976,7 @@ class CircularDependency(LoaderError):
def _load_file( def _load_file(
hass: HomeAssistant, comp_or_platform: str, base_paths: list[str] hass: HomeAssistant, comp_or_platform: str, base_paths: list[str]
) -> ModuleType | None: ) -> ComponentProtocol | None:
"""Try to load specified file. """Try to load specified file.
Looks in config dir first, then built-in components. Looks in config dir first, then built-in components.
@ -957,7 +1011,7 @@ def _load_file(
cache[comp_or_platform] = module cache[comp_or_platform] = module
return module return cast(ComponentProtocol, module)
except ImportError as err: except ImportError as err:
# This error happens if for example custom_components/switch # This error happens if for example custom_components/switch
@ -981,7 +1035,7 @@ def _load_file(
class ModuleWrapper: class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument.""" """Class to wrap a Python module and auto fill in hass argument."""
def __init__(self, hass: HomeAssistant, module: ModuleType) -> None: def __init__(self, hass: HomeAssistant, module: ComponentProtocol) -> None:
"""Initialize the module wrapper.""" """Initialize the module wrapper."""
self._hass = hass self._hass = hass
self._module = module self._module = module
@ -1010,7 +1064,7 @@ class Components:
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name) integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)
if isinstance(integration, Integration): if isinstance(integration, Integration):
component: ModuleType | None = integration.get_component() component: ComponentProtocol | None = integration.get_component()
else: else:
# Fallback to importing old-school # Fallback to importing old-school
component = _load_file(self._hass, comp_name, _lookup_path(self._hass)) component = _load_file(self._hass, comp_name, _lookup_path(self._hass))

View file

@ -236,7 +236,7 @@ async def _async_setup_component(
SLOW_SETUP_WARNING, SLOW_SETUP_WARNING,
) )
task = None task: Awaitable[bool] | None = None
result: Any | bool = True result: Any | bool = True
try: try:
if hasattr(component, "async_setup"): if hasattr(component, "async_setup"):

View file

@ -202,6 +202,14 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
}, },
return_type="bool", return_type="bool",
), ),
TypeHintMatch(
function_name="async_reset_platform",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=None,
),
], ],
"__any_platform__": [ "__any_platform__": [
TypeHintMatch( TypeHintMatch(