Make EntityComponent generic (#78473)

This commit is contained in:
epenet 2022-09-14 20:16:23 +02:00 committed by GitHub
parent fd05d949cc
commit 996bcbdac6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 53 additions and 50 deletions

View file

@ -154,12 +154,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if entity_id in cast(AutomationEntity, automation_entity).referenced_entities if entity_id in automation_entity.referenced_entities
] ]
@ -169,12 +169,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(cast(AutomationEntity, automation_entity).referenced_entities) return list(automation_entity.referenced_entities)
@callback @callback
@ -183,12 +183,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if device_id in cast(AutomationEntity, automation_entity).referenced_devices if device_id in automation_entity.referenced_devices
] ]
@ -198,12 +198,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(cast(AutomationEntity, automation_entity).referenced_devices) return list(automation_entity.referenced_devices)
@callback @callback
@ -212,12 +212,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [ return [
automation_entity.entity_id automation_entity.entity_id
for automation_entity in component.entities for automation_entity in component.entities
if area_id in cast(AutomationEntity, automation_entity).referenced_areas if area_id in automation_entity.referenced_areas
] ]
@ -227,12 +227,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
return [] return []
component: EntityComponent = hass.data[DOMAIN] component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None: if (automation_entity := component.get_entity(entity_id)) is None:
return [] return []
return list(cast(AutomationEntity, automation_entity).referenced_areas) return list(automation_entity.referenced_areas)
@callback @callback
@ -252,7 +252,9 @@ def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up all automations.""" """Set up all automations."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent[AutomationEntity](
LOGGER, DOMAIN, hass
)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -93,7 +93,7 @@ CONFIG_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the counters.""" """Set up the counters."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[Counter](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager() id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection( yaml_collection = collection.YamlCollection(

View file

@ -67,9 +67,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up is called when Home Assistant is loading our component.""" """Set up is called when Home Assistant is loading our component."""
dominos = Dominos(hass, config) dominos = Dominos(hass, config)
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[DominosOrder](_LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = {} hass.data[DOMAIN] = {}
entities = [] entities: list[DominosOrder] = []
conf = config[DOMAIN] conf = config[DOMAIN]
hass.services.register( hass.services.register(

View file

@ -85,7 +85,9 @@ class FaceInformation(TypedDict, total=False):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the image processing.""" """Set up the image processing."""
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL) component = EntityComponent[ImageProcessingEntity](
_LOGGER, DOMAIN, hass, SCAN_INTERVAL
)
await component.async_setup(config) await component.async_setup(config)

View file

@ -92,7 +92,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input boolean.""" """Set up an input boolean."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputBoolean](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -77,7 +77,7 @@ class InputButtonStorageCollection(collection.StorageCollection):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input button.""" """Set up an input button."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputButton](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -130,7 +130,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input datetime.""" """Set up an input datetime."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputDatetime](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -107,7 +107,7 @@ STORAGE_VERSION = 1
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input slider.""" """Set up an input slider."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputNumber](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -132,7 +132,7 @@ class InputSelectStore(Store):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select.""" """Set up an input select."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputSelect](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -107,7 +107,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input text.""" """Set up an input text."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[InputText](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -83,7 +83,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
mailboxes.append(mailbox) mailboxes.append(mailbox)
mailbox_entity = MailboxEntity(mailbox) mailbox_entity = MailboxEntity(mailbox)
component = EntityComponent( component = EntityComponent[MailboxEntity](
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL
) )
await component.async_add_entities([mailbox_entity]) await component.async_add_entities([mailbox_entity])

View file

@ -326,7 +326,7 @@ The following persons point at invalid users:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the person component.""" """Set up the person component."""
entity_component = EntityComponent(_LOGGER, DOMAIN, hass) entity_component = EntityComponent[Person](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager() id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection( yaml_collection = collection.YamlCollection(
logging.getLogger(f"{__name__}.yaml_collection"), id_manager logging.getLogger(f"{__name__}.yaml_collection"), id_manager

View file

@ -111,7 +111,7 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: {cv.string: PLANT_SCHEMA}}, extra=vol.ALLOW_
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Plant component.""" """Set up the Plant component."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[Plant](_LOGGER, DOMAIN, hass)
entities = [] entities = []
for plant_name, plant_config in config[DOMAIN].items(): for plant_name, plant_config in config[DOMAIN].items():

View file

@ -52,7 +52,7 @@ SERVICE_SCHEMA_COMPLETE_TASK = vol.Schema({vol.Required(CONF_ID): cv.string})
def setup(hass: HomeAssistant, config: ConfigType) -> bool: def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Remember the milk component.""" """Set up the Remember the milk component."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[RememberTheMilk](_LOGGER, DOMAIN, hass)
stored_rtm_config = RememberTheMilkConfiguration(hass) stored_rtm_config = RememberTheMilkConfiguration(hass)
for rtm_config in config[DOMAIN]: for rtm_config in config[DOMAIN]:

View file

@ -27,6 +27,7 @@ from homeassistant.const import (
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import discovery, template from homeassistant.helpers import discovery, template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import ( from homeassistant.helpers.entity_component import (
DEFAULT_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
EntityComponent, EntityComponent,
@ -53,7 +54,7 @@ COORDINATOR_AWARE_PLATFORMS = [SENSOR_DOMAIN, BINARY_SENSOR_DOMAIN]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the rest platforms.""" """Set up the rest platforms."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[Entity](_LOGGER, DOMAIN, hass)
_async_setup_shared_data(hass) _async_setup_shared_data(hass)
async def reload_service_handler(service: ServiceCall) -> None: async def reload_service_handler(service: ServiceCall) -> None:

View file

@ -154,7 +154,7 @@ ENTITY_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select.""" """Set up an input select."""
component = EntityComponent(LOGGER, DOMAIN, hass) component = EntityComponent[Schedule](LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -183,7 +183,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Load the scripts from the configuration.""" """Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) hass.data[DOMAIN] = component = EntityComponent[ScriptEntity](LOGGER, DOMAIN, hass)
# Process integration platforms right away since # Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED # we will create entities before firing EVENT_COMPONENT_LOADED
@ -205,9 +205,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def turn_on_service(service: ServiceCall) -> None: async def turn_on_service(service: ServiceCall) -> None:
"""Call a service to turn script on.""" """Call a service to turn script on."""
variables = service.data.get(ATTR_VARIABLES) variables = service.data.get(ATTR_VARIABLES)
script_entities: list[ScriptEntity] = cast( script_entities = await component.async_extract_from_service(service)
list[ScriptEntity], await component.async_extract_from_service(service)
)
for script_entity in script_entities: for script_entity in script_entities:
await script_entity.async_turn_on( await script_entity.async_turn_on(
variables=variables, context=service.context, wait=False variables=variables, context=service.context, wait=False
@ -216,9 +214,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def turn_off_service(service: ServiceCall) -> None: async def turn_off_service(service: ServiceCall) -> None:
"""Cancel a script.""" """Cancel a script."""
# Stopping a script is ok to be done in parallel # Stopping a script is ok to be done in parallel
script_entities: list[ScriptEntity] = cast( script_entities = await component.async_extract_from_service(service)
list[ScriptEntity], await component.async_extract_from_service(service)
)
if not script_entities: if not script_entities:
return return
@ -232,9 +228,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def toggle_service(service: ServiceCall) -> None: async def toggle_service(service: ServiceCall) -> None:
"""Toggle a script.""" """Toggle a script."""
script_entities: list[ScriptEntity] = cast( script_entities = await component.async_extract_from_service(service)
list[ScriptEntity], await component.async_extract_from_service(service)
)
for script_entity in script_entities: for script_entity in script_entities:
await script_entity.async_toggle(context=service.context, wait=False) await script_entity.async_toggle(context=service.context, wait=False)

View file

@ -106,7 +106,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select.""" """Set up an input select."""
component = EntityComponent(_LOGGER, DOMAIN, hass) component = EntityComponent[Timer](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager() id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection( yaml_collection = collection.YamlCollection(

View file

@ -185,7 +185,7 @@ class ZoneStorageCollection(collection.StorageCollection):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up configured zones as well as Home Assistant zone if necessary.""" """Set up configured zones as well as Home Assistant zone if necessary."""
component = entity_component.EntityComponent(_LOGGER, DOMAIN, hass) component = entity_component.EntityComponent[Zone](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager() id_manager = collection.IDManager()
yaml_collection = collection.IDLessCollection( yaml_collection = collection.IDLessCollection(

View file

@ -7,7 +7,7 @@ from datetime import timedelta
from itertools import chain from itertools import chain
import logging import logging
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any, Generic, TypeVar
import voluptuous as vol import voluptuous as vol
@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15) DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components" DATA_INSTANCES = "entity_components"
_EntityT = TypeVar("_EntityT", bound=entity.Entity)
@bind_hass @bind_hass
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None: async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
await entity_obj.async_update_ha_state(True) await entity_obj.async_update_ha_state(True)
class EntityComponent: class EntityComponent(Generic[_EntityT]):
"""The EntityComponent manages platforms that manages entities. """The EntityComponent manages platforms that manages entities.
This class has the following responsibilities: This class has the following responsibilities:
@ -86,18 +88,19 @@ class EntityComponent:
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property @property
def entities(self) -> Iterable[entity.Entity]: def entities(self) -> Iterable[_EntityT]:
"""Return an iterable that returns all entities.""" """Return an iterable that returns all entities."""
return chain.from_iterable( return chain.from_iterable(
platform.entities.values() for platform in self._platforms.values() platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
) )
def get_entity(self, entity_id: str) -> entity.Entity | None: def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity.""" """Get an entity."""
for platform in self._platforms.values(): for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id) entity_obj = platform.entities.get(entity_id)
if entity_obj is not None: if entity_obj is not None:
return entity_obj return entity_obj # type: ignore[return-value]
return None return None
def setup(self, config: ConfigType) -> None: def setup(self, config: ConfigType) -> None:
@ -176,14 +179,14 @@ class EntityComponent:
async def async_extract_from_service( async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True self, service_call: ServiceCall, expand_group: bool = True
) -> list[entity.Entity]: ) -> list[_EntityT]:
"""Extract all known and available entities from a service call. """Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown. Will return an empty list if entities specified but unknown.
This method must be run in the event loop. This method must be run in the event loop.
""" """
return await service.async_extract_entities( return await service.async_extract_entities( # type: ignore[return-value]
self.hass, self.entities, service_call, expand_group self.hass, self.entities, service_call, expand_group
) )

View file

@ -14,6 +14,7 @@ from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import config_per_platform from . import config_per_platform
from .entity import Entity
from .entity_component import EntityComponent from .entity_component import EntityComponent
from .entity_platform import EntityPlatform, async_get_platforms from .entity_platform import EntityPlatform, async_get_platforms
from .service import async_register_admin_service from .service import async_register_admin_service
@ -120,7 +121,7 @@ async def _async_setup_platform(
) )
return return
entity_component: EntityComponent = hass.data[integration_platform] entity_component: EntityComponent[Entity] = hass.data[integration_platform]
tasks = [ tasks = [
entity_component.async_setup_platform(integration_name, p_config) entity_component.async_setup_platform(integration_name, p_config)
for p_config in platform_configs for p_config in platform_configs